雷娃 commited on
Commit
2493f19
·
1 Parent(s): 1d199f5

add interactive mode

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -2,6 +2,7 @@
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
  from threading import Thread
4
  import gradio as gr
 
5
  import torch
6
 
7
  # load model and tokenizer
@@ -29,20 +30,31 @@ def chat(user_input, max_new_tokens=512):
29
  #create streamer
30
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
31
 
 
 
 
 
 
 
32
  def generate():
33
  model.generate(**inputs, max_new_tokens=max_new_tokens, streamer=streamer)
34
 
35
  thread = Thread(target=generate)
36
  thread.start()
37
 
38
- prompt_len = len(prompt)
39
- print(prompt)
40
  generated_text = ""
41
  for new_text in streamer:
42
  generated_text += new_text
43
- print(generated_text)
 
 
 
 
 
44
  #yield generated_text
45
- yield generated_text[prompt_len:]
 
46
 
47
  thread.join()
48
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
  from threading import Thread
4
  import gradio as gr
5
+ import re
6
  import torch
7
 
8
  # load model and tokenizer
 
30
  #create streamer
31
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
32
 
33
+ def get_start_idx(response, input):
34
+ match = re.search(re.escape(response), input)
35
+ if not match:
36
+ return -1
37
+ return match.end()
38
+
39
  def generate():
40
  model.generate(**inputs, max_new_tokens=max_new_tokens, streamer=streamer)
41
 
42
  thread = Thread(target=generate)
43
  thread.start()
44
 
45
+ start_idx = -1
 
46
  generated_text = ""
47
  for new_text in streamer:
48
  generated_text += new_text
49
+
50
+ if (start_idx == -1):
51
+ start_idx = get_start_idx(generated_text, user_input)
52
+ if (start_idx != -1):
53
+ start_idx += len("ASSISTANT")
54
+ #print(generated_text)
55
  #yield generated_text
56
+ if (start_idx > 0):
57
+ yield generated_text[start_idx:]
58
 
59
  thread.join()
60