sarthak501 commited on
Commit
b87d488
·
verified ·
1 Parent(s): 360a4d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -17
app.py CHANGED
@@ -5,28 +5,29 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
  app = FastAPI()
7
  app.add_middleware(
8
- CORSMiddleware, allow_origins=["*"], allow_credentials=False,
9
- allow_methods=["*"], allow_headers=["*"]
 
10
  )
11
 
12
  model_name = "NeuralNovel/Mistral-7B-Instruct-v0.2-Neural-Story"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
14
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
15
 
16
- class QueryRequest(BaseModel):
17
- prompt: str
18
 
19
  @app.post("/api/generate-story")
20
- def generate_story(req: QueryRequest):
21
- if not req.prompt.strip():
22
- raise HTTPException(status_code=400, detail="Prompt must not be empty")
23
- inputs = tokenizer(req.prompt, return_tensors="pt", truncation=True)
24
- outputs = model.generate(
25
- **inputs,
26
- max_new_tokens=200,
27
- temperature=0.9,
28
- top_p=0.95,
29
- repetition_penalty=1.2,
30
- do_sample=True
31
- )
32
- return {"story": tokenizer.decode(outputs[0], skip_special_tokens=True)}
 
5
 
6
  app = FastAPI()
7
  app.add_middleware(
8
+ CORSMiddleware,
9
+ allow_origins=["*"], allow_credentials=False,
10
+ allow_methods=["*"], allow_headers=["*"]
11
  )
12
 
13
  model_name = "NeuralNovel/Mistral-7B-Instruct-v0.2-Neural-Story"
14
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
15
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
16
 
17
+ class PromptRequest(BaseModel):
18
+ prompt: str
19
 
20
  @app.post("/api/generate-story")
21
+ async def generate_story(req: PromptRequest):
22
+ if not req.prompt.strip():
23
+ raise HTTPException(status_code=400, detail="Prompt must not be empty")
24
+ inputs = tokenizer(req.prompt, return_tensors="pt", truncation=True)
25
+ outputs = model.generate(
26
+ **inputs,
27
+ max_new_tokens=250,
28
+ temperature=0.9,
29
+ top_p=0.95,
30
+ repetition_penalty=1.2,
31
+ do_sample=True
32
+ )
33
+ return {"story": tokenizer.decode(outputs[0], skip_special_tokens=True)}