general-api / app.py
sarthak501's picture
Update app.py
801e389 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
# Use a writable folder for offloading weights
offload_dir = "/tmp/offload"
os.makedirs(offload_dir, exist_ok=True)
app = FastAPI()
# CORS setup
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"]
)
# Smaller & faster model
model_name = "tiiuae/falcon-rw-1b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
low_cpu_mem_usage=True,
offload_folder=offload_dir
)
class PromptRequest(BaseModel):
prompt: str
@app.post("/api/generate-story")
async def generate_story(req: PromptRequest):
prompt = req.prompt.strip()
if not prompt:
raise HTTPException(status_code=400, detail="Prompt must not be empty")
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=150,
do_sample=True,
temperature=0.9,
top_p=0.9
)
story = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"story": story}