gouravbhadraDev commited on
Commit
a0b62ab
·
verified ·
1 Parent(s): e9f3a9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -1
app.py CHANGED
@@ -149,6 +149,10 @@ pegasus_tokenizer = PegasusTokenizer.from_pretrained(pegasus_model_name)
149
  pegasus_model = PegasusForConditionalGeneration.from_pretrained(pegasus_model_name)
150
  pegasus_model = pegasus_model.to(device)
151
 
 
 
 
 
152
 
153
 
154
  # --- Generation Functions ---
@@ -213,11 +217,37 @@ def generate_pegasus(prompt: str) -> (str, str):
213
  # Pegasus does not use <think> tags, so no reasoning extraction
214
  return "", generated_text.strip()
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  @app.post("/generate/{model_name}", response_model=GenerateResponse)
218
  async def generate(
219
  request: PromptRequest,
220
- model_name: str = Path(..., description="Model to use: 'deepseekr1-qwen', 't5-large' or 'pegasus-large'")
221
  ):
222
  if model_name == "deepseekr1-qwen":
223
  reasoning, text = generate_deepseek(request.prompt)
@@ -225,12 +255,15 @@ async def generate(
225
  reasoning, text = generate_t5(request.prompt)
226
  elif model_name == "pegasus-large":
227
  reasoning, text = generate_pegasus(request.prompt)
 
 
228
  else:
229
  return GenerateResponse(reasoning_content="", generated_text=f"Error: Unknown model '{model_name}'.")
230
 
231
  return GenerateResponse(reasoning_content=reasoning, generated_text=text)
232
 
233
 
 
234
  # --- Global Exception Handler ---
235
 
236
  @app.exception_handler(Exception)
 
149
  pegasus_model = PegasusForConditionalGeneration.from_pretrained(pegasus_model_name)
150
  pegasus_model = pegasus_model.to(device)
151
 
152
+ qwen3_model_name = "Qwen/Qwen3-0.6B"
153
+ qwen3_tokenizer = AutoTokenizer.from_pretrained(qwen3_model_name)
154
+ qwen3_model = AutoModelForCausalLM.from_pretrained(qwen3_model_name)
155
+ qwen3_model = qwen3_model.to(device)
156
 
157
 
158
  # --- Generation Functions ---
 
217
  # Pegasus does not use <think> tags, so no reasoning extraction
218
  return "", generated_text.strip()
219
 
220
+ def generate_qwen3(prompt: str) -> (str, str):
221
+ inputs = qwen3_tokenizer(
222
+ prompt,
223
+ return_tensors="pt",
224
+ truncation=True,
225
+ max_length=1024,
226
+ ).to(device)
227
+
228
+ outputs = qwen3_model.generate(
229
+ **inputs,
230
+ max_new_tokens=512,
231
+ temperature=0.7,
232
+ top_p=0.9,
233
+ do_sample=True,
234
+ num_return_sequences=1,
235
+ pad_token_id=qwen3_tokenizer.eos_token_id,
236
+ )
237
+
238
+ generated_text = qwen3_tokenizer.decode(outputs[0], skip_special_tokens=True)
239
+
240
+ if "</think>" in generated_text:
241
+ reasoning_content, content = generated_text.split("</think>", 1)
242
+ return reasoning_content.strip(), content.strip()
243
+ else:
244
+ return "", generated_text.strip()
245
+
246
 
247
  @app.post("/generate/{model_name}", response_model=GenerateResponse)
248
  async def generate(
249
  request: PromptRequest,
250
+ model_name: str = Path(..., description="Model to use: 'deepseekr1-qwen', 't5-large', 'pegasus-large', or 'qwen3-0.6b'")
251
  ):
252
  if model_name == "deepseekr1-qwen":
253
  reasoning, text = generate_deepseek(request.prompt)
 
255
  reasoning, text = generate_t5(request.prompt)
256
  elif model_name == "pegasus-large":
257
  reasoning, text = generate_pegasus(request.prompt)
258
+ elif model_name == "qwen3-0.6b":
259
+ reasoning, text = generate_qwen3(request.prompt)
260
  else:
261
  return GenerateResponse(reasoning_content="", generated_text=f"Error: Unknown model '{model_name}'.")
262
 
263
  return GenerateResponse(reasoning_content=reasoning, generated_text=text)
264
 
265
 
266
+
267
  # --- Global Exception Handler ---
268
 
269
  @app.exception_handler(Exception)