File size: 3,977 Bytes
7767876 2aa9275 a327f7b f95231e a327f7b 2aa9275 a327f7b f95231e 2aa9275 f95231e 2aa9275 f95231e 691edc1 f95231e 2aa9275 691edc1 f95231e 2aa9275 f95231e 2aa9275 f95231e 2aa9275 a327f7b f95231e a327f7b f95231e 2aa9275 f95231e 2aa9275 a327f7b 2aa9275 f95231e 2aa9275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings("ignore")
"""
デバッグ用のシンプルなSarashinaチャットボット
additional_inputsなしでテスト
"""
# モデルとトークナイザーの初期化
MODEL_NAME = "sbintuitions/sarashina2.2-3b-instruct-v0.1"
print("モデルを読み込み中...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
print("モデルの読み込みが完了しました。")
def respond(message, history):
"""
シンプルなチャットボット応答関数
additional_inputsなし
"""
try:
# デバッグ情報を出力
print(f"DEBUG - message: {message} (type: {type(message)})")
print(f"DEBUG - history: {history} (type: {type(history)})")
# システムメッセージ(固定)
system_message = "あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。"
# 会話履歴を含むプロンプトを構築
conversation = f"システム: {system_message}\n"
# 会話履歴を追加
if history and isinstance(history, list):
for item in history:
if isinstance(item, (list, tuple)) and len(item) >= 2:
user_msg, bot_msg = item[0], item[1]
if user_msg:
conversation += f"ユーザー: {user_msg}\n"
if bot_msg:
conversation += f"アシスタント: {bot_msg}\n"
# 現在のメッセージを追加
conversation += f"ユーザー: {message}\nアシスタント: "
# トークン化
inputs = tokenizer.encode(conversation, return_tensors="pt")
# GPU使用時はCUDAに移動
if torch.cuda.is_available():
inputs = inputs.cuda()
# 応答生成
with torch.no_grad():
outputs = model.generate(
inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1
)
# 生成されたテキストをデコード
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 応答部分のみを抽出
full_response = generated[len(conversation):].strip()
# 不要な部分を除去
if "ユーザー:" in full_response:
full_response = full_response.split("ユーザー:")[0].strip()
# ストリーミング風の出力
for i in range(len(full_response)):
response = full_response[:i+1]
yield response
except Exception as e:
yield f"エラーが発生しました: {str(e)}"
"""
シンプルなChatInterface(additional_inputsなし)
"""
demo = gr.ChatInterface(
respond,
title="🤖 Sarashina Chatbot (Simple)",
description="Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。(デバッグ用)",
theme=gr.themes.Soft(),
examples=[
"こんにちは!",
"日本について教えて",
"プログラミングの質問があります",
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_api=True,
debug=True
) |