chatbot / app.py
doropiza's picture
c
754bc2f
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import spaces
import time
import json
from datetime import datetime
import logging
import gc
# ログ設定
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
class ChatBot:
def __init__(self):
# モデル選択オプション
self.available_models = {
"gemma-2b-it": "google/gemma-2b-it",
"gemma-7b-it": "google/gemma-7b-it",
"gemma-9b-it": "google/gemma-9b-it",
"gemma-12b-it":"google/gemma-3-12b-it",
"Llama-3-ELYZA-JP-8B":"elyza/Llama-3-ELYZA-JP-8B",
"phi-3-mini": "microsoft/Phi-3-mini-4k-instruct"
}
self.current_model_name = "Llama-3-ELYZA-JP-8B"
self.model_path = self.available_models[self.current_model_name]
# モデルとトークナイザーの初期化(GPU処理は後で実行)
self.tokenizer = None
self.model = None
self.model_loaded = False
# キャッシュ設定
self.response_cache = {}
# 会話管理
self.conversations = {}
# 設定
self.generation_config = {
"temperature": 0.8,
"top_p": 0.9,
"max_new_tokens": 150,
"repetition_penalty": 1.2
}
logger.info("ChatBot初期化完了(モデル読み込みはGPU関数内で実行)")
@spaces.GPU(duration=60)
def load_model(self):
"""モデルの読み込み(GPU環境内で実行)"""
try:
logger.info(f"モデル {self.model_path} を読み込み中...")
# 既存モデルのクリーンアップ(GPU環境内で安全に実行)
if hasattr(self, 'model') and self.model is not None:
del self.model
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
del self.tokenizer
# ガベージコレクション
gc.collect()
# CUDA利用可能性チェック(GPU環境内でのみ実行)
if torch.cuda.is_available():
torch.cuda.empty_cache()
device = torch.device("cuda")
logger.info("CUDA環境を確認済み - GPU使用")
# 量子化設定
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_storage=torch.uint8
)
# quantization_config = BitsAndBytesConfig(
# load_in_8bit=True,
# llm_int8_threshold=6.0,
# llm_int8_has_fp16_weight=False
# )
else:
device = torch.device("cpu")
quantization_config = None
logger.info("CPU環境 - 量子化無効")
# トークナイザー読み込み
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
token=HUGGINGFACE_TOKEN,
trust_remote_code=True
)
# パディングトークン設定
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# モデル読み込み設定
model_kwargs = {
"token": HUGGINGFACE_TOKEN,
"trust_remote_code": True,
"low_cpu_mem_usage": True,
}
if torch.cuda.is_available():
model_kwargs.update({
"quantization_config": quantization_config,
"device_map": "auto",
"torch_dtype": torch.float16
})
else:
model_kwargs["torch_dtype"] = torch.float32
# モデル読み込み
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path,
**model_kwargs
)
# CPU使用時の明示的なデバイス移動
if not torch.cuda.is_available():
self.model = self.model.to(device)
# モデルを評価モードに設定
self.model.eval()
self.model_loaded = True
logger.info(f"モデル {self.model_path} の読み込み完了")
# メモリ使用量ログ
if torch.cuda.is_available():
memory_allocated = torch.cuda.memory_allocated() / 1024**3
logger.info(f"GPU メモリ使用量: {memory_allocated:.2f} GB")
return "モデル読み込み完了"
except Exception as e:
logger.error(f"モデル読み込みエラー: {e}")
self.model_loaded = False
self.tokenizer = None
self.model = None
return f"モデル読み込み失敗: {str(e)}"
@spaces.GPU(duration=60)
def switch_model(self, model_name):
"""モデルの切り替え(GPU環境内で実行)"""
if model_name in self.available_models and model_name != self.current_model_name:
self.current_model_name = model_name
self.model_path = self.available_models[model_name]
self.model_loaded = False
# キャッシュクリア
self.response_cache.clear()
result = self.load_model()
return f"モデルを {model_name} に切り替えました: {result}"
return f"モデル {model_name} は既に選択されているか、利用できません"
def get_cache_key(self, message, config):
"""キャッシュキーを生成"""
key_data = f"{message.strip().lower()}_{self.current_model_name}_{str(config)}"
return hash(key_data)
def get_cached_response(self, cache_key):
"""キャッシュから応答を取得"""
return self.response_cache.get(cache_key)
def set_cached_response(self, cache_key, response):
"""応答をキャッシュに保存"""
if len(self.response_cache) > 50:
oldest_key = next(iter(self.response_cache))
del self.response_cache[oldest_key]
self.response_cache[cache_key] = response
def create_prompt(self, message, conversation_history):
"""プロンプトの作成"""
if self.current_model_name.startswith("gemma"):
prompt = ""
for msg in conversation_history[-6:]:
if msg["role"] == "user":
prompt += f"<start_of_turn>user\n{msg['content']}<end_of_turn>\n"
else:
prompt += f"<start_of_turn>model\n{msg['content']}<end_of_turn>\n"
prompt += f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
elif self.current_model_name.startswith("phi"):
prompt = "<|system|>\nYou are a helpful AI assistant.<|end|>\n"
for msg in conversation_history[-6:]:
if msg["role"] == "user":
prompt += f"<|user|>\n{msg['content']}<|end|>\n"
else:
prompt += f"<|assistant|>\n{msg['content']}<|end|>\n"
prompt += f"<|user|>\n{message}<|end|>\n<|assistant|>\n"
else:
prompt = f"Human: {message}\nAssistant:"
return prompt
@spaces.GPU(duration=45)
def generate_response(self, message, conversation_history=None):
"""応答生成(GPU環境内で実行)"""
# モデルが読み込まれていない場合は読み込み
if not self.model_loaded:
load_result = self.load_model()
if not self.model_loaded:
return f"モデル読み込みに失敗しました: {load_result}"
if conversation_history is None:
conversation_history = []
try:
# プロンプト作成
prompt = self.create_prompt(message, conversation_history)
# トークン化
inputs = self.tokenizer.encode(
prompt,
return_tensors='pt',
max_length=1024,
truncation=True,
padding=True
)
# デバイスに移動(GPU環境内で安全に実行)
if torch.cuda.is_available():
inputs = inputs.to(self.model.device)
# 生成パラメータ
generation_kwargs = {
"inputs": inputs,
"max_new_tokens": self.generation_config["max_new_tokens"],
"temperature": self.generation_config["temperature"],
"top_p": self.generation_config["top_p"],
"do_sample": True,
"repetition_penalty": self.generation_config["repetition_penalty"],
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"use_cache": True,
"attention_mask": torch.ones_like(inputs)
}
# 生成実行
with torch.no_grad():
outputs = self.model.generate(**generation_kwargs)
# デコード
response = self.tokenizer.decode(
outputs[0][inputs.shape[1]:],
skip_special_tokens=True
).strip()
# 応答の後処理
if self.current_model_name.startswith("gemma"):
if "<end_of_turn>" in response:
response = response.split("<end_of_turn>")[0].strip()
elif self.current_model_name.startswith("phi"):
if "<|" in response:
response = response.split("<|")[0].strip()
return response if response else "申し訳ありませんが、適切な応答を生成できませんでした。"
except Exception as e:
logger.error(f"応答生成エラー: {e}")
return f"エラーが発生しました: {str(e)}"
def get_conversation(self, session_id="default"):
"""会話履歴を取得"""
if session_id not in self.conversations:
self.conversations[session_id] = []
return self.conversations[session_id]
def add_to_conversation(self, session_id, role, content):
"""会話履歴に追加"""
if session_id not in self.conversations:
self.conversations[session_id] = []
self.conversations[session_id].append({
"role": role,
"content": content,
"timestamp": datetime.now().isoformat()
})
# 履歴制限
if len(self.conversations[session_id]) > 50:
self.conversations[session_id] = self.conversations[session_id][-50:]
def clear_conversation(self, session_id="default"):
"""会話履歴をクリア"""
if session_id in self.conversations:
self.conversations[session_id] = []
def export_conversation(self, session_id="default"):
"""会話履歴をエクスポート"""
conversation = self.get_conversation(session_id)
return json.dumps(conversation, indent=2, ensure_ascii=False)
def update_generation_config(self, temperature, top_p, max_tokens, repetition_penalty):
"""生成設定を更新"""
self.generation_config.update({
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_tokens,
"repetition_penalty": repetition_penalty
})
self.response_cache.clear()
# グローバルChatBotインスタンス
chatbot = ChatBot()
def create_interface():
"""Gradioインターフェースの作成"""
with gr.Blocks(
title="ChatGPT Clone - Advanced AI Chat",
theme=gr.themes.Soft(),
css="""
.chatbot { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; }
.message { border-radius: 8px; margin: 4px 0; }
.user-message { background-color: #e3f2fd; }
.bot-message { background-color: #f1f8e9; }
.status-box { background-color: #f5f5f5; padding: 10px; border-radius: 5px; }
"""
) as demo:
# ヘッダー
gr.Markdown("""
# 🤖 ChatGPT Clone - Advanced AI Chat
**Powered by Hugging Face Models** | **GPU Accelerated** | **Multi-Model Support**
⚠️ **初回使用時**: 最初のメッセージ送信時にモデルが自動的に読み込まれます。少しお待ちください。
""")
with gr.Row():
# メインチャットエリア
with gr.Column(scale=3):
# ステータス表示
with gr.Row():
with gr.Column():
status_display = gr.Markdown("**ステータス**: 待機中(モデル未読み込み)", elem_classes=["status-box"])
with gr.Column():
model_display = gr.Markdown(f"**現在のモデル**: {chatbot.current_model_name}", elem_classes=["status-box"])
# チャットボット
chatbot_display = gr.Chatbot(
label="AI Chat",
height=500,
show_label=False,
type="messages",
avatar_images=("👤", "🤖")
)
# 入力エリア
with gr.Row():
msg_input = gr.Textbox(
placeholder="メッセージを入力してください(初回は読み込み時間がかかります)...",
scale=4,
show_label=False,
lines=2
)
send_button = gr.Button("📤 送信", scale=1, variant="primary")
# コントロールボタン
with gr.Row():
clear_button = gr.Button("🗑️ クリア", size="sm")
export_button = gr.Button("📁 エクスポート", size="sm")
reload_model_button = gr.Button("🔄 モデル読込", size="sm")
# サイドバー
with gr.Column(scale=1):
gr.Markdown("### ⚙️ 設定")
# モデル選択
model_selector = gr.Dropdown(
choices=list(chatbot.available_models.keys()),
value=chatbot.current_model_name,
label="🔄 モデル選択",
interactive=True
)
# 生成パラメータ
gr.Markdown("#### 生成パラメータ")
temperature_slider = gr.Slider(
minimum=0.1, maximum=2.0, value=0.8, step=0.1,
label="🌡️ Temperature", info="創造性を調整"
)
top_p_slider = gr.Slider(
minimum=0.1, maximum=1.0, value=0.9, step=0.05,
label="🎯 Top-p", info="多様性を調整"
)
max_tokens_slider = gr.Slider(
minimum=50, maximum=500, value=150, step=25,
label="📝 最大トークン数"
)
repetition_penalty_slider = gr.Slider(
minimum=1.0, maximum=2.0, value=1.2, step=0.1,
label="🔄 繰り返し抑制"
)
# 統計情報
gr.Markdown("#### 📊 統計")
stats_display = gr.Markdown("**総会話数**: 0\n**キャッシュ数**: 0")
# 出力用テキストボックス
export_output = gr.Textbox(label="エクスポートされた会話", visible=False)
# イベント処理関数
def send_message(message, history):
if not message.strip():
return history, "", "**ステータス**: 空のメッセージです", get_stats()
start_time = time.time()
session_id = "default"
try:
# キャッシュチェック
cache_key = chatbot.get_cache_key(message, chatbot.generation_config)
cached_response = chatbot.get_cached_response(cache_key)
if cached_response:
response = cached_response
status = "**ステータス**: ✅ キャッシュから取得"
else:
# 新しい応答を生成
conversation_history = chatbot.get_conversation(session_id)
# 初回読み込みの場合の状態表示
if not chatbot.model_loaded:
status = "**ステータス**: 🔄 モデル読み込み中..."
else:
status = "**ステータス**: 🔄 応答生成中..."
response = chatbot.generate_response(message, conversation_history)
# エラーチェック
if response.startswith("エラーが発生しました") or response.startswith("モデル読み込みに失敗"):
return history, message, f"**ステータス**: ❌ {response}", get_stats()
chatbot.set_cached_response(cache_key, response)
status = "**ステータス**: ✅ 新規生成完了"
# 会話履歴に追加
chatbot.add_to_conversation(session_id, "user", message)
chatbot.add_to_conversation(session_id, "assistant", response)
# 履歴更新
new_history = history + [{"role": "user", "content": message}]
new_history = new_history + [{"role": "assistant", "content": response}]
end_time = time.time()
process_time = end_time - start_time
status += f" ({process_time:.2f}秒)"
return new_history, "", status, get_stats()
except Exception as e:
logger.error(f"メッセージ送信エラー: {e}")
return history, message, f"**ステータス**: ❌ エラー: {str(e)}", get_stats()
def clear_chat():
session_id = "default"
chatbot.clear_conversation(session_id)
return [], "**ステータス**: 🗑️ チャットをクリアしました", get_stats()
def change_model(model_name):
result = chatbot.switch_model(model_name)
return f"**現在のモデル**: {chatbot.current_model_name}", f"**ステータス**: {result}"
def reload_model():
result = chatbot.load_model()
status = f"✅ {result}" if chatbot.model_loaded else f"❌ {result}"
return f"**ステータス**: {status}"
def update_params(temp, top_p, max_tokens, rep_penalty):
chatbot.update_generation_config(temp, top_p, max_tokens, rep_penalty)
return "**ステータス**: ⚙️ パラメータを更新しました(キャッシュクリア済み)"
def export_chat():
session_id = "default"
conversation_json = chatbot.export_conversation(session_id)
return conversation_json, "**ステータス**: 📁 会話をエクスポートしました"
def get_stats():
total_conversations = sum(len(conv) for conv in chatbot.conversations.values())
cache_count = len(chatbot.response_cache)
model_status = "読み込み済み" if chatbot.model_loaded else "未読み込み"
return f"**総会話数**: {total_conversations}\n**キャッシュ数**: {cache_count}\n**モデル**: {model_status}"
# イベントバインディング
send_button.click(
send_message,
inputs=[msg_input, chatbot_display],
outputs=[chatbot_display, msg_input, status_display, stats_display]
)
msg_input.submit(
send_message,
inputs=[msg_input, chatbot_display],
outputs=[chatbot_display, msg_input, status_display, stats_display]
)
clear_button.click(
clear_chat,
outputs=[chatbot_display, status_display, stats_display]
)
model_selector.change(
change_model,
inputs=[model_selector],
outputs=[model_display, status_display]
)
reload_model_button.click(
reload_model,
outputs=[status_display]
)
export_button.click(
export_chat,
outputs=[export_output, status_display]
)
# パラメータ更新
for slider in [temperature_slider, top_p_slider, max_tokens_slider, repetition_penalty_slider]:
slider.change(
update_params,
inputs=[temperature_slider, top_p_slider, max_tokens_slider, repetition_penalty_slider],
outputs=[status_display]
)
return demo
# アプリケーション起動
if __name__ == "__main__":
demo = create_interface()
# Hugging Face Spaces用
demo.launch(
share=False,
show_error=True,
server_name="0.0.0.0",
server_port=7860
)