Spaces:
Sleeping
Sleeping
import os | |
os.environ["HF_HOME"] = "/tmp/hf" | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers" | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, AutoModel | |
from fastapi.responses import StreamingResponse | |
import threading | |
app = FastAPI() | |
# model_id = "GEB-AGI/geb-1.3b" | |
# tokenizer = AutoTokenizer.from_pretrained(model_id) | |
# model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") | |
model = AutoModel.from_pretrained("GEB-AGI/geb-1.3b", trust_remote_code=True) | |
tokenizer = AutoTokenizer.from_pretrained("GEB-AGI/geb-1.3b", trust_remote_code=True) | |
class ChatRequest(BaseModel): | |
message: str | |
async def chat_stream(request: ChatRequest): | |
# Usar plantilla de chat, instrucción clara en español | |
messages = [ | |
{ | |
"role": "system", | |
"content": "Eres un asistente IA amigable y responde siempre en español, de forma breve y clara.", | |
}, | |
{"role": "user", "content": request.message}, | |
] | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = tokenizer(prompt, return_tensors="pt") | |
input_ids = inputs["input_ids"] | |
attention_mask = inputs["attention_mask"] | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=48, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
streamer=streamer, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
async def event_generator(): | |
for new_text in streamer: | |
yield new_text | |
return StreamingResponse(event_generator(), media_type="text/plain") | |