oggata commited on
Commit
f95231e
·
verified ·
1 Parent(s): 691edc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -63
app.py CHANGED
@@ -5,8 +5,8 @@ import warnings
5
  warnings.filterwarnings("ignore")
6
 
7
  """
8
- Sarashinaモデルを使用したGradioチャットボット
9
- Hugging Face Transformersライブラリを使用してローカルでモデルを実行
10
  """
11
 
12
  # モデルとトークナイザーの初期化
@@ -22,30 +22,31 @@ model = AutoModelForCausalLM.from_pretrained(
22
  )
23
  print("モデルの読み込みが完了しました。")
24
 
25
- def respond(
26
- message,
27
- history,
28
- system_message,
29
- max_tokens,
30
- temperature,
31
- top_p,
32
- ):
33
  """
34
- チャットボットの応答を生成する関数
35
- 元のコードと同じシグネチャを維持
36
  """
37
  try:
 
 
 
 
 
 
 
38
  # 会話履歴を含むプロンプトを構築
39
- conversation = ""
40
- if system_message.strip():
41
- conversation += f"システム: {system_message}\n"
42
 
43
  # 会話履歴を追加
44
- for user_msg, bot_msg in history:
45
- if user_msg:
46
- conversation += f"ユーザー: {user_msg}\n"
47
- if bot_msg:
48
- conversation += f"アシスタント: {bot_msg}\n"
 
 
 
49
 
50
  # 現在のメッセージを追加
51
  conversation += f"ユーザー: {message}\nアシスタント: "
@@ -57,15 +58,13 @@ def respond(
57
  if torch.cuda.is_available():
58
  inputs = inputs.cuda()
59
 
60
- # 応答生成(ストリーミング対応)
61
- response = ""
62
  with torch.no_grad():
63
- # 一度に生成してからストリーミング風に出力
64
  outputs = model.generate(
65
  inputs,
66
- max_new_tokens=max_tokens,
67
- temperature=temperature,
68
- top_p=top_p,
69
  do_sample=True,
70
  pad_token_id=tokenizer.eos_token_id,
71
  eos_token_id=tokenizer.eos_token_id,
@@ -91,47 +90,17 @@ def respond(
91
  yield f"エラーが発生しました: {str(e)}"
92
 
93
  """
94
- Gradio ChatInterfaceを使用したシンプルなチャットボット
95
- 元のコードと同じadditional_inputsを維持
96
  """
97
  demo = gr.ChatInterface(
98
  respond,
99
- title="🤖 Sarashina Chatbot",
100
- description="Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。",
101
- additional_inputs=[
102
- gr.Textbox(
103
- value="あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。",
104
- label="システムメッセージ",
105
- lines=3
106
- ),
107
- gr.Slider(
108
- minimum=1,
109
- maximum=1024,
110
- value=512,
111
- step=1,
112
- label="最大新規トークン数"
113
- ),
114
- gr.Slider(
115
- minimum=0.1,
116
- maximum=2.0,
117
- value=0.7,
118
- step=0.1,
119
- label="Temperature (創造性)"
120
- ),
121
- gr.Slider(
122
- minimum=0.1,
123
- maximum=1.0,
124
- value=0.95,
125
- step=0.05,
126
- label="Top-p (多様性制御)",
127
- ),
128
- ],
129
  theme=gr.themes.Soft(),
130
  examples=[
131
- ["こんにちは!今日はどんなことを話しましょうか?"],
132
- ["日本の文化について教えてください。"],
133
- ["簡単なレシピを教えてもらえますか?"],
134
- ["プログラミングについて質問があります。"],
135
  ],
136
  cache_examples=False,
137
  )
@@ -141,6 +110,6 @@ if __name__ == "__main__":
141
  server_name="0.0.0.0",
142
  server_port=7860,
143
  share=False,
144
- show_api=True, # API documentation を表示
145
  debug=True
146
  )
 
5
  warnings.filterwarnings("ignore")
6
 
7
  """
8
+ デバッグ用のシンプルなSarashinaチャットボット
9
+ additional_inputsなしでテスト
10
  """
11
 
12
  # モデルとトークナイザーの初期化
 
22
  )
23
  print("モデルの読み込みが完了しました。")
24
 
25
+ def respond(message, history):
 
 
 
 
 
 
 
26
  """
27
+ シンプルなチャットボット応答関数
28
+ additional_inputsなし
29
  """
30
  try:
31
+ # デバッグ情報を出力
32
+ print(f"DEBUG - message: {message} (type: {type(message)})")
33
+ print(f"DEBUG - history: {history} (type: {type(history)})")
34
+
35
+ # システムメッセージ(固定)
36
+ system_message = "あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。"
37
+
38
  # 会話履歴を含むプロンプトを構築
39
+ conversation = f"システム: {system_message}\n"
 
 
40
 
41
  # 会話履歴を追加
42
+ if history and isinstance(history, list):
43
+ for item in history:
44
+ if isinstance(item, (list, tuple)) and len(item) >= 2:
45
+ user_msg, bot_msg = item[0], item[1]
46
+ if user_msg:
47
+ conversation += f"ユーザー: {user_msg}\n"
48
+ if bot_msg:
49
+ conversation += f"アシスタント: {bot_msg}\n"
50
 
51
  # 現在のメッセージを追加
52
  conversation += f"ユーザー: {message}\nアシスタント: "
 
58
  if torch.cuda.is_available():
59
  inputs = inputs.cuda()
60
 
61
+ # 応答生成
 
62
  with torch.no_grad():
 
63
  outputs = model.generate(
64
  inputs,
65
+ max_new_tokens=512,
66
+ temperature=0.7,
67
+ top_p=0.95,
68
  do_sample=True,
69
  pad_token_id=tokenizer.eos_token_id,
70
  eos_token_id=tokenizer.eos_token_id,
 
90
  yield f"エラーが発生しました: {str(e)}"
91
 
92
  """
93
+ シンプルなChatInterface(additional_inputsなし)
 
94
  """
95
  demo = gr.ChatInterface(
96
  respond,
97
+ title="🤖 Sarashina Chatbot (Simple)",
98
+ description="Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。(デバッグ用)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  theme=gr.themes.Soft(),
100
  examples=[
101
+ "こんにちは!",
102
+ "日本について教えて",
103
+ "プログラミングの質問があります",
 
104
  ],
105
  cache_examples=False,
106
  )
 
110
  server_name="0.0.0.0",
111
  server_port=7860,
112
  share=False,
113
+ show_api=True,
114
  debug=True
115
  )