surkovvv commited on
Commit
dc8887b
·
1 Parent(s): b261ed6

some fixes in generating kwargs

Browse files
Files changed (1) hide show
  1. app.py +7 -23
app.py CHANGED
@@ -1,40 +1,23 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
  from threading import Thread
 
5
 
6
 
7
  tokenizer = AutoTokenizer.from_pretrained("IlyaGusev/saiga_llama3_8b")
8
  model = AutoModelForCausalLM.from_pretrained("IlyaGusev/saiga_llama3_8b", torch_dtype=torch.bfloat16)
9
- model = model #.to('cuda')
10
-
11
-
12
- class StopOnTokens(StoppingCriteria):
13
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
14
- stop_ids = [29, 0]
15
- for stop_id in stop_ids:
16
- if input_ids[0][-1] == stop_id:
17
- return True
18
- return False
19
 
20
 
21
  def predict(message, history):
22
  print(history)
23
  history_transformer_format = history + [{"role": "user", "content": message},
24
  {"role": "assistant", "content": ""}]
25
- stop = StopOnTokens()
26
 
27
- # messages = "".join(["".join(["<|start_header_id|>user<|end_header_id|>\n"+item[0],
28
- # "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"+item[1]])
29
- # for item in history_transformer_format])
30
- # messages = [{"role": "user", item[0], "content": item[1]} for item in history_transformer_format]
31
- #print(messages)
32
-
33
- # model_inputs = tokenizer([messages], return_tensors="pt") # .to("cuda")
34
  model_inputs = tokenizer.apply_chat_template(history_transformer_format, return_tensors="pt")
35
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
36
  generate_kwargs = dict(
37
- model_inputs,
38
  streamer=streamer,
39
  max_new_tokens=1024,
40
  do_sample=True,
@@ -42,9 +25,9 @@ def predict(message, history):
42
  top_k=1000,
43
  temperature=1.0,
44
  num_beams=1,
45
- stopping_criteria=StoppingCriteriaList([stop])
46
- )
47
- t = Thread(target=model.generate, kwargs=generate_kwargs)
48
  t.start()
49
 
50
  partial_message = ""
@@ -53,4 +36,5 @@ def predict(message, history):
53
  partial_message += new_token
54
  yield partial_message
55
 
 
56
  gr.ChatInterface(predict).launch(share=True)
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  from threading import Thread
5
+ from functools import partial
6
 
7
 
8
  tokenizer = AutoTokenizer.from_pretrained("IlyaGusev/saiga_llama3_8b")
9
  model = AutoModelForCausalLM.from_pretrained("IlyaGusev/saiga_llama3_8b", torch_dtype=torch.bfloat16)
10
+ model = model
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  def predict(message, history):
14
  print(history)
15
  history_transformer_format = history + [{"role": "user", "content": message},
16
  {"role": "assistant", "content": ""}]
 
17
 
 
 
 
 
 
 
 
18
  model_inputs = tokenizer.apply_chat_template(history_transformer_format, return_tensors="pt")
19
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
20
  generate_kwargs = dict(
 
21
  streamer=streamer,
22
  max_new_tokens=1024,
23
  do_sample=True,
 
25
  top_k=1000,
26
  temperature=1.0,
27
  num_beams=1,
28
+ )
29
+ generating_func = partial(model.generate, model_inputs)
30
+ t = Thread(target=generating_func, kwargs=generate_kwargs)
31
  t.start()
32
 
33
  partial_message = ""
 
36
  partial_message += new_token
37
  yield partial_message
38
 
39
+
40
  gr.ChatInterface(predict).launch(share=True)