general-api / app.py
sarthak501's picture
Update app.py
1c18d40 verified
raw
history blame
1.16 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"]
)
# Load FLAN-T5 model and tokenizer
model_name = "google/flan-t5-base" # or use "flan-t5-large" if space/resources allow
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
class QueryRequest(BaseModel):
query: str
echo: bool = False
@app.post("/api/query")
async def generate_response(req: QueryRequest):
query = req.query.strip()
if not query:
raise HTTPException(status_code=400, detail="Query must not be empty")
if req.echo:
return {"response": query}
inputs = tokenizer(query, return_tensors="pt", truncation=True)
outputs = model.generate(**inputs, max_new_tokens=200)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"response": generated}