gouravbhadraDev commited on
Commit
d61b7ff
·
verified ·
1 Parent(s): 6c63c6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -36
app.py CHANGED
@@ -114,13 +114,13 @@ def clean_text(text: str) -> str:
114
  # --- Scraping Endpoint ---
115
 
116
  @app.get("/scrape", response_model=ThreadResponse)
117
- def scrape(url: str = Query(...)):
118
  scraper = cloudscraper.create_scraper()
119
  response = scraper.get(url)
120
 
121
  if response.status_code == 200:
122
- soup = BeautifulSoup(response.content, 'html.parser')
123
- comment_containers = soup.find_all('div', class_='post__content')
124
 
125
  if comment_containers:
126
  question = clean_text(comment_containers[0].get_text(strip=True, separator="\n"))
@@ -129,27 +129,46 @@ def scrape(url: str = Query(...)):
129
  return ThreadResponse(question="", replies=[])
130
 
131
 
132
- # --- Load T5-Small Model and Tokenizer ---
133
 
134
- tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-large")
135
- model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-large")
 
136
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
137
- model = model.to(device)
 
 
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- # --- Core Generation Function Using T5 Prompting ---
 
141
 
142
- def generate_text_with_t5(prompt: str) -> (str, str):
143
- """
144
- Accepts a prompt string that includes the T5 task prefix (e.g. "summarize: ..."),
145
- generates output text, and optionally extracts reasoning if present.
146
- Returns a tuple (reasoning_content, generated_text).
147
- """
148
- # Tokenize input prompt with truncation to max 512 tokens
149
- inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True).to(device)
150
 
151
- # Generate output tokens with beam search for quality
152
- outputs = model.generate(
 
 
153
  inputs,
154
  max_length=512,
155
  num_beams=4,
@@ -157,29 +176,28 @@ def generate_text_with_t5(prompt: str) -> (str, str):
157
  length_penalty=1.0,
158
  early_stopping=True,
159
  )
 
160
 
161
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
162
-
163
- # Optional: parse reasoning if your prompt/model uses a special separator like </think>
164
  if "</think>" in generated_text:
165
  reasoning_content, content = generated_text.split("</think>", 1)
166
- reasoning_content = reasoning_content.strip()
167
- content = content.strip()
168
  else:
169
- reasoning_content = ""
170
- content = generated_text.strip()
171
 
172
- return reasoning_content, content
173
 
 
174
 
175
- # --- /generate Endpoint Using T5 Prompting ---
 
 
 
 
 
 
 
 
 
 
176
 
177
- @app.post("/generate", response_model=GenerateResponse)
178
- async def generate(request: PromptRequest):
179
- """
180
- Accepts a prompt string from frontend, which should include the T5 task prefix,
181
- e.g. "summarize: {text to summarize}" or "translate English to German: {text}".
182
- Returns generated text and optional reasoning content.
183
- """
184
- reasoning_content, generated_text = generate_text_with_t5(request.prompt)
185
- return GenerateResponse(reasoning_content=reasoning_content, generated_text=generated_text)
 
114
  # --- Scraping Endpoint ---
115
 
116
  @app.get("/scrape", response_model=ThreadResponse)
117
+ def scrape(url: str):
118
  scraper = cloudscraper.create_scraper()
119
  response = scraper.get(url)
120
 
121
  if response.status_code == 200:
122
+ soup = BeautifulSoup(response.content, "html.parser")
123
+ comment_containers = soup.find_all("div", class_="post__content")
124
 
125
  if comment_containers:
126
  question = clean_text(comment_containers[0].get_text(strip=True, separator="\n"))
 
129
  return ThreadResponse(question="", replies=[])
130
 
131
 
132
+ # --- Load DeepSeek-R1-Distill-Qwen-1.5B Model & Tokenizer ---
133
 
134
+ deepseek_model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
135
+ deepseek_tokenizer = AutoTokenizer.from_pretrained(deepseek_model_name)
136
+ deepseek_model = AutoModelForCausalLM.from_pretrained(deepseek_model_name)
137
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
138
+ deepseek_model = deepseek_model.to(device)
139
+
140
+
141
+ # --- Load T5-Large Model & Tokenizer ---
142
 
143
+ t5_model_name = "google-t5/t5-large"
144
+ t5_tokenizer = T5Tokenizer.from_pretrained(t5_model_name)
145
+ t5_model = T5ForConditionalGeneration.from_pretrained(t5_model_name)
146
+ t5_model = t5_model.to(device)
147
+
148
+
149
+ # --- Generation Functions ---
150
+
151
+ def generate_deepseek(prompt: str) -> (str, str):
152
+ inputs = deepseek_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device)
153
+ outputs = deepseek_model.generate(
154
+ **inputs,
155
+ max_length=512,
156
+ temperature=0.7,
157
+ top_p=0.9,
158
+ do_sample=True,
159
+ num_return_sequences=1,
160
+ pad_token_id=deepseek_tokenizer.eos_token_id,
161
+ )
162
+ generated_text = deepseek_tokenizer.decode(outputs[0], skip_special_tokens=True)
163
 
164
+ # DeepSeek models usually do not have a special reasoning delimiter, so return empty reasoning
165
+ return "", generated_text.strip()
166
 
 
 
 
 
 
 
 
 
167
 
168
+ def generate_t5(prompt: str) -> (str, str):
169
+ # T5 expects prompt with task prefix, e.g. "summarize: ..."
170
+ inputs = t5_tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True).to(device)
171
+ outputs = t5_model.generate(
172
  inputs,
173
  max_length=512,
174
  num_beams=4,
 
176
  length_penalty=1.0,
177
  early_stopping=True,
178
  )
179
+ generated_text = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
180
 
181
+ # Optional reasoning parsing if </think> is used
 
 
182
  if "</think>" in generated_text:
183
  reasoning_content, content = generated_text.split("</think>", 1)
184
+ return reasoning_content.strip(), content.strip()
 
185
  else:
186
+ return "", generated_text.strip()
 
187
 
 
188
 
189
+ # --- API Endpoints ---
190
 
191
+ @app.post("/generate/{model_name}", response_model=GenerateResponse)
192
+ async def generate(
193
+ request: PromptRequest,
194
+ model_name: str = Path(..., description="Model to use: 'deepseekr1-qwen' or 't5-large'")
195
+ ):
196
+ if model_name == "deepseekr1-qwen":
197
+ reasoning, text = generate_deepseek(request.prompt)
198
+ elif model_name == "t5-large":
199
+ reasoning, text = generate_t5(request.prompt)
200
+ else:
201
+ return {"reasoning_content": "", "generated_text": f"Error: Unknown model '{model_name}'."}
202
 
203
+ return GenerateResponse(reasoning_content=reasoning, generated_text=text)