sarthak501 commited on
Commit
1c18d40
·
verified ·
1 Parent(s): 54e8c97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -1,30 +1,39 @@
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
- import random
 
5
 
6
  app = FastAPI()
7
 
8
- # Enable CORS for JavaScript fetch usage across origins
9
  app.add_middleware(
10
  CORSMiddleware,
11
- allow_origins=["*"],
12
  allow_credentials=False,
13
  allow_methods=["*"],
14
  allow_headers=["*"]
15
  )
16
 
 
 
 
 
 
17
  class QueryRequest(BaseModel):
18
  query: str
19
  echo: bool = False
20
 
21
  @app.post("/api/query")
22
- async def general_query(req: QueryRequest):
23
- text = req.query.strip()
24
- if not text:
25
  raise HTTPException(status_code=400, detail="Query must not be empty")
 
26
  if req.echo:
27
- return {"response": text}
28
- words = text.split()
29
- random.shuffle(words)
30
- return {"response": " ".join(words)}
 
 
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ import torch
6
 
7
  app = FastAPI()
8
 
 
9
  app.add_middleware(
10
  CORSMiddleware,
11
+ allow_origins=["*"],
12
  allow_credentials=False,
13
  allow_methods=["*"],
14
  allow_headers=["*"]
15
  )
16
 
17
+ # Load FLAN-T5 model and tokenizer
18
+ model_name = "google/flan-t5-base" # or use "flan-t5-large" if space/resources allow
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
21
+
22
  class QueryRequest(BaseModel):
23
  query: str
24
  echo: bool = False
25
 
26
  @app.post("/api/query")
27
+ async def generate_response(req: QueryRequest):
28
+ query = req.query.strip()
29
+ if not query:
30
  raise HTTPException(status_code=400, detail="Query must not be empty")
31
+
32
  if req.echo:
33
+ return {"response": query}
34
+
35
+ inputs = tokenizer(query, return_tensors="pt", truncation=True)
36
+ outputs = model.generate(**inputs, max_new_tokens=200)
37
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
+
39
+ return {"response": generated}