gouravbhadraDev commited on
Commit
6e95583
·
verified ·
1 Parent(s): c11a3a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -36,15 +36,14 @@ def scrape(url: str = Query(...)):
36
  return ThreadResponse(question=question, replies=replies)
37
  return ThreadResponse(question="", replies=[])
38
 
39
- MODEL_NAME = "google/flan-t5-small"
40
 
41
  # Load the pipeline once at startup with device auto-mapping
42
  text_generator = pipeline(
43
- "text2text-generation",
44
  model=MODEL_NAME,
 
45
  device=0 if torch.cuda.is_available() else -1,
46
- max_new_tokens=512,
47
- temperature=0.5
48
  )
49
 
50
  class PromptRequest(BaseModel):
@@ -52,18 +51,24 @@ class PromptRequest(BaseModel):
52
 
53
  @app.post("/generate")
54
  async def generate_text(request: PromptRequest):
55
- # Use the pipeline to generate text directly
56
- output = text_generator(request.prompt)[0]['generated_text']
 
 
 
 
 
 
57
 
58
- # Extract reasoning and content parts if thinking tags are present
59
- if "</think>" in output:
60
- reasoning_content = output.split("</think>")[0].strip()
61
- content = output.split("</think>")[1].strip().rstrip("</s>")
62
  else:
63
  reasoning_content = ""
64
- content = output.strip().rstrip("</s>")
65
 
66
  return {
67
  "reasoning_content": reasoning_content,
68
  "generated_text": content
69
- }
 
36
  return ThreadResponse(question=question, replies=replies)
37
  return ThreadResponse(question="", replies=[])
38
 
39
+ MODEL_NAME = "deepseek-ai/DeepSeek-R1"
40
 
41
  # Load the pipeline once at startup with device auto-mapping
42
  text_generator = pipeline(
43
+ "text-generation",
44
  model=MODEL_NAME,
45
+ trust_remote_code=True,
46
  device=0 if torch.cuda.is_available() else -1,
 
 
47
  )
48
 
49
  class PromptRequest(BaseModel):
 
51
 
52
  @app.post("/generate")
53
  async def generate_text(request: PromptRequest):
54
+ # Prepare messages as expected by the model pipeline
55
+ messages = [{"role": "user", "content": request.prompt}]
56
+
57
+ # Call the pipeline with messages
58
+ outputs = text_generator(messages)
59
+
60
+ # The pipeline returns a list of dicts with 'generated_text'
61
+ generated_text = outputs[0]['generated_text']
62
 
63
+ # Optional: parse reasoning and content if your model uses special tags like </think>
64
+ if "</think>" in generated_text:
65
+ reasoning_content = generated_text.split("</think>")[0].strip()
66
+ content = generated_text.split("</think>")[1].strip()
67
  else:
68
  reasoning_content = ""
69
+ content = generated_text.strip()
70
 
71
  return {
72
  "reasoning_content": reasoning_content,
73
  "generated_text": content
74
+ }