DauroCamilo commited on
Commit
543077e
·
verified ·
1 Parent(s): 7dc77a4
Files changed (1) hide show
  1. main.py +7 -4
main.py CHANGED
@@ -4,15 +4,18 @@ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
4
 
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
7
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
8
  from fastapi.responses import StreamingResponse
9
  import threading
10
 
11
  app = FastAPI()
12
 
13
- model_id = "GEB-AGI/geb-1.3b"
14
- tokenizer = AutoTokenizer.from_pretrained(model_id)
15
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
 
 
 
16
 
17
  class ChatRequest(BaseModel):
18
  message: str
 
4
 
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, AutoModel
8
  from fastapi.responses import StreamingResponse
9
  import threading
10
 
11
  app = FastAPI()
12
 
13
+ # model_id = "GEB-AGI/geb-1.3b"
14
+ # tokenizer = AutoTokenizer.from_pretrained(model_id)
15
+ # model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
16
+ model = AutoModel.from_pretrained("GEB-AGI/geb-1.3b", trust_remote_code=True).bfloat16().cuda()
17
+ tokenizer = AutoTokenizer.from_pretrained("GEB-AGI/geb-1.3b", trust_remote_code=True)
18
+
19
 
20
  class ChatRequest(BaseModel):
21
  message: str