oggata commited on
Commit
f333515
·
verified ·
1 Parent(s): f95231e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -33
app.py CHANGED
@@ -5,8 +5,8 @@ import warnings
5
  warnings.filterwarnings("ignore")
6
 
7
  """
8
- デバッグ用のシンプルなSarashinaチャットボット
9
- additional_inputsなしでテスト
10
  """
11
 
12
  # モデルとトークナイザーの初期化
@@ -22,31 +22,30 @@ model = AutoModelForCausalLM.from_pretrained(
22
  )
23
  print("モデルの読み込みが完了しました。")
24
 
25
- def respond(message, history):
 
 
 
 
 
 
 
26
  """
27
- シンプルなチャットボット応答関数
28
- additional_inputsなし
29
  """
30
  try:
31
- # デバッグ情報を出力
32
- print(f"DEBUG - message: {message} (type: {type(message)})")
33
- print(f"DEBUG - history: {history} (type: {type(history)})")
34
-
35
- # システムメッセージ(固定)
36
- system_message = "あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。"
37
-
38
- # 会話履歴を含むプロンプトを構築
39
- conversation = f"システム: {system_message}\n"
40
 
41
  # 会話履歴を追加
42
- if history and isinstance(history, list):
43
- for item in history:
44
- if isinstance(item, (list, tuple)) and len(item) >= 2:
45
- user_msg, bot_msg = item[0], item[1]
46
- if user_msg:
47
- conversation += f"ユーザー: {user_msg}\n"
48
- if bot_msg:
49
- conversation += f"アシスタント: {bot_msg}\n"
50
 
51
  # 現在のメッセージを追加
52
  conversation += f"ユーザー: {message}\nアシスタント: "
@@ -58,13 +57,15 @@ def respond(message, history):
58
  if torch.cuda.is_available():
59
  inputs = inputs.cuda()
60
 
61
- # 応答生成
 
62
  with torch.no_grad():
 
63
  outputs = model.generate(
64
  inputs,
65
- max_new_tokens=512,
66
- temperature=0.7,
67
- top_p=0.95,
68
  do_sample=True,
69
  pad_token_id=tokenizer.eos_token_id,
70
  eos_token_id=tokenizer.eos_token_id,
@@ -90,17 +91,47 @@ def respond(message, history):
90
  yield f"エラーが発生しました: {str(e)}"
91
 
92
  """
93
- シンプルなChatInterface(additional_inputsなし)
 
94
  """
95
  demo = gr.ChatInterface(
96
  respond,
97
- title="🤖 Sarashina Chatbot (Simple)",
98
- description="Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。(デバッグ用)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  theme=gr.themes.Soft(),
100
  examples=[
101
- "こんにちは!",
102
- "日本について教えて",
103
- "プログラミングの質問があります",
 
104
  ],
105
  cache_examples=False,
106
  )
@@ -110,6 +141,6 @@ if __name__ == "__main__":
110
  server_name="0.0.0.0",
111
  server_port=7860,
112
  share=False,
113
- show_api=True,
114
  debug=True
115
  )
 
5
  warnings.filterwarnings("ignore")
6
 
7
  """
8
+ Sarashinaモデルを使用したGradioチャットボット
9
+ Hugging Face Transformersライブラリを使用してローカルでモデルを実行
10
  """
11
 
12
  # モデルとトークナイザーの初期化
 
22
  )
23
  print("モデルの読み込みが完了しました。")
24
 
25
+ def respond(
26
+ message,
27
+ history: list[tuple[str, str]],
28
+ system_message,
29
+ max_tokens,
30
+ temperature,
31
+ top_p,
32
+ ):
33
  """
34
+ チャットボットの応答を生成する関数
35
+ Gradio ChatInterfaceの標準形式に対応
36
  """
37
  try:
38
+ # システムメッセージと会話履歴を含むプロンプトを構築
39
+ conversation = ""
40
+ if system_message.strip():
41
+ conversation += f"システム: {system_message}\n"
 
 
 
 
 
42
 
43
  # 会話履歴を追加
44
+ for user_msg, bot_msg in history:
45
+ if user_msg:
46
+ conversation += f"ユーザー: {user_msg}\n"
47
+ if bot_msg:
48
+ conversation += f"アシスタント: {bot_msg}\n"
 
 
 
49
 
50
  # 現在のメッセージを追加
51
  conversation += f"ユーザー: {message}\nアシスタント: "
 
57
  if torch.cuda.is_available():
58
  inputs = inputs.cuda()
59
 
60
+ # 応答生成(ストリーミング対応)
61
+ response = ""
62
  with torch.no_grad():
63
+ # 一度に生成してからストリーミング風に出力
64
  outputs = model.generate(
65
  inputs,
66
+ max_new_tokens=max_tokens,
67
+ temperature=temperature,
68
+ top_p=top_p,
69
  do_sample=True,
70
  pad_token_id=tokenizer.eos_token_id,
71
  eos_token_id=tokenizer.eos_token_id,
 
91
  yield f"エラーが発生しました: {str(e)}"
92
 
93
  """
94
+ Gradio ChatInterfaceを使用したシンプルなチャットボット
95
+ カスタマイズ可能なパラメータを含む
96
  """
97
  demo = gr.ChatInterface(
98
  respond,
99
+ title="🤖 Sarashina Chatbot",
100
+ description="Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。",
101
+ additional_inputs=[
102
+ gr.Textbox(
103
+ value="あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。",
104
+ label="システムメッセージ",
105
+ lines=3
106
+ ),
107
+ gr.Slider(
108
+ minimum=1,
109
+ maximum=1024,
110
+ value=512,
111
+ step=1,
112
+ label="最大新規トークン数"
113
+ ),
114
+ gr.Slider(
115
+ minimum=0.1,
116
+ maximum=2.0,
117
+ value=0.7,
118
+ step=0.1,
119
+ label="Temperature (創造性)"
120
+ ),
121
+ gr.Slider(
122
+ minimum=0.1,
123
+ maximum=1.0,
124
+ value=0.95,
125
+ step=0.05,
126
+ label="Top-p (多様性制御)",
127
+ ),
128
+ ],
129
  theme=gr.themes.Soft(),
130
  examples=[
131
+ ["こんにちは!今日はどんなことを話しましょうか?"],
132
+ ["日本の文化について教えてください。"],
133
+ ["簡単なレシピを教えてもらえますか?"],
134
+ ["プログラミングについて質問があります。"],
135
  ],
136
  cache_examples=False,
137
  )
 
141
  server_name="0.0.0.0",
142
  server_port=7860,
143
  share=False,
144
+ show_api=True, # API documentation を表示
145
  debug=True
146
  )