Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
import os
|
6 |
|
7 |
-
# Use a writable folder for offloading weights
|
8 |
offload_dir = "/tmp/offload"
|
9 |
os.makedirs(offload_dir, exist_ok=True)
|
10 |
|
@@ -13,28 +13,24 @@ app = FastAPI()
|
|
13 |
# CORS setup
|
14 |
app.add_middleware(
|
15 |
CORSMiddleware,
|
16 |
-
allow_origins=["*"],
|
17 |
allow_credentials=False,
|
18 |
allow_methods=["*"],
|
19 |
allow_headers=["*"]
|
20 |
)
|
21 |
|
22 |
-
#
|
23 |
-
model_name = "
|
24 |
|
25 |
-
# Load tokenizer
|
26 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
27 |
-
|
28 |
-
# Load model with /tmp offload folder
|
29 |
model = AutoModelForCausalLM.from_pretrained(
|
30 |
model_name,
|
31 |
-
torch_dtype="
|
32 |
device_map="auto",
|
33 |
low_cpu_mem_usage=True,
|
34 |
offload_folder=offload_dir
|
35 |
)
|
36 |
|
37 |
-
# Request body schema
|
38 |
class PromptRequest(BaseModel):
|
39 |
prompt: str
|
40 |
|
@@ -44,19 +40,13 @@ async def generate_story(req: PromptRequest):
|
|
44 |
if not prompt:
|
45 |
raise HTTPException(status_code=400, detail="Prompt must not be empty")
|
46 |
|
47 |
-
# Tokenize input
|
48 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
|
49 |
-
|
50 |
-
# Generate story
|
51 |
outputs = model.generate(
|
52 |
**inputs,
|
53 |
-
max_new_tokens=
|
54 |
do_sample=True,
|
55 |
temperature=0.9,
|
56 |
-
top_p=0.9
|
57 |
-
repetition_penalty=1.2
|
58 |
)
|
59 |
-
|
60 |
-
# Decode and return
|
61 |
story = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
62 |
return {"story": story}
|
|
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
import os
|
6 |
|
7 |
+
# Use a writable folder for offloading weights
|
8 |
offload_dir = "/tmp/offload"
|
9 |
os.makedirs(offload_dir, exist_ok=True)
|
10 |
|
|
|
13 |
# CORS setup
|
14 |
app.add_middleware(
|
15 |
CORSMiddleware,
|
16 |
+
allow_origins=["*"],
|
17 |
allow_credentials=False,
|
18 |
allow_methods=["*"],
|
19 |
allow_headers=["*"]
|
20 |
)
|
21 |
|
22 |
+
# Smaller & faster model
|
23 |
+
model_name = "tiiuae/falcon-rw-1b"
|
24 |
|
|
|
25 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
26 |
model = AutoModelForCausalLM.from_pretrained(
|
27 |
model_name,
|
28 |
+
torch_dtype="auto",
|
29 |
device_map="auto",
|
30 |
low_cpu_mem_usage=True,
|
31 |
offload_folder=offload_dir
|
32 |
)
|
33 |
|
|
|
34 |
class PromptRequest(BaseModel):
|
35 |
prompt: str
|
36 |
|
|
|
40 |
if not prompt:
|
41 |
raise HTTPException(status_code=400, detail="Prompt must not be empty")
|
42 |
|
|
|
43 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
|
|
|
|
|
44 |
outputs = model.generate(
|
45 |
**inputs,
|
46 |
+
max_new_tokens=150,
|
47 |
do_sample=True,
|
48 |
temperature=0.9,
|
49 |
+
top_p=0.9
|
|
|
50 |
)
|
|
|
|
|
51 |
story = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
52 |
return {"story": story}
|