chatdemo2 / app.py
oggata's picture
Update app.py
173f32e verified
raw
history blame
4.14 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings("ignore")
"""
Sarashinaモデルを使用したGradioチャットボット
Hugging Face Transformersライブラリを使用してローカルでモデルを実行
"""
# モデルとトークナイザーの初期化
MODEL_NAME = "sbintuitions/sarashina2.2-3b-instruct-v0.1"
print("モデルを読み込み中...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
print("モデルの読み込みが完了しました。")
# グローバル設定変数
SYSTEM_MESSAGE = "あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。"
MAX_TOKENS = 512
TEMPERATURE = 0.7
TOP_P = 0.95
def respond(message, history):
"""
チャットボットの応答を生成する関数
"""
try:
# システムメッセージと会話履歴を含むプロンプトを構築
conversation = ""
if SYSTEM_MESSAGE.strip():
conversation += f"システム: {SYSTEM_MESSAGE}\n"
# 会話履歴を追加
for user_msg, bot_msg in history:
if user_msg:
conversation += f"ユーザー: {user_msg}\n"
if bot_msg:
conversation += f"アシスタント: {bot_msg}\n"
# 現在のメッセージを追加
conversation += f"ユーザー: {message}\nアシスタント: "
# トークン化
inputs = tokenizer.encode(conversation, return_tensors="pt")
# GPU使用時はCUDAに移動
if torch.cuda.is_available():
inputs = inputs.cuda()
# 応答生成
with torch.no_grad():
outputs = model.generate(
inputs,
max_new_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1
)
# 生成されたテキストをデコード
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 応答部分のみを抽出
full_response = generated[len(conversation):].strip()
# 不要な部分を除去
if "ユーザー:" in full_response:
full_response = full_response.split("ユーザー:")[0].strip()
return full_response
except Exception as e:
return f"エラーが発生しました: {str(e)}"
# Gradio Blocksを使用したチャットインターフェース
with gr.Blocks(
title="🤖 Sarashina Chatbot",
theme=gr.themes.Soft()
) as demo:
gr.Markdown("# 🤖 Sarashina Chatbot")
gr.Markdown("Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。")
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(
label="メッセージを入力してください",
placeholder="こんにちは!何かお手伝いできることはありますか?",
lines=2
)
clear = gr.Button("会話をクリア")
def user(message, history):
return "", history + [[message, None]]
def bot(history):
history[-1][1] = respond(history[-1][0], history[:-1])
return history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True, # パブリックリンクを作成
show_api=True,
debug=True
)