sarthak501 commited on
Commit
5dd83cb
·
verified ·
1 Parent(s): b947ef0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -2,7 +2,6 @@ from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
- import torch
6
 
7
  app = FastAPI()
8
  app.add_middleware(
@@ -13,9 +12,13 @@ app.add_middleware(
13
  allow_headers=["*"]
14
  )
15
 
16
- model_name = "togethercomputer/RedPajama-INCITE-7B-Base"
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
19
 
20
  class PromptRequest(BaseModel):
21
  prompt: str
@@ -26,13 +29,13 @@ async def generate_story(req: PromptRequest):
26
  if not prompt:
27
  raise HTTPException(status_code=400, detail="Prompt must not be empty")
28
 
29
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
30
  outputs = model.generate(
31
  **inputs,
32
  max_new_tokens=200,
33
  do_sample=True,
 
34
  top_p=0.9,
35
- temperature=0.85,
36
  repetition_penalty=1.2
37
  )
38
  story = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
5
 
6
  app = FastAPI()
7
  app.add_middleware(
 
12
  allow_headers=["*"]
13
  )
14
 
15
+ model_name = "ethzanalytics/RedPajama-INCITE-7B-Base-sharded-bf16"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_name,
19
+ torch_dtype="bfloat16",
20
+ device_map="auto"
21
+ )
22
 
23
  class PromptRequest(BaseModel):
24
  prompt: str
 
29
  if not prompt:
30
  raise HTTPException(status_code=400, detail="Prompt must not be empty")
31
 
32
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
33
  outputs = model.generate(
34
  **inputs,
35
  max_new_tokens=200,
36
  do_sample=True,
37
+ temperature=0.9,
38
  top_p=0.9,
 
39
  repetition_penalty=1.2
40
  )
41
  story = tokenizer.decode(outputs[0], skip_special_tokens=True)