|
import os |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
import torch |
|
import spaces |
|
import time |
|
import json |
|
from datetime import datetime |
|
import logging |
|
import gc |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") |
|
|
|
class ChatBot: |
|
def __init__(self): |
|
|
|
self.available_models = { |
|
"gemma-2b-it": "google/gemma-2b-it", |
|
"gemma-7b-it": "google/gemma-7b-it", |
|
"gemma-9b-it": "google/gemma-9b-it", |
|
"gemma-12b-it":"google/gemma-3-12b-it", |
|
"Llama-3-ELYZA-JP-8B":"elyza/Llama-3-ELYZA-JP-8B", |
|
"phi-3-mini": "microsoft/Phi-3-mini-4k-instruct" |
|
} |
|
|
|
self.current_model_name = "Llama-3-ELYZA-JP-8B" |
|
self.model_path = self.available_models[self.current_model_name] |
|
|
|
|
|
self.tokenizer = None |
|
self.model = None |
|
self.model_loaded = False |
|
|
|
|
|
self.response_cache = {} |
|
|
|
|
|
self.conversations = {} |
|
|
|
|
|
self.generation_config = { |
|
"temperature": 0.8, |
|
"top_p": 0.9, |
|
"max_new_tokens": 150, |
|
"repetition_penalty": 1.2 |
|
} |
|
|
|
logger.info("ChatBot初期化完了(モデル読み込みはGPU関数内で実行)") |
|
|
|
@spaces.GPU(duration=60) |
|
def load_model(self): |
|
"""モデルの読み込み(GPU環境内で実行)""" |
|
try: |
|
logger.info(f"モデル {self.model_path} を読み込み中...") |
|
|
|
|
|
if hasattr(self, 'model') and self.model is not None: |
|
del self.model |
|
if hasattr(self, 'tokenizer') and self.tokenizer is not None: |
|
del self.tokenizer |
|
|
|
|
|
gc.collect() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
device = torch.device("cuda") |
|
logger.info("CUDA環境を確認済み - GPU使用") |
|
|
|
|
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_quant_storage=torch.uint8 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
device = torch.device("cpu") |
|
quantization_config = None |
|
logger.info("CPU環境 - 量子化無効") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
self.model_path, |
|
token=HUGGINGFACE_TOKEN, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
model_kwargs = { |
|
"token": HUGGINGFACE_TOKEN, |
|
"trust_remote_code": True, |
|
"low_cpu_mem_usage": True, |
|
} |
|
|
|
if torch.cuda.is_available(): |
|
model_kwargs.update({ |
|
"quantization_config": quantization_config, |
|
"device_map": "auto", |
|
"torch_dtype": torch.float16 |
|
}) |
|
else: |
|
model_kwargs["torch_dtype"] = torch.float32 |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_path, |
|
**model_kwargs |
|
) |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
self.model = self.model.to(device) |
|
|
|
|
|
self.model.eval() |
|
|
|
self.model_loaded = True |
|
logger.info(f"モデル {self.model_path} の読み込み完了") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
memory_allocated = torch.cuda.memory_allocated() / 1024**3 |
|
logger.info(f"GPU メモリ使用量: {memory_allocated:.2f} GB") |
|
|
|
return "モデル読み込み完了" |
|
|
|
except Exception as e: |
|
logger.error(f"モデル読み込みエラー: {e}") |
|
self.model_loaded = False |
|
self.tokenizer = None |
|
self.model = None |
|
return f"モデル読み込み失敗: {str(e)}" |
|
|
|
@spaces.GPU(duration=60) |
|
def switch_model(self, model_name): |
|
"""モデルの切り替え(GPU環境内で実行)""" |
|
if model_name in self.available_models and model_name != self.current_model_name: |
|
self.current_model_name = model_name |
|
self.model_path = self.available_models[model_name] |
|
self.model_loaded = False |
|
|
|
|
|
self.response_cache.clear() |
|
|
|
result = self.load_model() |
|
return f"モデルを {model_name} に切り替えました: {result}" |
|
return f"モデル {model_name} は既に選択されているか、利用できません" |
|
|
|
def get_cache_key(self, message, config): |
|
"""キャッシュキーを生成""" |
|
key_data = f"{message.strip().lower()}_{self.current_model_name}_{str(config)}" |
|
return hash(key_data) |
|
|
|
def get_cached_response(self, cache_key): |
|
"""キャッシュから応答を取得""" |
|
return self.response_cache.get(cache_key) |
|
|
|
def set_cached_response(self, cache_key, response): |
|
"""応答をキャッシュに保存""" |
|
if len(self.response_cache) > 50: |
|
oldest_key = next(iter(self.response_cache)) |
|
del self.response_cache[oldest_key] |
|
self.response_cache[cache_key] = response |
|
|
|
def create_prompt(self, message, conversation_history): |
|
"""プロンプトの作成""" |
|
if self.current_model_name.startswith("gemma"): |
|
prompt = "" |
|
for msg in conversation_history[-6:]: |
|
if msg["role"] == "user": |
|
prompt += f"<start_of_turn>user\n{msg['content']}<end_of_turn>\n" |
|
else: |
|
prompt += f"<start_of_turn>model\n{msg['content']}<end_of_turn>\n" |
|
prompt += f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n" |
|
|
|
elif self.current_model_name.startswith("phi"): |
|
prompt = "<|system|>\nYou are a helpful AI assistant.<|end|>\n" |
|
for msg in conversation_history[-6:]: |
|
if msg["role"] == "user": |
|
prompt += f"<|user|>\n{msg['content']}<|end|>\n" |
|
else: |
|
prompt += f"<|assistant|>\n{msg['content']}<|end|>\n" |
|
prompt += f"<|user|>\n{message}<|end|>\n<|assistant|>\n" |
|
else: |
|
prompt = f"Human: {message}\nAssistant:" |
|
|
|
return prompt |
|
|
|
@spaces.GPU(duration=45) |
|
def generate_response(self, message, conversation_history=None): |
|
"""応答生成(GPU環境内で実行)""" |
|
|
|
if not self.model_loaded: |
|
load_result = self.load_model() |
|
if not self.model_loaded: |
|
return f"モデル読み込みに失敗しました: {load_result}" |
|
|
|
if conversation_history is None: |
|
conversation_history = [] |
|
|
|
try: |
|
|
|
prompt = self.create_prompt(message, conversation_history) |
|
|
|
|
|
inputs = self.tokenizer.encode( |
|
prompt, |
|
return_tensors='pt', |
|
max_length=1024, |
|
truncation=True, |
|
padding=True |
|
) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
inputs = inputs.to(self.model.device) |
|
|
|
|
|
generation_kwargs = { |
|
"inputs": inputs, |
|
"max_new_tokens": self.generation_config["max_new_tokens"], |
|
"temperature": self.generation_config["temperature"], |
|
"top_p": self.generation_config["top_p"], |
|
"do_sample": True, |
|
"repetition_penalty": self.generation_config["repetition_penalty"], |
|
"pad_token_id": self.tokenizer.pad_token_id, |
|
"eos_token_id": self.tokenizer.eos_token_id, |
|
"use_cache": True, |
|
"attention_mask": torch.ones_like(inputs) |
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate(**generation_kwargs) |
|
|
|
|
|
response = self.tokenizer.decode( |
|
outputs[0][inputs.shape[1]:], |
|
skip_special_tokens=True |
|
).strip() |
|
|
|
|
|
if self.current_model_name.startswith("gemma"): |
|
if "<end_of_turn>" in response: |
|
response = response.split("<end_of_turn>")[0].strip() |
|
elif self.current_model_name.startswith("phi"): |
|
if "<|" in response: |
|
response = response.split("<|")[0].strip() |
|
|
|
return response if response else "申し訳ありませんが、適切な応答を生成できませんでした。" |
|
|
|
except Exception as e: |
|
logger.error(f"応答生成エラー: {e}") |
|
return f"エラーが発生しました: {str(e)}" |
|
|
|
def get_conversation(self, session_id="default"): |
|
"""会話履歴を取得""" |
|
if session_id not in self.conversations: |
|
self.conversations[session_id] = [] |
|
return self.conversations[session_id] |
|
|
|
def add_to_conversation(self, session_id, role, content): |
|
"""会話履歴に追加""" |
|
if session_id not in self.conversations: |
|
self.conversations[session_id] = [] |
|
|
|
self.conversations[session_id].append({ |
|
"role": role, |
|
"content": content, |
|
"timestamp": datetime.now().isoformat() |
|
}) |
|
|
|
|
|
if len(self.conversations[session_id]) > 50: |
|
self.conversations[session_id] = self.conversations[session_id][-50:] |
|
|
|
def clear_conversation(self, session_id="default"): |
|
"""会話履歴をクリア""" |
|
if session_id in self.conversations: |
|
self.conversations[session_id] = [] |
|
|
|
def export_conversation(self, session_id="default"): |
|
"""会話履歴をエクスポート""" |
|
conversation = self.get_conversation(session_id) |
|
return json.dumps(conversation, indent=2, ensure_ascii=False) |
|
|
|
def update_generation_config(self, temperature, top_p, max_tokens, repetition_penalty): |
|
"""生成設定を更新""" |
|
self.generation_config.update({ |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"max_new_tokens": max_tokens, |
|
"repetition_penalty": repetition_penalty |
|
}) |
|
self.response_cache.clear() |
|
|
|
|
|
chatbot = ChatBot() |
|
|
|
def create_interface(): |
|
"""Gradioインターフェースの作成""" |
|
|
|
with gr.Blocks( |
|
title="ChatGPT Clone - Advanced AI Chat", |
|
theme=gr.themes.Soft(), |
|
css=""" |
|
.chatbot { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } |
|
.message { border-radius: 8px; margin: 4px 0; } |
|
.user-message { background-color: #e3f2fd; } |
|
.bot-message { background-color: #f1f8e9; } |
|
.status-box { background-color: #f5f5f5; padding: 10px; border-radius: 5px; } |
|
""" |
|
) as demo: |
|
|
|
|
|
gr.Markdown(""" |
|
# 🤖 ChatGPT Clone - Advanced AI Chat |
|
**Powered by Hugging Face Models** | **GPU Accelerated** | **Multi-Model Support** |
|
|
|
⚠️ **初回使用時**: 最初のメッセージ送信時にモデルが自動的に読み込まれます。少しお待ちください。 |
|
""") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=3): |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
status_display = gr.Markdown("**ステータス**: 待機中(モデル未読み込み)", elem_classes=["status-box"]) |
|
with gr.Column(): |
|
model_display = gr.Markdown(f"**現在のモデル**: {chatbot.current_model_name}", elem_classes=["status-box"]) |
|
|
|
|
|
chatbot_display = gr.Chatbot( |
|
label="AI Chat", |
|
height=500, |
|
show_label=False, |
|
type="messages", |
|
avatar_images=("👤", "🤖") |
|
) |
|
|
|
|
|
with gr.Row(): |
|
msg_input = gr.Textbox( |
|
placeholder="メッセージを入力してください(初回は読み込み時間がかかります)...", |
|
scale=4, |
|
show_label=False, |
|
lines=2 |
|
) |
|
send_button = gr.Button("📤 送信", scale=1, variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
clear_button = gr.Button("🗑️ クリア", size="sm") |
|
export_button = gr.Button("📁 エクスポート", size="sm") |
|
reload_model_button = gr.Button("🔄 モデル読込", size="sm") |
|
|
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("### ⚙️ 設定") |
|
|
|
|
|
model_selector = gr.Dropdown( |
|
choices=list(chatbot.available_models.keys()), |
|
value=chatbot.current_model_name, |
|
label="🔄 モデル選択", |
|
interactive=True |
|
) |
|
|
|
|
|
gr.Markdown("#### 生成パラメータ") |
|
|
|
temperature_slider = gr.Slider( |
|
minimum=0.1, maximum=2.0, value=0.8, step=0.1, |
|
label="🌡️ Temperature", info="創造性を調整" |
|
) |
|
|
|
top_p_slider = gr.Slider( |
|
minimum=0.1, maximum=1.0, value=0.9, step=0.05, |
|
label="🎯 Top-p", info="多様性を調整" |
|
) |
|
|
|
max_tokens_slider = gr.Slider( |
|
minimum=50, maximum=500, value=150, step=25, |
|
label="📝 最大トークン数" |
|
) |
|
|
|
repetition_penalty_slider = gr.Slider( |
|
minimum=1.0, maximum=2.0, value=1.2, step=0.1, |
|
label="🔄 繰り返し抑制" |
|
) |
|
|
|
|
|
gr.Markdown("#### 📊 統計") |
|
stats_display = gr.Markdown("**総会話数**: 0\n**キャッシュ数**: 0") |
|
|
|
|
|
export_output = gr.Textbox(label="エクスポートされた会話", visible=False) |
|
|
|
|
|
def send_message(message, history): |
|
if not message.strip(): |
|
return history, "", "**ステータス**: 空のメッセージです", get_stats() |
|
|
|
start_time = time.time() |
|
session_id = "default" |
|
|
|
try: |
|
|
|
cache_key = chatbot.get_cache_key(message, chatbot.generation_config) |
|
cached_response = chatbot.get_cached_response(cache_key) |
|
|
|
if cached_response: |
|
response = cached_response |
|
status = "**ステータス**: ✅ キャッシュから取得" |
|
else: |
|
|
|
conversation_history = chatbot.get_conversation(session_id) |
|
|
|
|
|
if not chatbot.model_loaded: |
|
status = "**ステータス**: 🔄 モデル読み込み中..." |
|
else: |
|
status = "**ステータス**: 🔄 応答生成中..." |
|
|
|
response = chatbot.generate_response(message, conversation_history) |
|
|
|
|
|
if response.startswith("エラーが発生しました") or response.startswith("モデル読み込みに失敗"): |
|
return history, message, f"**ステータス**: ❌ {response}", get_stats() |
|
|
|
chatbot.set_cached_response(cache_key, response) |
|
status = "**ステータス**: ✅ 新規生成完了" |
|
|
|
|
|
chatbot.add_to_conversation(session_id, "user", message) |
|
chatbot.add_to_conversation(session_id, "assistant", response) |
|
|
|
|
|
new_history = history + [{"role": "user", "content": message}] |
|
new_history = new_history + [{"role": "assistant", "content": response}] |
|
|
|
end_time = time.time() |
|
process_time = end_time - start_time |
|
|
|
status += f" ({process_time:.2f}秒)" |
|
|
|
return new_history, "", status, get_stats() |
|
|
|
except Exception as e: |
|
logger.error(f"メッセージ送信エラー: {e}") |
|
return history, message, f"**ステータス**: ❌ エラー: {str(e)}", get_stats() |
|
|
|
def clear_chat(): |
|
session_id = "default" |
|
chatbot.clear_conversation(session_id) |
|
return [], "**ステータス**: 🗑️ チャットをクリアしました", get_stats() |
|
|
|
def change_model(model_name): |
|
result = chatbot.switch_model(model_name) |
|
return f"**現在のモデル**: {chatbot.current_model_name}", f"**ステータス**: {result}" |
|
|
|
def reload_model(): |
|
result = chatbot.load_model() |
|
status = f"✅ {result}" if chatbot.model_loaded else f"❌ {result}" |
|
return f"**ステータス**: {status}" |
|
|
|
def update_params(temp, top_p, max_tokens, rep_penalty): |
|
chatbot.update_generation_config(temp, top_p, max_tokens, rep_penalty) |
|
return "**ステータス**: ⚙️ パラメータを更新しました(キャッシュクリア済み)" |
|
|
|
def export_chat(): |
|
session_id = "default" |
|
conversation_json = chatbot.export_conversation(session_id) |
|
return conversation_json, "**ステータス**: 📁 会話をエクスポートしました" |
|
|
|
def get_stats(): |
|
total_conversations = sum(len(conv) for conv in chatbot.conversations.values()) |
|
cache_count = len(chatbot.response_cache) |
|
model_status = "読み込み済み" if chatbot.model_loaded else "未読み込み" |
|
return f"**総会話数**: {total_conversations}\n**キャッシュ数**: {cache_count}\n**モデル**: {model_status}" |
|
|
|
|
|
send_button.click( |
|
send_message, |
|
inputs=[msg_input, chatbot_display], |
|
outputs=[chatbot_display, msg_input, status_display, stats_display] |
|
) |
|
|
|
msg_input.submit( |
|
send_message, |
|
inputs=[msg_input, chatbot_display], |
|
outputs=[chatbot_display, msg_input, status_display, stats_display] |
|
) |
|
|
|
clear_button.click( |
|
clear_chat, |
|
outputs=[chatbot_display, status_display, stats_display] |
|
) |
|
|
|
model_selector.change( |
|
change_model, |
|
inputs=[model_selector], |
|
outputs=[model_display, status_display] |
|
) |
|
|
|
reload_model_button.click( |
|
reload_model, |
|
outputs=[status_display] |
|
) |
|
|
|
export_button.click( |
|
export_chat, |
|
outputs=[export_output, status_display] |
|
) |
|
|
|
|
|
for slider in [temperature_slider, top_p_slider, max_tokens_slider, repetition_penalty_slider]: |
|
slider.change( |
|
update_params, |
|
inputs=[temperature_slider, top_p_slider, max_tokens_slider, repetition_penalty_slider], |
|
outputs=[status_display] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
|
|
|
|
demo.launch( |
|
share=False, |
|
show_error=True, |
|
server_name="0.0.0.0", |
|
server_port=7860 |
|
) |