sarthak501 commited on
Commit
06534f9
·
verified ·
1 Parent(s): 5d54b97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -10
app.py CHANGED
@@ -1,31 +1,34 @@
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
  app = FastAPI()
7
  app.add_middleware(
8
- CORSMiddleware, allow_origins=["*"], allow_credentials=False,
 
9
  allow_methods=["*"], allow_headers=["*"]
10
  )
11
 
12
- model_name = "NeuralNovel/Mistral-7B-Instruct-v0.2-Neural-Story"
13
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
14
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
15
 
16
  class PromptRequest(BaseModel):
17
  prompt: str
18
 
19
  @app.post("/api/generate-story")
20
  async def generate_story(req: PromptRequest):
21
- if not req.prompt.strip():
 
22
  raise HTTPException(status_code=400, detail="Prompt must not be empty")
23
- inputs = tokenizer(req.prompt, return_tensors="pt", truncation=True)
 
24
  outputs = model.generate(
25
  **inputs,
26
- max_new_tokens=250,
27
- temperature=0.9,
28
- top_p=0.95,
29
  repetition_penalty=1.2,
30
  do_sample=True
31
  )
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
  app = FastAPI()
7
  app.add_middleware(
8
+ CORSMiddleware,
9
+ allow_origins=["*"], allow_credentials=False,
10
  allow_methods=["*"], allow_headers=["*"]
11
  )
12
 
13
+ model_name = "tiiuae/falcon-7b-instruct"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModelForCausalLM.from_pretrained(model_name)
16
 
17
  class PromptRequest(BaseModel):
18
  prompt: str
19
 
20
  @app.post("/api/generate-story")
21
  async def generate_story(req: PromptRequest):
22
+ prompt = req.prompt.strip()
23
+ if not prompt:
24
  raise HTTPException(status_code=400, detail="Prompt must not be empty")
25
+
26
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
27
  outputs = model.generate(
28
  **inputs,
29
+ max_new_tokens=200,
30
+ temperature=0.85,
31
+ top_p=0.9,
32
  repetition_penalty=1.2,
33
  do_sample=True
34
  )