oggata commited on
Commit
d228260
·
verified ·
1 Parent(s): 173f32e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -45
app.py CHANGED
@@ -22,21 +22,23 @@ model = AutoModelForCausalLM.from_pretrained(
22
  )
23
  print("モデルの読み込みが完了しました。")
24
 
25
- # グローバル設定変数
26
- SYSTEM_MESSAGE = "あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。"
27
- MAX_TOKENS = 512
28
- TEMPERATURE = 0.7
29
- TOP_P = 0.95
30
-
31
- def respond(message, history):
 
32
  """
33
  チャットボットの応答を生成する関数
 
34
  """
35
  try:
36
  # システムメッセージと会話履歴を含むプロンプトを構築
37
  conversation = ""
38
- if SYSTEM_MESSAGE.strip():
39
- conversation += f"システム: {SYSTEM_MESSAGE}\n"
40
 
41
  # 会話履歴を追加
42
  for user_msg, bot_msg in history:
@@ -55,13 +57,15 @@ def respond(message, history):
55
  if torch.cuda.is_available():
56
  inputs = inputs.cuda()
57
 
58
- # 応答生成
 
59
  with torch.no_grad():
 
60
  outputs = model.generate(
61
  inputs,
62
- max_new_tokens=MAX_TOKENS,
63
- temperature=TEMPERATURE,
64
- top_p=TOP_P,
65
  do_sample=True,
66
  pad_token_id=tokenizer.eos_token_id,
67
  eos_token_id=tokenizer.eos_token_id,
@@ -78,46 +82,65 @@ def respond(message, history):
78
  if "ユーザー:" in full_response:
79
  full_response = full_response.split("ユーザー:")[0].strip()
80
 
81
- return full_response
 
 
 
82
 
83
  except Exception as e:
84
- return f"エラーが発生しました: {str(e)}"
85
 
86
- # Gradio Blocksを使用したチャットインターフェース
87
- with gr.Blocks(
 
 
 
 
88
  title="🤖 Sarashina Chatbot",
89
- theme=gr.themes.Soft()
90
- ) as demo:
91
-
92
- gr.Markdown("# 🤖 Sarashina Chatbot")
93
- gr.Markdown("Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。")
94
-
95
- chatbot = gr.Chatbot(height=500)
96
- msg = gr.Textbox(
97
- label="メッセージを入力してください",
98
- placeholder="こんにちは!何かお手伝いできることはありますか?",
99
- lines=2
100
- )
101
- clear = gr.Button("会話をクリア")
102
-
103
- def user(message, history):
104
- return "", history + [[message, None]]
105
-
106
- def bot(history):
107
- history[-1][1] = respond(history[-1][0], history[:-1])
108
- return history
109
-
110
- msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
111
- bot, chatbot, chatbot
112
- )
113
-
114
- clear.click(lambda: None, None, chatbot, queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  if __name__ == "__main__":
117
  demo.launch(
118
  server_name="0.0.0.0",
119
  server_port=7860,
120
- share=True, # パブリックリンクを作成
121
- show_api=True,
122
  debug=True
123
  )
 
22
  )
23
  print("モデルの読み込みが完了しました。")
24
 
25
+ def respond(
26
+ message,
27
+ history: list[tuple[str, str]],
28
+ system_message,
29
+ max_tokens,
30
+ temperature,
31
+ top_p,
32
+ ):
33
  """
34
  チャットボットの応答を生成する関数
35
+ Gradio ChatInterfaceの標準形式に対応
36
  """
37
  try:
38
  # システムメッセージと会話履歴を含むプロンプトを構築
39
  conversation = ""
40
+ if system_message.strip():
41
+ conversation += f"システム: {system_message}\n"
42
 
43
  # 会話履歴を追加
44
  for user_msg, bot_msg in history:
 
57
  if torch.cuda.is_available():
58
  inputs = inputs.cuda()
59
 
60
+ # 応答生成(ストリーミング対応)
61
+ response = ""
62
  with torch.no_grad():
63
+ # 一度に生成してからストリーミング風に出力
64
  outputs = model.generate(
65
  inputs,
66
+ max_new_tokens=max_tokens,
67
+ temperature=temperature,
68
+ top_p=top_p,
69
  do_sample=True,
70
  pad_token_id=tokenizer.eos_token_id,
71
  eos_token_id=tokenizer.eos_token_id,
 
82
  if "ユーザー:" in full_response:
83
  full_response = full_response.split("ユーザー:")[0].strip()
84
 
85
+ # ストリーミング風の出力
86
+ for i in range(len(full_response)):
87
+ response = full_response[:i+1]
88
+ yield response
89
 
90
  except Exception as e:
91
+ yield f"エラーが発生しました: {str(e)}"
92
 
93
+ """
94
+ Gradio ChatInterfaceを使用したシンプルなチャットボット
95
+ カスタマイズ可能なパラメータを含む
96
+ """
97
+ demo = gr.ChatInterface(
98
+ respond,
99
  title="🤖 Sarashina Chatbot",
100
+ description="Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。",
101
+ additional_inputs=[
102
+ gr.Textbox(
103
+ value="あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。",
104
+ label="システムメッセージ",
105
+ lines=3
106
+ ),
107
+ gr.Slider(
108
+ minimum=1,
109
+ maximum=1024,
110
+ value=512,
111
+ step=1,
112
+ label="最大新規トークン数"
113
+ ),
114
+ gr.Slider(
115
+ minimum=0.1,
116
+ maximum=2.0,
117
+ value=0.7,
118
+ step=0.1,
119
+ label="Temperature (創造性)"
120
+ ),
121
+ gr.Slider(
122
+ minimum=0.1,
123
+ maximum=1.0,
124
+ value=0.95,
125
+ step=0.05,
126
+ label="Top-p (多様性制御)",
127
+ ),
128
+ ],
129
+ theme=gr.themes.Soft(),
130
+ examples=[
131
+ ["こんにちは!今日はどんなことを話しましょうか?"],
132
+ ["日本の文化について教えてください。"],
133
+ ["簡単なレシピを教えてもらえますか?"],
134
+ ["プログラミングについて質問があります。"],
135
+ ],
136
+ cache_examples=False,
137
+ )
138
 
139
  if __name__ == "__main__":
140
  demo.launch(
141
  server_name="0.0.0.0",
142
  server_port=7860,
143
+ share=False,
144
+ show_api=True, # API documentation を表示
145
  debug=True
146
  )