gouravbhadraDev commited on
Commit
d15392d
·
verified ·
1 Parent(s): 409ae10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -1
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from fastapi import FastAPI, Query
2
  from pydantic import BaseModel
3
  import cloudscraper
@@ -74,4 +75,111 @@ async def generate_text(request: PromptRequest):
74
  return {
75
  "reasoning_content": reasoning_content,
76
  "generated_text": content
77
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
  from fastapi import FastAPI, Query
3
  from pydantic import BaseModel
4
  import cloudscraper
 
75
  return {
76
  "reasoning_content": reasoning_content,
77
  "generated_text": content
78
+ }
79
+
80
+ '''
81
+
82
+ from fastapi import FastAPI, Query
83
+ from pydantic import BaseModel
84
+ import cloudscraper
85
+ from bs4 import BeautifulSoup
86
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
87
+ import torch
88
+ import re
89
+
90
+ app = FastAPI()
91
+
92
+ # --- Data Models ---
93
+
94
+ class ThreadResponse(BaseModel):
95
+ question: str
96
+ replies: list[str]
97
+
98
+ class PromptRequest(BaseModel):
99
+ prompt: str
100
+
101
+ class GenerateResponse(BaseModel):
102
+ reasoning_content: str
103
+ generated_text: str
104
+
105
+
106
+ # --- Utility Functions ---
107
+
108
+ def clean_text(text: str) -> str:
109
+ text = text.strip()
110
+ text = re.sub(r"\b\d+\s*likes?,?\s*\d*\s*replies?$", "", text, flags=re.IGNORECASE).strip()
111
+ return text
112
+
113
+
114
+ # --- Scraping Endpoint ---
115
+
116
+ @app.get("/scrape", response_model=ThreadResponse)
117
+ def scrape(url: str = Query(...)):
118
+ scraper = cloudscraper.create_scraper()
119
+ response = scraper.get(url)
120
+
121
+ if response.status_code == 200:
122
+ soup = BeautifulSoup(response.content, 'html.parser')
123
+ comment_containers = soup.find_all('div', class_='post__content')
124
+
125
+ if comment_containers:
126
+ question = clean_text(comment_containers[0].get_text(strip=True, separator="\n"))
127
+ replies = [clean_text(comment.get_text(strip=True, separator="\n")) for comment in comment_containers[1:]]
128
+ return ThreadResponse(question=question, replies=replies)
129
+ return ThreadResponse(question="", replies=[])
130
+
131
+
132
+ # --- Load T5-Small Model and Tokenizer ---
133
+
134
+ tokenizer = T5Tokenizer.from_pretrained("google/t5-small")
135
+ model = T5ForConditionalGeneration.from_pretrained("google/t5-small")
136
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
137
+ model = model.to(device)
138
+
139
+
140
+ # --- Core Generation Function Using T5 Prompting ---
141
+
142
+ def generate_text_with_t5(prompt: str) -> (str, str):
143
+ """
144
+ Accepts a prompt string that includes the T5 task prefix (e.g. "summarize: ..."),
145
+ generates output text, and optionally extracts reasoning if present.
146
+ Returns a tuple (reasoning_content, generated_text).
147
+ """
148
+ # Tokenize input prompt with truncation to max 512 tokens
149
+ inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True).to(device)
150
+
151
+ # Generate output tokens with beam search for quality
152
+ outputs = model.generate(
153
+ inputs,
154
+ max_length=512,
155
+ num_beams=4,
156
+ repetition_penalty=2.5,
157
+ length_penalty=1.0,
158
+ early_stopping=True,
159
+ )
160
+
161
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
162
+
163
+ # Optional: parse reasoning if your prompt/model uses a special separator like </think>
164
+ if "</think>" in generated_text:
165
+ reasoning_content, content = generated_text.split("</think>", 1)
166
+ reasoning_content = reasoning_content.strip()
167
+ content = content.strip()
168
+ else:
169
+ reasoning_content = ""
170
+ content = generated_text.strip()
171
+
172
+ return reasoning_content, content
173
+
174
+
175
+ # --- /generate Endpoint Using T5 Prompting ---
176
+
177
+ @app.post("/generate", response_model=GenerateResponse)
178
+ async def generate(request: PromptRequest):
179
+ """
180
+ Accepts a prompt string from frontend, which should include the T5 task prefix,
181
+ e.g. "summarize: {text to summarize}" or "translate English to German: {text}".
182
+ Returns generated text and optional reasoning content.
183
+ """
184
+ reasoning_content, generated_text = generate_text_with_t5(request.prompt)
185
+ return GenerateResponse(reasoning_content=reasoning_content, generated_text=generated_text)