gouravbhadraDev commited on
Commit
8b03c54
·
verified ·
1 Parent(s): 9f70441

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -23
app.py CHANGED
@@ -2,11 +2,8 @@ from fastapi import FastAPI, Query
2
  from pydantic import BaseModel
3
  import cloudscraper
4
  from bs4 import BeautifulSoup
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
6
- import torch
7
  import re
8
- import os
9
-
10
 
11
  app = FastAPI()
12
 
@@ -34,39 +31,34 @@ def scrape(url: str = Query(...)):
34
  return ThreadResponse(question=question, replies=replies)
35
  return ThreadResponse(question="", replies=[])
36
 
37
-
38
  MODEL_NAME = "google/flan-t5-small"
39
 
40
- # Load tokenizer and model once at startup, with device auto-mapping
41
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
42
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype="auto", device_map="auto")
43
- model.eval()
 
 
 
 
44
 
45
  class PromptRequest(BaseModel):
46
  prompt: str
47
 
48
  @app.post("/generate")
49
  async def generate_text(request: PromptRequest):
50
- # Prepare chat-style input with thinking mode enabled
51
- messages = [{"role": "user", "content": request.prompt}]
52
- text = tokenizer.apply_chat_template(messages, tokenize=False, enable_thinking=True)
53
-
54
- inputs = tokenizer([text], return_tensors="pt").to(model.device)
55
- with torch.no_grad():
56
- generated_ids = model.generate(**inputs, max_new_tokens=512, temperature=0.5)
57
- output_ids = generated_ids[:, inputs.input_ids.shape[-1]:].tolist()[0]
58
- output_text = tokenizer.decode(output_ids)
59
 
60
  # Extract reasoning and content parts if thinking tags are present
61
- if "</think>" in output_text:
62
- reasoning_content = output_text.split("</think>")[0].strip()
63
- content = output_text.split("</think>")[1].strip().rstrip("</s>")
64
  else:
65
  reasoning_content = ""
66
- content = output_text.strip().rstrip("</s>")
67
 
68
  return {
69
  "reasoning_content": reasoning_content,
70
  "generated_text": content
71
  }
72
-
 
2
  from pydantic import BaseModel
3
  import cloudscraper
4
  from bs4 import BeautifulSoup
5
+ from transformers import pipeline
 
6
  import re
 
 
7
 
8
  app = FastAPI()
9
 
 
31
  return ThreadResponse(question=question, replies=replies)
32
  return ThreadResponse(question="", replies=[])
33
 
 
34
  MODEL_NAME = "google/flan-t5-small"
35
 
36
+ # Load the pipeline once at startup with device auto-mapping
37
+ text_generator = pipeline(
38
+ "text2text-generation",
39
+ model=MODEL_NAME,
40
+ device=0 if torch.cuda.is_available() else -1,
41
+ max_new_tokens=512,
42
+ temperature=0.5
43
+ )
44
 
45
  class PromptRequest(BaseModel):
46
  prompt: str
47
 
48
  @app.post("/generate")
49
  async def generate_text(request: PromptRequest):
50
+ # Use the pipeline to generate text directly
51
+ output = text_generator(request.prompt)[0]['generated_text']
 
 
 
 
 
 
 
52
 
53
  # Extract reasoning and content parts if thinking tags are present
54
+ if "</think>" in output:
55
+ reasoning_content = output.split("</think>")[0].strip()
56
+ content = output.split("</think>")[1].strip().rstrip("</s>")
57
  else:
58
  reasoning_content = ""
59
+ content = output.strip().rstrip("</s>")
60
 
61
  return {
62
  "reasoning_content": reasoning_content,
63
  "generated_text": content
64
  }