# import gradio as gr | |
# import torch | |
# import os | |
# from transformers import AutoTokenizer, AutoModelForCausalLM | |
# # 環境変数からトークンを取得 | |
# HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
# if not HUGGINGFACE_TOKEN: | |
# raise ValueError("HUGGINGFACE_TOKEN environment variable is not set") | |
# def _load_model(): | |
# if not torch.cuda.is_available(): | |
# raise RuntimeError("GPU is not available but required.") | |
# print("GPU is available and model will be loaded.") | |
# return "GPU ready" | |
# _load_model() | |
# # モデルとトークナイザーの初期化 | |
# MODEL_NAME = "google/gemma-7b-it" | |
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HUGGINGFACE_TOKEN) | |
# model = AutoModelForCausalLM.from_pretrained( | |
# MODEL_NAME, | |
# torch_dtype=torch.float16, | |
# device_map="auto", | |
# token=HUGGINGFACE_TOKEN | |
# ) | |
# def generate_response(prompt): | |
# # プロンプトの準備 | |
# inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# # 応答の生成 | |
# with torch.no_grad(): | |
# outputs = model.generate( | |
# **inputs, | |
# max_new_tokens=512, | |
# temperature=0.7, | |
# top_p=0.9, | |
# do_sample=True, | |
# pad_token_id=tokenizer.eos_token_id | |
# ) | |
# # 応答のデコード | |
# response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# return response | |
# def respond(message, history): | |
# # チャット履歴の構築 | |
# chat_history = "" | |
# for msg in history: | |
# chat_history += f"{msg['role']}: {msg['content']}\n" | |
# # 現在のメッセージを追加 | |
# prompt = f"{chat_history}Human: {message}\nAssistant:" | |
# try: | |
# response = generate_response(prompt) | |
# # 応答から余分な部分を削除 | |
# response = response.split("Assistant:")[-1].strip() | |
# return response | |
# except Exception as e: | |
# return f"エラーが発生しました: {str(e)}" | |
# # Gradioインターフェースの設定 | |
# iface = gr.ChatInterface( | |
# fn=respond, | |
# textbox=gr.Textbox( | |
# placeholder="メッセージを入力してください...", | |
# container=False, | |
# scale=7, | |
# lines=2 | |
# ), | |
# chatbot=gr.Chatbot( | |
# height=600, | |
# show_copy_button=True, | |
# show_share_button=True, | |
# avatar_images=(None, None) | |
# ), | |
# title="Gemma Chat Assistant", | |
# description="Google Gemmaモデルを使用したチャットアシスタントです。", | |
# theme=gr.themes.Soft(), | |
# examples=[ | |
# "こんにちは", | |
# "自己紹介をしてください", | |
# "Pythonについて教えてください" | |
# ] | |
# ) | |
# if __name__ == "__main__": | |
# iface.launch( | |
# share=True, | |
# server_name="0.0.0.0", | |
# server_port=7860 | |
# ) | |
import os | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
class ChatBot: | |
def __init__(self): | |
# 軽量なローカルLLMを使用(日本語対応) | |
model_name = "google/gemma-7b-it" | |
# 日本語対応の場合は "rinna/japanese-gpt2-medium" に変更可能 | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_TOKEN) | |
self.model = AutoModelForCausalLM.from_pretrained(model_name, token=HUGGINGFACE_TOKEN) | |
# パディングトークンを設定 | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.chat_history = [] | |
def generate_response(self, message): | |
try: | |
# 入力をトークン化 | |
inputs = self.tokenizer.encode(message + self.tokenizer.eos_token, return_tensors='pt') | |
# レスポンス生成 | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
inputs, | |
max_length=inputs.shape[1] + 100, | |
num_return_sequences=1, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
# レスポンスをデコード | |
response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
return response.strip() | |
except Exception as e: | |
return f"エラーが発生しました: {str(e)}" | |
def chat_interface(self, message, history): | |
if not message.strip(): | |
return history, "" | |
# レスポンス生成 | |
bot_response = self.generate_response(message) | |
# 会話履歴を更新 | |
history.append([message, bot_response]) | |
return history, "" | |
# ChatBotインスタンス作成 | |
chatbot = ChatBot() | |
# Gradioインターフェース設定 | |
def create_interface(): | |
with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# 🤖 ChatGPT Clone") | |
gr.Markdown("ローカルLLMを使用したチャットボットです") | |
# チャット履歴表示 | |
chatbot_display = gr.Chatbot( | |
label="チャット", | |
height=400, | |
show_label=True | |
) | |
# 入力欄とボタン | |
with gr.Row(): | |
msg_input = gr.Textbox( | |
placeholder="メッセージを入力してください...", | |
scale=4, | |
show_label=False | |
) | |
send_button = gr.Button("送信", scale=1) | |
clear_button = gr.Button("クリア", scale=1) | |
# イベント処理 | |
def send_message(message, history): | |
return chatbot.chat_interface(message, history) | |
def clear_chat(): | |
chatbot.chat_history = [] | |
return [] | |
# ボタンクリック時の処理 | |
send_button.click( | |
send_message, | |
inputs=[msg_input, chatbot_display], | |
outputs=[chatbot_display, msg_input] | |
) | |
# Enterキーでも送信 | |
msg_input.submit( | |
send_message, | |
inputs=[msg_input, chatbot_display], | |
outputs=[chatbot_display, msg_input] | |
) | |
# クリアボタン | |
clear_button.click( | |
clear_chat, | |
outputs=[chatbot_display] | |
) | |
return demo | |
# アプリケーション起動 | |
if __name__ == "__main__": | |
demo = create_interface() | |
# ローカル開発用 | |
# demo.launch(share=False, server_name="127.0.0.1", server_port=7860) | |
# Hugging Face Spaces用 | |
demo.launch(share=True) | |