gouravbhadraDev commited on
Commit
e9f3a9a
·
verified ·
1 Parent(s): ad10382

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -2
app.py CHANGED
@@ -83,7 +83,7 @@ from fastapi import FastAPI, Query, Path
83
  from pydantic import BaseModel
84
  import cloudscraper
85
  from bs4 import BeautifulSoup
86
- from transformers import AutoTokenizer, AutoModelForCausalLM, T5Tokenizer, T5ForConditionalGeneration
87
  import torch
88
  import re
89
  from fastapi.responses import JSONResponse
@@ -144,6 +144,13 @@ t5_tokenizer = T5Tokenizer.from_pretrained(t5_model_name)
144
  t5_model = T5ForConditionalGeneration.from_pretrained(t5_model_name)
145
  t5_model = t5_model.to(device)
146
 
 
 
 
 
 
 
 
147
  # --- Generation Functions ---
148
 
149
  def generate_deepseek(prompt: str) -> (str, str):
@@ -185,20 +192,45 @@ def generate_t5(prompt: str) -> (str, str):
185
 
186
  # --- API Endpoints ---
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  @app.post("/generate/{model_name}", response_model=GenerateResponse)
189
  async def generate(
190
  request: PromptRequest,
191
- model_name: str = Path(..., description="Model to use: 'deepseekr1-qwen' or 't5-large'")
192
  ):
193
  if model_name == "deepseekr1-qwen":
194
  reasoning, text = generate_deepseek(request.prompt)
195
  elif model_name == "t5-large":
196
  reasoning, text = generate_t5(request.prompt)
 
 
197
  else:
198
  return GenerateResponse(reasoning_content="", generated_text=f"Error: Unknown model '{model_name}'.")
199
 
200
  return GenerateResponse(reasoning_content=reasoning, generated_text=text)
201
 
 
202
  # --- Global Exception Handler ---
203
 
204
  @app.exception_handler(Exception)
 
83
  from pydantic import BaseModel
84
  import cloudscraper
85
  from bs4 import BeautifulSoup
86
+ from transformers import AutoTokenizer, AutoModelForCausalLM, T5Tokenizer, T5ForConditionalGeneration, PegasusTokenizer, PegasusForConditionalGeneration
87
  import torch
88
  import re
89
  from fastapi.responses import JSONResponse
 
144
  t5_model = T5ForConditionalGeneration.from_pretrained(t5_model_name)
145
  t5_model = t5_model.to(device)
146
 
147
+ pegasus_model_name = "google/pegasus-large"
148
+ pegasus_tokenizer = PegasusTokenizer.from_pretrained(pegasus_model_name)
149
+ pegasus_model = PegasusForConditionalGeneration.from_pretrained(pegasus_model_name)
150
+ pegasus_model = pegasus_model.to(device)
151
+
152
+
153
+
154
  # --- Generation Functions ---
155
 
156
  def generate_deepseek(prompt: str) -> (str, str):
 
192
 
193
  # --- API Endpoints ---
194
 
195
+ def generate_pegasus(prompt: str) -> (str, str):
196
+ # Pegasus expects raw text input (no prefix needed)
197
+ inputs = pegasus_tokenizer(
198
+ prompt,
199
+ return_tensors="pt",
200
+ truncation=True,
201
+ max_length=1024,
202
+ ).to(device)
203
+
204
+ outputs = pegasus_model.generate(
205
+ **inputs,
206
+ max_new_tokens=150,
207
+ num_beams=4,
208
+ length_penalty=2.0,
209
+ early_stopping=True,
210
+ )
211
+ generated_text = pegasus_tokenizer.decode(outputs[0], skip_special_tokens=True)
212
+
213
+ # Pegasus does not use <think> tags, so no reasoning extraction
214
+ return "", generated_text.strip()
215
+
216
+
217
  @app.post("/generate/{model_name}", response_model=GenerateResponse)
218
  async def generate(
219
  request: PromptRequest,
220
+ model_name: str = Path(..., description="Model to use: 'deepseekr1-qwen', 't5-large' or 'pegasus-large'")
221
  ):
222
  if model_name == "deepseekr1-qwen":
223
  reasoning, text = generate_deepseek(request.prompt)
224
  elif model_name == "t5-large":
225
  reasoning, text = generate_t5(request.prompt)
226
+ elif model_name == "pegasus-large":
227
+ reasoning, text = generate_pegasus(request.prompt)
228
  else:
229
  return GenerateResponse(reasoning_content="", generated_text=f"Error: Unknown model '{model_name}'.")
230
 
231
  return GenerateResponse(reasoning_content=reasoning, generated_text=text)
232
 
233
+
234
  # --- Global Exception Handler ---
235
 
236
  @app.exception_handler(Exception)