Spaces:
Sleeping
Sleeping
File size: 1,366 Bytes
d86fc01 7d1e868 e0f91c0 801e389 ec50ee8 d86fc01 e0f91c0 d86fc01 7d1e868 801e389 7d1e868 d86fc01 801e389 e0f91c0 06534f9 5dd83cb 801e389 e0f91c0 ec50ee8 5dd83cb 1c18d40 b87d488 52808f5 360a4d3 b87d488 06534f9 b87d488 06534f9 5dd83cb b87d488 801e389 a4c01fe 5dd83cb 801e389 b87d488 7d1e868 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
# Use a writable folder for offloading weights
offload_dir = "/tmp/offload"
os.makedirs(offload_dir, exist_ok=True)
app = FastAPI()
# CORS setup
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"]
)
# Smaller & faster model
model_name = "tiiuae/falcon-rw-1b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
low_cpu_mem_usage=True,
offload_folder=offload_dir
)
class PromptRequest(BaseModel):
prompt: str
@app.post("/api/generate-story")
async def generate_story(req: PromptRequest):
prompt = req.prompt.strip()
if not prompt:
raise HTTPException(status_code=400, detail="Prompt must not be empty")
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=150,
do_sample=True,
temperature=0.9,
top_p=0.9
)
story = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"story": story}
|