File size: 1,366 Bytes
d86fc01
 
 
7d1e868
e0f91c0
 
801e389
ec50ee8
 
d86fc01
 
e0f91c0
 
d86fc01
7d1e868
801e389
7d1e868
 
 
d86fc01
 
801e389
 
e0f91c0
06534f9
5dd83cb
 
801e389
e0f91c0
 
ec50ee8
5dd83cb
1c18d40
b87d488
 
52808f5
360a4d3
b87d488
06534f9
 
b87d488
06534f9
5dd83cb
b87d488
 
801e389
a4c01fe
5dd83cb
801e389
b87d488
7d1e868
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import os

# Use a writable folder for offloading weights
offload_dir = "/tmp/offload"
os.makedirs(offload_dir, exist_ok=True)

app = FastAPI()

# CORS setup
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=False,
    allow_methods=["*"],
    allow_headers=["*"]
)

# Smaller & faster model
model_name = "tiiuae/falcon-rw-1b"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto",
    low_cpu_mem_usage=True,
    offload_folder=offload_dir
)

class PromptRequest(BaseModel):
    prompt: str

@app.post("/api/generate-story")
async def generate_story(req: PromptRequest):
    prompt = req.prompt.strip()
    if not prompt:
        raise HTTPException(status_code=400, detail="Prompt must not be empty")

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=150,
        do_sample=True,
        temperature=0.9,
        top_p=0.9
    )
    story = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return {"story": story}