sarthak501 commited on
Commit
801e389
·
verified ·
1 Parent(s): ec50ee8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -17
app.py CHANGED
@@ -4,7 +4,7 @@ from pydantic import BaseModel
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import os
6
 
7
- # Use a writable folder for offloading weights (Hugging Face Spaces restricts /app)
8
  offload_dir = "/tmp/offload"
9
  os.makedirs(offload_dir, exist_ok=True)
10
 
@@ -13,28 +13,24 @@ app = FastAPI()
13
  # CORS setup
14
  app.add_middleware(
15
  CORSMiddleware,
16
- allow_origins=["*"], # Allow all origins for testing
17
  allow_credentials=False,
18
  allow_methods=["*"],
19
  allow_headers=["*"]
20
  )
21
 
22
- # Model name (7B model - large, will offload to /tmp)
23
- model_name = "ethzanalytics/RedPajama-INCITE-7B-Base-sharded-bf16"
24
 
25
- # Load tokenizer
26
  tokenizer = AutoTokenizer.from_pretrained(model_name)
27
-
28
- # Load model with /tmp offload folder
29
  model = AutoModelForCausalLM.from_pretrained(
30
  model_name,
31
- torch_dtype="bfloat16",
32
  device_map="auto",
33
  low_cpu_mem_usage=True,
34
  offload_folder=offload_dir
35
  )
36
 
37
- # Request body schema
38
  class PromptRequest(BaseModel):
39
  prompt: str
40
 
@@ -44,19 +40,13 @@ async def generate_story(req: PromptRequest):
44
  if not prompt:
45
  raise HTTPException(status_code=400, detail="Prompt must not be empty")
46
 
47
- # Tokenize input
48
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
49
-
50
- # Generate story
51
  outputs = model.generate(
52
  **inputs,
53
- max_new_tokens=200,
54
  do_sample=True,
55
  temperature=0.9,
56
- top_p=0.9,
57
- repetition_penalty=1.2
58
  )
59
-
60
- # Decode and return
61
  story = tokenizer.decode(outputs[0], skip_special_tokens=True)
62
  return {"story": story}
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import os
6
 
7
+ # Use a writable folder for offloading weights
8
  offload_dir = "/tmp/offload"
9
  os.makedirs(offload_dir, exist_ok=True)
10
 
 
13
  # CORS setup
14
  app.add_middleware(
15
  CORSMiddleware,
16
+ allow_origins=["*"],
17
  allow_credentials=False,
18
  allow_methods=["*"],
19
  allow_headers=["*"]
20
  )
21
 
22
+ # Smaller & faster model
23
+ model_name = "tiiuae/falcon-rw-1b"
24
 
 
25
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
26
  model = AutoModelForCausalLM.from_pretrained(
27
  model_name,
28
+ torch_dtype="auto",
29
  device_map="auto",
30
  low_cpu_mem_usage=True,
31
  offload_folder=offload_dir
32
  )
33
 
 
34
  class PromptRequest(BaseModel):
35
  prompt: str
36
 
 
40
  if not prompt:
41
  raise HTTPException(status_code=400, detail="Prompt must not be empty")
42
 
 
43
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
 
 
44
  outputs = model.generate(
45
  **inputs,
46
+ max_new_tokens=150,
47
  do_sample=True,
48
  temperature=0.9,
49
+ top_p=0.9
 
50
  )
 
 
51
  story = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
  return {"story": story}