DauroCamilo commited on
Commit
59e2f12
verified
1 Parent(s): c00e046

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +24 -20
main.py CHANGED
@@ -1,49 +1,53 @@
1
- import os
2
- import torch
3
- os.environ["HF_HOME"] = "/tmp/hf"
4
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
5
-
6
  from fastapi import FastAPI
7
  from pydantic import BaseModel
8
- from transformers import AutoTokenizer, AutoModel, TextIteratorStreamer
9
  from fastapi.responses import StreamingResponse
 
10
  import threading
11
 
12
  app = FastAPI()
13
 
14
- model_id = "GEB-AGI/geb-1.3b"
15
- model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
16
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
17
 
 
18
  class ChatRequest(BaseModel):
19
  message: str
20
 
21
  @app.post("/chat/stream")
22
  async def chat_stream(request: ChatRequest):
23
- prompt = f"Responde en espa帽ol de forma clara y breve como un asistente IA.\nUsuario: {request.message}\nIA:"
 
 
24
 
25
- # 1. Tokeniza a tokens (sin padding, sin encode)
26
- tokens = tokenizer.tokenize(prompt)
27
- token_ids = tokenizer.convert_tokens_to_ids(tokens)
28
- # 2. A帽ade manualmente los tokens especiales
29
- input_ids = tokenizer.build_inputs_with_special_tokens(token_ids)
30
- input_ids = torch.tensor([input_ids])
31
 
32
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
 
33
  generation_kwargs = dict(
34
- input_ids=input_ids,
35
- max_new_tokens=48,
 
36
  temperature=0.7,
37
  top_p=0.9,
38
  do_sample=True,
39
  streamer=streamer,
40
- pad_token_id=getattr(tokenizer, "eos_token_id", None),
41
  )
42
  thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
43
  thread.start()
44
 
 
45
  async def event_generator():
46
  for new_text in streamer:
47
  yield new_text
48
 
49
  return StreamingResponse(event_generator(), media_type="text/plain")
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  from fastapi.responses import StreamingResponse
5
+ import torch
6
  import threading
7
 
8
  app = FastAPI()
9
 
10
+ # Cargar modelo y tokenizer de Phi-2 (usa el modelo de Hugging Face Hub)
11
+ model_id = "HuggingFaceTB/SmolLM2-135M"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
13
+ model = AutoModelForCausalLM.from_pretrained(model_id)
14
 
15
+ # Modelo de entrada
16
  class ChatRequest(BaseModel):
17
  message: str
18
 
19
  @app.post("/chat/stream")
20
  async def chat_stream(request: ChatRequest):
21
+ prompt = f"""Responde en espa帽ol de forma clara y breve como un asistente IA.
22
+ Usuario: {request.message}
23
+ IA:"""
24
 
25
+ # Tokenizar entrada
26
+ inputs = tokenizer(prompt, return_tensors="pt")
27
+ input_ids = inputs["input_ids"]
28
+ attention_mask = inputs["attention_mask"]
 
 
29
 
30
+ # Streamer para obtener tokens generados poco a poco
31
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=False)
32
+
33
+ # Iniciar la generaci贸n en un hilo aparte
34
  generation_kwargs = dict(
35
+ input_ids=input_ids,
36
+ attention_mask=attention_mask,
37
+ max_new_tokens=48, # Puedes ajustar este valor para m谩s/menos tokens
38
  temperature=0.7,
39
  top_p=0.9,
40
  do_sample=True,
41
  streamer=streamer,
42
+ pad_token_id=tokenizer.eos_token_id
43
  )
44
  thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
45
  thread.start()
46
 
47
+ # StreamingResponse espera un generador que devuelva texto
48
  async def event_generator():
49
  for new_text in streamer:
50
  yield new_text
51
 
52
  return StreamingResponse(event_generator(), media_type="text/plain")
53
+