import os import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import torch import spaces import time import json from datetime import datetime import logging import gc # ログ設定 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") class ChatBot: def __init__(self): # モデル選択オプション self.available_models = { "gemma-2b-it": "google/gemma-2b-it", "gemma-7b-it": "google/gemma-7b-it", "gemma-9b-it": "google/gemma-9b-it", "gemma-12b-it":"google/gemma-3-12b-it", "Llama-3-ELYZA-JP-8B":"elyza/Llama-3-ELYZA-JP-8B", "phi-3-mini": "microsoft/Phi-3-mini-4k-instruct" } self.current_model_name = "Llama-3-ELYZA-JP-8B" self.model_path = self.available_models[self.current_model_name] # モデルとトークナイザーの初期化(GPU処理は後で実行) self.tokenizer = None self.model = None self.model_loaded = False # キャッシュ設定 self.response_cache = {} # 会話管理 self.conversations = {} # 設定 self.generation_config = { "temperature": 0.8, "top_p": 0.9, "max_new_tokens": 150, "repetition_penalty": 1.2 } logger.info("ChatBot初期化完了(モデル読み込みはGPU関数内で実行)") @spaces.GPU(duration=60) def load_model(self): """モデルの読み込み(GPU環境内で実行)""" try: logger.info(f"モデル {self.model_path} を読み込み中...") # 既存モデルのクリーンアップ(GPU環境内で安全に実行) if hasattr(self, 'model') and self.model is not None: del self.model if hasattr(self, 'tokenizer') and self.tokenizer is not None: del self.tokenizer # ガベージコレクション gc.collect() # CUDA利用可能性チェック(GPU環境内でのみ実行) if torch.cuda.is_available(): torch.cuda.empty_cache() device = torch.device("cuda") logger.info("CUDA環境を確認済み - GPU使用") # 量子化設定 quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_storage=torch.uint8 ) # quantization_config = BitsAndBytesConfig( # load_in_8bit=True, # llm_int8_threshold=6.0, # llm_int8_has_fp16_weight=False # ) else: device = torch.device("cpu") quantization_config = None logger.info("CPU環境 - 量子化無効") # トークナイザー読み込み self.tokenizer = AutoTokenizer.from_pretrained( self.model_path, token=HUGGINGFACE_TOKEN, trust_remote_code=True ) # パディングトークン設定 if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # モデル読み込み設定 model_kwargs = { "token": HUGGINGFACE_TOKEN, "trust_remote_code": True, "low_cpu_mem_usage": True, } if torch.cuda.is_available(): model_kwargs.update({ "quantization_config": quantization_config, "device_map": "auto", "torch_dtype": torch.float16 }) else: model_kwargs["torch_dtype"] = torch.float32 # モデル読み込み self.model = AutoModelForCausalLM.from_pretrained( self.model_path, **model_kwargs ) # CPU使用時の明示的なデバイス移動 if not torch.cuda.is_available(): self.model = self.model.to(device) # モデルを評価モードに設定 self.model.eval() self.model_loaded = True logger.info(f"モデル {self.model_path} の読み込み完了") # メモリ使用量ログ if torch.cuda.is_available(): memory_allocated = torch.cuda.memory_allocated() / 1024**3 logger.info(f"GPU メモリ使用量: {memory_allocated:.2f} GB") return "モデル読み込み完了" except Exception as e: logger.error(f"モデル読み込みエラー: {e}") self.model_loaded = False self.tokenizer = None self.model = None return f"モデル読み込み失敗: {str(e)}" @spaces.GPU(duration=60) def switch_model(self, model_name): """モデルの切り替え(GPU環境内で実行)""" if model_name in self.available_models and model_name != self.current_model_name: self.current_model_name = model_name self.model_path = self.available_models[model_name] self.model_loaded = False # キャッシュクリア self.response_cache.clear() result = self.load_model() return f"モデルを {model_name} に切り替えました: {result}" return f"モデル {model_name} は既に選択されているか、利用できません" def get_cache_key(self, message, config): """キャッシュキーを生成""" key_data = f"{message.strip().lower()}_{self.current_model_name}_{str(config)}" return hash(key_data) def get_cached_response(self, cache_key): """キャッシュから応答を取得""" return self.response_cache.get(cache_key) def set_cached_response(self, cache_key, response): """応答をキャッシュに保存""" if len(self.response_cache) > 50: oldest_key = next(iter(self.response_cache)) del self.response_cache[oldest_key] self.response_cache[cache_key] = response def create_prompt(self, message, conversation_history): """プロンプトの作成""" if self.current_model_name.startswith("gemma"): prompt = "" for msg in conversation_history[-6:]: if msg["role"] == "user": prompt += f"user\n{msg['content']}\n" else: prompt += f"model\n{msg['content']}\n" prompt += f"user\n{message}\nmodel\n" elif self.current_model_name.startswith("phi"): prompt = "<|system|>\nYou are a helpful AI assistant.<|end|>\n" for msg in conversation_history[-6:]: if msg["role"] == "user": prompt += f"<|user|>\n{msg['content']}<|end|>\n" else: prompt += f"<|assistant|>\n{msg['content']}<|end|>\n" prompt += f"<|user|>\n{message}<|end|>\n<|assistant|>\n" else: prompt = f"Human: {message}\nAssistant:" return prompt @spaces.GPU(duration=45) def generate_response(self, message, conversation_history=None): """応答生成(GPU環境内で実行)""" # モデルが読み込まれていない場合は読み込み if not self.model_loaded: load_result = self.load_model() if not self.model_loaded: return f"モデル読み込みに失敗しました: {load_result}" if conversation_history is None: conversation_history = [] try: # プロンプト作成 prompt = self.create_prompt(message, conversation_history) # トークン化 inputs = self.tokenizer.encode( prompt, return_tensors='pt', max_length=1024, truncation=True, padding=True ) # デバイスに移動(GPU環境内で安全に実行) if torch.cuda.is_available(): inputs = inputs.to(self.model.device) # 生成パラメータ generation_kwargs = { "inputs": inputs, "max_new_tokens": self.generation_config["max_new_tokens"], "temperature": self.generation_config["temperature"], "top_p": self.generation_config["top_p"], "do_sample": True, "repetition_penalty": self.generation_config["repetition_penalty"], "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id, "use_cache": True, "attention_mask": torch.ones_like(inputs) } # 生成実行 with torch.no_grad(): outputs = self.model.generate(**generation_kwargs) # デコード response = self.tokenizer.decode( outputs[0][inputs.shape[1]:], skip_special_tokens=True ).strip() # 応答の後処理 if self.current_model_name.startswith("gemma"): if "" in response: response = response.split("")[0].strip() elif self.current_model_name.startswith("phi"): if "<|" in response: response = response.split("<|")[0].strip() return response if response else "申し訳ありませんが、適切な応答を生成できませんでした。" except Exception as e: logger.error(f"応答生成エラー: {e}") return f"エラーが発生しました: {str(e)}" def get_conversation(self, session_id="default"): """会話履歴を取得""" if session_id not in self.conversations: self.conversations[session_id] = [] return self.conversations[session_id] def add_to_conversation(self, session_id, role, content): """会話履歴に追加""" if session_id not in self.conversations: self.conversations[session_id] = [] self.conversations[session_id].append({ "role": role, "content": content, "timestamp": datetime.now().isoformat() }) # 履歴制限 if len(self.conversations[session_id]) > 50: self.conversations[session_id] = self.conversations[session_id][-50:] def clear_conversation(self, session_id="default"): """会話履歴をクリア""" if session_id in self.conversations: self.conversations[session_id] = [] def export_conversation(self, session_id="default"): """会話履歴をエクスポート""" conversation = self.get_conversation(session_id) return json.dumps(conversation, indent=2, ensure_ascii=False) def update_generation_config(self, temperature, top_p, max_tokens, repetition_penalty): """生成設定を更新""" self.generation_config.update({ "temperature": temperature, "top_p": top_p, "max_new_tokens": max_tokens, "repetition_penalty": repetition_penalty }) self.response_cache.clear() # グローバルChatBotインスタンス chatbot = ChatBot() def create_interface(): """Gradioインターフェースの作成""" with gr.Blocks( title="ChatGPT Clone - Advanced AI Chat", theme=gr.themes.Soft(), css=""" .chatbot { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } .message { border-radius: 8px; margin: 4px 0; } .user-message { background-color: #e3f2fd; } .bot-message { background-color: #f1f8e9; } .status-box { background-color: #f5f5f5; padding: 10px; border-radius: 5px; } """ ) as demo: # ヘッダー gr.Markdown(""" # 🤖 ChatGPT Clone - Advanced AI Chat **Powered by Hugging Face Models** | **GPU Accelerated** | **Multi-Model Support** ⚠️ **初回使用時**: 最初のメッセージ送信時にモデルが自動的に読み込まれます。少しお待ちください。 """) with gr.Row(): # メインチャットエリア with gr.Column(scale=3): # ステータス表示 with gr.Row(): with gr.Column(): status_display = gr.Markdown("**ステータス**: 待機中(モデル未読み込み)", elem_classes=["status-box"]) with gr.Column(): model_display = gr.Markdown(f"**現在のモデル**: {chatbot.current_model_name}", elem_classes=["status-box"]) # チャットボット chatbot_display = gr.Chatbot( label="AI Chat", height=500, show_label=False, type="messages", avatar_images=("👤", "🤖") ) # 入力エリア with gr.Row(): msg_input = gr.Textbox( placeholder="メッセージを入力してください(初回は読み込み時間がかかります)...", scale=4, show_label=False, lines=2 ) send_button = gr.Button("📤 送信", scale=1, variant="primary") # コントロールボタン with gr.Row(): clear_button = gr.Button("🗑️ クリア", size="sm") export_button = gr.Button("📁 エクスポート", size="sm") reload_model_button = gr.Button("🔄 モデル読込", size="sm") # サイドバー with gr.Column(scale=1): gr.Markdown("### ⚙️ 設定") # モデル選択 model_selector = gr.Dropdown( choices=list(chatbot.available_models.keys()), value=chatbot.current_model_name, label="🔄 モデル選択", interactive=True ) # 生成パラメータ gr.Markdown("#### 生成パラメータ") temperature_slider = gr.Slider( minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="🌡️ Temperature", info="創造性を調整" ) top_p_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="🎯 Top-p", info="多様性を調整" ) max_tokens_slider = gr.Slider( minimum=50, maximum=500, value=150, step=25, label="📝 最大トークン数" ) repetition_penalty_slider = gr.Slider( minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="🔄 繰り返し抑制" ) # 統計情報 gr.Markdown("#### 📊 統計") stats_display = gr.Markdown("**総会話数**: 0\n**キャッシュ数**: 0") # 出力用テキストボックス export_output = gr.Textbox(label="エクスポートされた会話", visible=False) # イベント処理関数 def send_message(message, history): if not message.strip(): return history, "", "**ステータス**: 空のメッセージです", get_stats() start_time = time.time() session_id = "default" try: # キャッシュチェック cache_key = chatbot.get_cache_key(message, chatbot.generation_config) cached_response = chatbot.get_cached_response(cache_key) if cached_response: response = cached_response status = "**ステータス**: ✅ キャッシュから取得" else: # 新しい応答を生成 conversation_history = chatbot.get_conversation(session_id) # 初回読み込みの場合の状態表示 if not chatbot.model_loaded: status = "**ステータス**: 🔄 モデル読み込み中..." else: status = "**ステータス**: 🔄 応答生成中..." response = chatbot.generate_response(message, conversation_history) # エラーチェック if response.startswith("エラーが発生しました") or response.startswith("モデル読み込みに失敗"): return history, message, f"**ステータス**: ❌ {response}", get_stats() chatbot.set_cached_response(cache_key, response) status = "**ステータス**: ✅ 新規生成完了" # 会話履歴に追加 chatbot.add_to_conversation(session_id, "user", message) chatbot.add_to_conversation(session_id, "assistant", response) # 履歴更新 new_history = history + [{"role": "user", "content": message}] new_history = new_history + [{"role": "assistant", "content": response}] end_time = time.time() process_time = end_time - start_time status += f" ({process_time:.2f}秒)" return new_history, "", status, get_stats() except Exception as e: logger.error(f"メッセージ送信エラー: {e}") return history, message, f"**ステータス**: ❌ エラー: {str(e)}", get_stats() def clear_chat(): session_id = "default" chatbot.clear_conversation(session_id) return [], "**ステータス**: 🗑️ チャットをクリアしました", get_stats() def change_model(model_name): result = chatbot.switch_model(model_name) return f"**現在のモデル**: {chatbot.current_model_name}", f"**ステータス**: {result}" def reload_model(): result = chatbot.load_model() status = f"✅ {result}" if chatbot.model_loaded else f"❌ {result}" return f"**ステータス**: {status}" def update_params(temp, top_p, max_tokens, rep_penalty): chatbot.update_generation_config(temp, top_p, max_tokens, rep_penalty) return "**ステータス**: ⚙️ パラメータを更新しました(キャッシュクリア済み)" def export_chat(): session_id = "default" conversation_json = chatbot.export_conversation(session_id) return conversation_json, "**ステータス**: 📁 会話をエクスポートしました" def get_stats(): total_conversations = sum(len(conv) for conv in chatbot.conversations.values()) cache_count = len(chatbot.response_cache) model_status = "読み込み済み" if chatbot.model_loaded else "未読み込み" return f"**総会話数**: {total_conversations}\n**キャッシュ数**: {cache_count}\n**モデル**: {model_status}" # イベントバインディング send_button.click( send_message, inputs=[msg_input, chatbot_display], outputs=[chatbot_display, msg_input, status_display, stats_display] ) msg_input.submit( send_message, inputs=[msg_input, chatbot_display], outputs=[chatbot_display, msg_input, status_display, stats_display] ) clear_button.click( clear_chat, outputs=[chatbot_display, status_display, stats_display] ) model_selector.change( change_model, inputs=[model_selector], outputs=[model_display, status_display] ) reload_model_button.click( reload_model, outputs=[status_display] ) export_button.click( export_chat, outputs=[export_output, status_display] ) # パラメータ更新 for slider in [temperature_slider, top_p_slider, max_tokens_slider, repetition_penalty_slider]: slider.change( update_params, inputs=[temperature_slider, top_p_slider, max_tokens_slider, repetition_penalty_slider], outputs=[status_display] ) return demo # アプリケーション起動 if __name__ == "__main__": demo = create_interface() # Hugging Face Spaces用 demo.launch( share=False, show_error=True, server_name="0.0.0.0", server_port=7860 )