File size: 4,432 Bytes
a327f7b 2aa9275 a327f7b 2aa9275 a327f7b 2aa9275 a327f7b 465af28 2aa9275 465af28 2aa9275 465af28 2aa9275 465af28 2aa9275 465af28 2aa9275 a327f7b 2aa9275 465af28 a327f7b 465af28 2aa9275 465af28 2aa9275 a327f7b 2aa9275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings("ignore")
"""
Sarashinaモデルを使用したGradioチャットボット
Hugging Face Transformersライブラリを使用してローカルでモデルを実行
"""
# モデルとトークナイザーの初期化
MODEL_NAME = "sbintuitions/sarashina2.2-3b-instruct-v0.1"
print("モデルを読み込み中...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
print("モデルの読み込みが完了しました。")
def respond(message, history):
"""
チャットボットの応答を生成する関数
※ type="messages"の場合、historyはメッセージ辞書のリスト形式になります
"""
try:
# システムメッセージとして使用するプロンプト
system_message = "あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。"
# 会話履歴を文字列形式に変換
conversation = ""
if system_message.strip():
conversation += f"システム: {system_message}\n"
# 会話履歴を追加(messages形式の場合)
for msg in history:
if msg["role"] == "user":
conversation += f"ユーザー: {msg['content']}\n"
elif msg["role"] == "assistant":
conversation += f"アシスタント: {msg['content']}\n"
# 現在のメッセージを追加
conversation += f"ユーザー: {message}\nアシスタント: "
# トークン化
inputs = tokenizer.encode(conversation, return_tensors="pt")
# GPU使用時はCUDAに移動
if torch.cuda.is_available():
inputs = inputs.cuda()
# 応答生成(ストリーミング対応)
response = ""
with torch.no_grad():
# 一度に生成してからストリーミング風に出力
outputs = model.generate(
inputs,
max_new_tokens=512, # デフォルト値を使用
temperature=0.7, # デフォルト値を使用
top_p=0.95, # デフォルト値を使用
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1
)
# 生成されたテキストをデコード
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 応答部分のみを抽出
full_response = generated[len(conversation):].strip()
# 不要な部分を除去
if "ユーザー:" in full_response:
full_response = full_response.split("ユーザー:")[0].strip()
# ストリーミング風の出力
for i in range(len(full_response)):
response = full_response[:i+1]
yield response
except Exception as e:
yield f"エラーが発生しました: {str(e)}"
"""
Gradio ChatInterfaceを使用したシンプルなチャットボット
type="messages"を設定してOpenAI形式のメッセージを使用
"""
demo = gr.ChatInterface(
respond,
type="messages", # 重要:これによりhistoryがメッセージ辞書形式になります
title="🤖 Sarashina Chatbot",
description="Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。",
theme=gr.themes.Soft(),
examples=[
"こんにちは!今日はどんなことを話しましょうか?",
"日本の文化について教えてください。",
"簡単なレシピを教えてもらえますか?",
"プログラミングについて質問があります。",
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_api=True, # API documentation を表示
debug=True
) |