from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM import os # Use a writable folder for offloading weights offload_dir = "/tmp/offload" os.makedirs(offload_dir, exist_ok=True) app = FastAPI() # CORS setup app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"] ) # Smaller & faster model model_name = "tiiuae/falcon-rw-1b" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto", low_cpu_mem_usage=True, offload_folder=offload_dir ) class PromptRequest(BaseModel): prompt: str @app.post("/api/generate-story") async def generate_story(req: PromptRequest): prompt = req.prompt.strip() if not prompt: raise HTTPException(status_code=400, detail="Prompt must not be empty") inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device) outputs = model.generate( **inputs, max_new_tokens=150, do_sample=True, temperature=0.9, top_p=0.9 ) story = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"story": story}