sarthak501 commited on
Commit
360a4d3
·
verified ·
1 Parent(s): 52808f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -40
app.py CHANGED
@@ -2,51 +2,31 @@ from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
- import torch
6
 
7
  app = FastAPI()
8
-
9
- # Enable CORS for frontend fetch requests
10
  app.add_middleware(
11
- CORSMiddleware,
12
- allow_origins=["*"],
13
- allow_credentials=False,
14
- allow_methods=["*"],
15
- allow_headers=["*"]
16
  )
17
 
18
- # Load FLAN-T5 model
19
- model_name = "google/flan-t5-base"
20
- tokenizer = AutoTokenizer.from_pretrained(model_name)
21
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
22
 
23
  class QueryRequest(BaseModel):
24
- query: str
25
- echo: bool = False
26
-
27
- @app.post("/api/query")
28
- async def generate_response(req: QueryRequest):
29
- query = req.query.strip()
30
- if not query:
31
- raise HTTPException(status_code=400, detail="Query must not be empty")
32
-
33
- if req.echo:
34
- return {"response": query}
35
-
36
- # Encode input
37
- inputs = tokenizer(query, return_tensors="pt", truncation=True)
38
-
39
- # Generate response with better decoding
40
- outputs = model.generate(
41
- **inputs,
42
- max_new_tokens=150,
43
- temperature=0.9,
44
- top_p=0.95,
45
- repetition_penalty=1.2,
46
- do_sample=True,
47
- num_return_sequences=1
48
- )
49
 
50
- # Decode output
51
- generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
- return {"response": generated}
 
 
 
 
 
 
 
 
 
 
 
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
5
 
6
  app = FastAPI()
 
 
7
  app.add_middleware(
8
+ CORSMiddleware, allow_origins=["*"], allow_credentials=False,
9
+ allow_methods=["*"], allow_headers=["*"]
 
 
 
10
  )
11
 
12
+ model_name = "NeuralNovel/Mistral-7B-Instruct-v0.2-Neural-Story"
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
 
15
 
16
  class QueryRequest(BaseModel):
17
+ prompt: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ @app.post("/api/generate-story")
20
+ def generate_story(req: QueryRequest):
21
+ if not req.prompt.strip():
22
+ raise HTTPException(status_code=400, detail="Prompt must not be empty")
23
+ inputs = tokenizer(req.prompt, return_tensors="pt", truncation=True)
24
+ outputs = model.generate(
25
+ **inputs,
26
+ max_new_tokens=200,
27
+ temperature=0.9,
28
+ top_p=0.95,
29
+ repetition_penalty=1.2,
30
+ do_sample=True
31
+ )
32
+ return {"story": tokenizer.decode(outputs[0], skip_special_tokens=True)}