general-api / app.py
sarthak501's picture
Update app.py
52808f5 verified
raw
history blame
1.38 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
app = FastAPI()
# Enable CORS for frontend fetch requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"]
)
# Load FLAN-T5 model
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
class QueryRequest(BaseModel):
query: str
echo: bool = False
@app.post("/api/query")
async def generate_response(req: QueryRequest):
query = req.query.strip()
if not query:
raise HTTPException(status_code=400, detail="Query must not be empty")
if req.echo:
return {"response": query}
# Encode input
inputs = tokenizer(query, return_tensors="pt", truncation=True)
# Generate response with better decoding
outputs = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.9,
top_p=0.95,
repetition_penalty=1.2,
do_sample=True,
num_return_sequences=1
)
# Decode output
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"response": generated}