DauroCamilo commited on
Commit
8a846d6
verified
1 Parent(s): 2605bf3

TinyLlama/TinyLlama-1.1B-Chat-v1.0

Browse files
Files changed (1) hide show
  1. main.py +16 -22
main.py CHANGED
@@ -1,59 +1,53 @@
1
  import os
2
-
3
  os.environ["HF_HOME"] = "/tmp/hf"
4
- os.environ["HF_DATASETS_CACHE"] = "/tmp/hf/datasets"
5
- os.environ["HF_METRICS_CACHE"] = "/tmp/hf/metrics"
6
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
 
7
  from fastapi import FastAPI
8
  from pydantic import BaseModel
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
  from fastapi.responses import StreamingResponse
11
- import torch
12
  import threading
13
 
14
  app = FastAPI()
15
 
16
- # Cargar modelo y tokenizer de Phi-2 (usa el modelo de Hugging Face Hub)
17
- model_id = "microsoft/phi-2"
18
  tokenizer = AutoTokenizer.from_pretrained(model_id)
19
- model = AutoModelForCausalLM.from_pretrained(model_id)
20
 
21
- # Modelo de entrada
22
  class ChatRequest(BaseModel):
23
  message: str
24
 
25
  @app.post("/chat/stream")
26
  async def chat_stream(request: ChatRequest):
27
- prompt = f"""Responde en espa帽ol de forma clara y breve como un asistente IA.
28
- Usuario: {request.message}
29
- IA:"""
30
-
31
- # Tokenizar entrada
 
 
 
 
32
  inputs = tokenizer(prompt, return_tensors="pt")
33
  input_ids = inputs["input_ids"]
34
  attention_mask = inputs["attention_mask"]
35
 
36
- # Streamer para obtener tokens generados poco a poco
37
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
38
-
39
- # Iniciar la generaci贸n en un hilo aparte
40
  generation_kwargs = dict(
41
- input_ids=input_ids,
42
  attention_mask=attention_mask,
43
- max_new_tokens=48, # Puedes ajustar este valor para m谩s/menos tokens
44
  temperature=0.7,
45
  top_p=0.9,
46
  do_sample=True,
47
  streamer=streamer,
48
- pad_token_id=tokenizer.eos_token_id
49
  )
50
  thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
51
  thread.start()
52
 
53
- # StreamingResponse espera un generador que devuelva texto
54
  async def event_generator():
55
  for new_text in streamer:
56
  yield new_text
57
 
58
  return StreamingResponse(event_generator(), media_type="text/plain")
59
-
 
1
  import os
 
2
  os.environ["HF_HOME"] = "/tmp/hf"
 
 
3
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
4
+
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
8
  from fastapi.responses import StreamingResponse
 
9
  import threading
10
 
11
  app = FastAPI()
12
 
13
+ model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
14
  tokenizer = AutoTokenizer.from_pretrained(model_id)
15
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
16
 
 
17
  class ChatRequest(BaseModel):
18
  message: str
19
 
20
  @app.post("/chat/stream")
21
  async def chat_stream(request: ChatRequest):
22
+ # Usar plantilla de chat, instrucci贸n clara en espa帽ol
23
+ messages = [
24
+ {
25
+ "role": "system",
26
+ "content": "Eres un asistente IA amigable y responde siempre en espa帽ol, de forma breve y clara.",
27
+ },
28
+ {"role": "user", "content": request.message},
29
+ ]
30
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
31
  inputs = tokenizer(prompt, return_tensors="pt")
32
  input_ids = inputs["input_ids"]
33
  attention_mask = inputs["attention_mask"]
34
 
 
35
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
36
  generation_kwargs = dict(
37
+ input_ids=input_ids,
38
  attention_mask=attention_mask,
39
+ max_new_tokens=48,
40
  temperature=0.7,
41
  top_p=0.9,
42
  do_sample=True,
43
  streamer=streamer,
44
+ pad_token_id=tokenizer.eos_token_id,
45
  )
46
  thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
47
  thread.start()
48
 
 
49
  async def event_generator():
50
  for new_text in streamer:
51
  yield new_text
52
 
53
  return StreamingResponse(event_generator(), media_type="text/plain")