sarthak501 commited on
Commit
52808f5
·
verified ·
1 Parent(s): 4417053

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
 
7
  app = FastAPI()
8
 
 
9
  app.add_middleware(
10
  CORSMiddleware,
11
  allow_origins=["*"],
@@ -14,8 +15,8 @@ app.add_middleware(
14
  allow_headers=["*"]
15
  )
16
 
17
- # Load FLAN-T5 model and tokenizer
18
- model_name = "google/flan-t5-base" # or use "flan-t5-large" if space/resources allow
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
21
 
@@ -32,8 +33,20 @@ async def generate_response(req: QueryRequest):
32
  if req.echo:
33
  return {"response": query}
34
 
 
35
  inputs = tokenizer(query, return_tensors="pt", truncation=True)
36
- outputs = model.generate(**inputs, max_new_tokens=200)
37
- generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  return {"response": generated}
 
6
 
7
  app = FastAPI()
8
 
9
+ # Enable CORS for frontend fetch requests
10
  app.add_middleware(
11
  CORSMiddleware,
12
  allow_origins=["*"],
 
15
  allow_headers=["*"]
16
  )
17
 
18
+ # Load FLAN-T5 model
19
+ model_name = "google/flan-t5-base"
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
22
 
 
33
  if req.echo:
34
  return {"response": query}
35
 
36
+ # Encode input
37
  inputs = tokenizer(query, return_tensors="pt", truncation=True)
 
 
38
 
39
+ # Generate response with better decoding
40
+ outputs = model.generate(
41
+ **inputs,
42
+ max_new_tokens=150,
43
+ temperature=0.9,
44
+ top_p=0.95,
45
+ repetition_penalty=1.2,
46
+ do_sample=True,
47
+ num_return_sequences=1
48
+ )
49
+
50
+ # Decode output
51
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
  return {"response": generated}