DauroCamilo commited on
Commit
617198c
·
verified ·
1 Parent(s): 39e93b7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +96 -45
main.py CHANGED
@@ -1,53 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
- from fastapi.responses import StreamingResponse
5
- import torch
6
- import threading
7
 
8
  app = FastAPI()
9
 
10
- # Cargar modelo y tokenizer de Phi-2 (usa el modelo de Hugging Face Hub)
11
- model_id = "HuggingFaceTB/SmolLM2-135M"
12
- tokenizer = AutoTokenizer.from_pretrained(model_id)
13
- model = AutoModelForCausalLM.from_pretrained(model_id)
14
 
15
  # Modelo de entrada
16
- class ChatRequest(BaseModel):
17
- message: str
18
-
19
- @app.post("/chat/stream")
20
- async def chat_stream(request: ChatRequest):
21
- prompt = f"""Responde en español de forma clara y breve como un asistente IA.
22
- Usuario: {request.message}
23
- IA:"""
24
-
25
- # Tokenizar entrada
26
- inputs = tokenizer(prompt, return_tensors="pt")
27
- input_ids = inputs["input_ids"]
28
- attention_mask = inputs["attention_mask"]
29
-
30
- # Streamer para obtener tokens generados poco a poco
31
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=False)
32
-
33
- # Iniciar la generación en un hilo aparte
34
- generation_kwargs = dict(
35
- input_ids=input_ids,
36
- attention_mask=attention_mask,
37
- max_new_tokens=48, # Puedes ajustar este valor para más/menos tokens
38
- temperature=0.7,
39
- top_p=0.9,
40
- do_sample=True,
41
- streamer=streamer,
42
- pad_token_id=tokenizer.eos_token_id
43
- )
44
- thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
45
- thread.start()
46
-
47
- # StreamingResponse espera un generador que devuelva texto
48
- async def event_generator():
49
- for new_text in streamer:
50
- yield new_text
51
-
52
- return StreamingResponse(event_generator(), media_type="text/plain")
53
 
 
 
 
 
 
 
 
 
 
1
+ # from fastapi import FastAPI
2
+ # from pydantic import BaseModel
3
+ # from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
4
+ # import torch
5
+
6
+ # app = FastAPI()
7
+
8
+ # model_id = "HuggingFaceTB/SmolLM2-360M"
9
+ # tokenizer = AutoTokenizer.from_pretrained(model_id)
10
+ # model = AutoModelForCausalLM.from_pretrained(model_id)
11
+
12
+ # class ChatRequest(BaseModel):
13
+ # context: str # Historial de la conversación, como texto
14
+
15
+ # class NewlineStoppingCriteria(StoppingCriteria):
16
+ # def __init__(self, prompt_len, tokenizer):
17
+ # super().__init__()
18
+ # self.prompt_len = prompt_len
19
+ # self.tokenizer = tokenizer
20
+
21
+ # def __call__(self, input_ids, scores, **kwargs):
22
+ # # Chequea si después del prompt hay un token de salto de línea
23
+ # gen_tokens = input_ids[0][self.prompt_len:]
24
+ # gen_text = self.tokenizer.decode(gen_tokens, skip_special_tokens=True)
25
+ # return '\n' in gen_text
26
+
27
+ # @app.post("/chat/demo_base")
28
+ # async def chat_demo_base(request: ChatRequest):
29
+ # prompt = (
30
+ # "Conversacion 1:\n"
31
+ # "-Dauro: -Hola Juanjo.\n"
32
+ # "-Juanjo: -¿Qué tal?\n"
33
+ # "-Dauro: -Bien, ¿y tú?\n\n"
34
+ # "Conversacion 2:\n"
35
+ # "-Juanjo: -Oye Asistente, ¿puedes mirar esto?\n"
36
+ # "-Asistente: -Por supuesto, dime.\n\n"
37
+ # f"Conversacion 3:\n{request.context}\n"
38
+ # )
39
+
40
+ # inputs = tokenizer(prompt, return_tensors="pt")
41
+ # input_ids = inputs["input_ids"]
42
+ # attention_mask = inputs["attention_mask"]
43
+
44
+ # stopping_criteria = StoppingCriteriaList([
45
+ # NewlineStoppingCriteria(prompt_len=input_ids.shape[1], tokenizer=tokenizer)
46
+ # ])
47
+
48
+ # output = model.generate(
49
+ # input_ids=input_ids,
50
+ # attention_mask=attention_mask,
51
+ # max_new_tokens=15,
52
+ # temperature=0.9,
53
+ # top_p=0.8,
54
+ # do_sample=True,
55
+ # pad_token_id=tokenizer.eos_token_id if hasattr(tokenizer, "eos_token_id") else None,
56
+ # stopping_criteria=stopping_criteria,
57
+ # )
58
+
59
+ # generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
60
+ # # Solo el fragmento después del prompt
61
+ # continuation = generated_text[len(prompt):].split('\n')[0]
62
+
63
+ # return {"generated_text": generated_text}
64
  from fastapi import FastAPI
65
  from pydantic import BaseModel
66
+ from typing import List, Optional
 
 
 
67
 
68
  app = FastAPI()
69
 
70
+ # Almacenamiento en memoria temporal
71
+ registro_actual = {}
 
 
72
 
73
  # Modelo de entrada
74
+ class DialogoEntrada(BaseModel):
75
+ enunciado: str
76
+ personajes: List[str] # lista de 3 personajes
77
+ relato_inicial: str
78
+ final_1: str
79
+ final_2: str
80
+ final_3: str
81
+
82
+ # Modelo de salida
83
+ class DialogoSalida(BaseModel):
84
+ enunciado: str
85
+ personajes: List[str]
86
+ relato_inicial: str
87
+ final_1: str
88
+ final_2: str
89
+ final_3: str
90
+
91
+ @app.post("/entrada")
92
+ async def registrar_dialogo(dialogo: DialogoEntrada):
93
+ global registro_actual
94
+ registro_actual = dialogo.dict() # Sobrescribe el contenido anterior
95
+ return {"status": "registro guardado"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ @app.get("/salida", response_model=Optional[DialogoSalida])
98
+ async def obtener_y_limpiar():
99
+ global registro_actual
100
+ if not registro_actual:
101
+ return None
102
+ salida = registro_actual
103
+ registro_actual = {} # Limpia después de devolver
104
+ return salida