doropiza commited on
Commit
75bd3ec
·
1 Parent(s): e9972e5
Files changed (2) hide show
  1. app.py +118 -44
  2. requirements.txt +2 -1
app.py CHANGED
@@ -95,51 +95,125 @@
95
  # server_port=7860
96
  # )
97
 
98
-
99
- import os, torch, gradio as gr
100
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
101
 
102
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
103
- MODEL_NAME = "google/gemma-7b-it"
104
-
105
- model, tokenizer = None, None # ← グローバルで空のまま
106
-
107
- def load_model():
108
- """初回リクエスト時にのみ GPU を要求してモデルをロード"""
109
- global model, tokenizer
110
- if model is not None:
111
- return
112
- if not torch.cuda.is_available():
113
- # ZeroGPU ならここで一度 False → 数秒待って再度 True になることもある
114
- raise RuntimeError("GPU still not attached (ZeroGPU)。数秒後に再試行してください。")
115
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HUGGINGFACE_TOKEN)
116
- model = AutoModelForCausalLM.from_pretrained(
117
- MODEL_NAME,
118
- device_map="auto",
119
- torch_dtype=torch.float16,
120
- token=HUGGINGFACE_TOKEN
121
- )
122
-
123
- def respond(message, history):
124
- load_model() # ← ここで初めて GPU を確保・モデルロード
125
- inputs = tokenizer(message, return_tensors="pt").to(model.device)
126
- with torch.no_grad():
127
- out = model.generate(**inputs, max_new_tokens=512, temperature=0.7, top_p=0.9)
128
- return tokenizer.decode(out[0], skip_special_tokens=True)
129
-
130
- iface = gr.ChatInterface(
131
- fn=respond,
132
- title="Gemma-ZeroGPU Demo",
133
- chatbot=gr.Chatbot(
134
- type="messages",
135
- height=600,
136
- show_copy_button=True,
137
- show_share_button=True
138
- )
139
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
 
141
  if __name__ == "__main__":
142
- iface.launch(
143
- server_name="0.0.0.0",
144
- server_port=7860
145
- )
 
 
 
 
95
  # server_port=7860
96
  # )
97
 
98
+ import os
99
+ import gradio as gr
100
+ from transformers import AutoTokenizer, AutoModelForCausalLM
101
+ import torch
102
 
103
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
104
+ class ChatBot:
105
+ def __init__(self):
106
+ # 軽量なローカルLLMを使用(日本語対応)
107
+ model_name = "google/gemma-7b-it"
108
+ # 日本語対応の場合は "rinna/japanese-gpt2-medium" に変更可能
109
+
110
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_TOKEN)
111
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, token=HUGGINGFACE_TOKEN)
112
+
113
+ # パディングトークンを設定
114
+ if self.tokenizer.pad_token is None:
115
+ self.tokenizer.pad_token = self.tokenizer.eos_token
116
+
117
+ self.chat_history = []
118
+
119
+ def generate_response(self, message):
120
+ try:
121
+ # 入力をトークン化
122
+ inputs = self.tokenizer.encode(message + self.tokenizer.eos_token, return_tensors='pt')
123
+
124
+ # レスポンス生成
125
+ with torch.no_grad():
126
+ outputs = self.model.generate(
127
+ inputs,
128
+ max_length=inputs.shape[1] + 100,
129
+ num_return_sequences=1,
130
+ temperature=0.7,
131
+ do_sample=True,
132
+ pad_token_id=self.tokenizer.pad_token_id,
133
+ eos_token_id=self.tokenizer.eos_token_id
134
+ )
135
+
136
+ # レスポンスをデコード
137
+ response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
138
+ return response.strip()
139
+
140
+ except Exception as e:
141
+ return f"エラーが発生しました: {str(e)}"
142
+
143
+ def chat_interface(self, message, history):
144
+ if not message.strip():
145
+ return history, ""
146
+
147
+ # レスポンス生成
148
+ bot_response = self.generate_response(message)
149
+
150
+ # 会話履歴を更新
151
+ history.append([message, bot_response])
152
+
153
+ return history, ""
154
+
155
+ # ChatBotインスタンス作成
156
+ chatbot = ChatBot()
157
+
158
+ # Gradioインターフェース設定
159
+ def create_interface():
160
+ with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as demo:
161
+ gr.Markdown("# 🤖 ChatGPT Clone")
162
+ gr.Markdown("ローカルLLMを使用したチャットボットです")
163
+
164
+ # チャット履歴表示
165
+ chatbot_display = gr.Chatbot(
166
+ label="チャット",
167
+ height=400,
168
+ show_label=True
169
+ )
170
+
171
+ # 入力欄とボタン
172
+ with gr.Row():
173
+ msg_input = gr.Textbox(
174
+ placeholder="メッセージを入力してください...",
175
+ scale=4,
176
+ show_label=False
177
+ )
178
+ send_button = gr.Button("送信", scale=1)
179
+ clear_button = gr.Button("クリア", scale=1)
180
+
181
+ # イベント処理
182
+ def send_message(message, history):
183
+ return chatbot.chat_interface(message, history)
184
+
185
+ def clear_chat():
186
+ chatbot.chat_history = []
187
+ return []
188
+
189
+ # ボタンクリック時の処理
190
+ send_button.click(
191
+ send_message,
192
+ inputs=[msg_input, chatbot_display],
193
+ outputs=[chatbot_display, msg_input]
194
+ )
195
+
196
+ # Enterキーでも送信
197
+ msg_input.submit(
198
+ send_message,
199
+ inputs=[msg_input, chatbot_display],
200
+ outputs=[chatbot_display, msg_input]
201
+ )
202
+
203
+ # クリアボタン
204
+ clear_button.click(
205
+ clear_chat,
206
+ outputs=[chatbot_display]
207
+ )
208
+
209
+ return demo
210
 
211
+ # アプリケーション起動
212
  if __name__ == "__main__":
213
+ demo = create_interface()
214
+
215
+ # ローカル開発用
216
+ # demo.launch(share=False, server_name="127.0.0.1", server_port=7860)
217
+
218
+ # Hugging Face Spaces用
219
+ demo.launch(share=True)
requirements.txt CHANGED
@@ -4,4 +4,5 @@ transformers>=4.30.0
4
  torch>=2.0.0
5
  accelerate>=0.20.0
6
  sentencepiece>=0.1.99
7
- google-generativeai>=0.3.0
 
 
4
  torch>=2.0.0
5
  accelerate>=0.20.0
6
  sentencepiece>=0.1.99
7
+ google-generativeai>=0.3.0
8
+ tokenizers>=0.13.0