雷娃 commited on
Commit
9423469
·
1 Parent(s): 26ca9d4

add stream output

Browse files
Files changed (1) hide show
  1. app.py +27 -10
app.py CHANGED
@@ -1,5 +1,6 @@
1
  # app.py
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
  import gradio as gr
4
  import torch
5
 
@@ -24,16 +25,32 @@ def chat(user_input, max_new_tokens=512):
24
 
25
  # encode the input prompt
26
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # generate response
29
- with torch.no_grad():
30
- outputs = model.generate(
31
- **inputs,
32
- max_new_tokens=max_new_tokens,
33
- pad_token_id=tokenizer.eos_token_id
34
- )
35
- response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
36
- return response
37
 
38
  # Construct Gradio Interface
39
  interface = gr.Interface(
@@ -42,7 +59,7 @@ interface = gr.Interface(
42
  gr.Textbox(lines=5, label="输入你的问题"),
43
  gr.Slider(minimum=100, maximum=1024, step=50, label="生成长度")
44
  ],
45
- outputs=gr.Textbox(label="模型回复"),
46
  title="Ling-lite-1.5 MoE 模型 Demo",
47
  description="基于 [inclusionAI/Ling-lite-1.5](https://huggingface.co/inclusionAI/Ling-lite-1.5) 的对话式文本生成演示。",
48
  examples=[
 
1
  # app.py
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
+ from threading import Thread
4
  import gradio as gr
5
  import torch
6
 
 
25
 
26
  # encode the input prompt
27
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
28
+
29
+ #create streamer
30
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
31
+
32
+ def generate():
33
+ model.generate(**inputs, max_new_tokens=max_new_tokens, streamer=streamer)
34
+
35
+ thread = Thread(target=generate)
36
+ thread.start()
37
+
38
+ generated_text = ""
39
+ for new_text in streamer:
40
+ generated_text += new_text
41
+ yield generated_text
42
+
43
+ thread.join()
44
 
45
  # generate response
46
+ #with torch.no_grad():
47
+ # outputs = model.generate(
48
+ # **inputs,
49
+ # max_new_tokens=max_new_tokens,
50
+ # pad_token_id=tokenizer.eos_token_id
51
+ # )
52
+ #response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
53
+ #return response
54
 
55
  # Construct Gradio Interface
56
  interface = gr.Interface(
 
59
  gr.Textbox(lines=5, label="输入你的问题"),
60
  gr.Slider(minimum=100, maximum=1024, step=50, label="生成长度")
61
  ],
62
+ outputs=gr.Textbox(label="模型回复", stream=True),
63
  title="Ling-lite-1.5 MoE 模型 Demo",
64
  description="基于 [inclusionAI/Ling-lite-1.5](https://huggingface.co/inclusionAI/Ling-lite-1.5) 的对话式文本生成演示。",
65
  examples=[