import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM import warnings warnings.filterwarnings("ignore") """ Sarashinaモデルを使用したGradioチャットボット Hugging Face Transformersライブラリを使用してローカルでモデルを実行 """ # モデルとトークナイザーの初期化 MODEL_NAME = "sbintuitions/sarashina2.2-3b-instruct-v0.1" print("モデルを読み込み中...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True ) print("モデルの読み込みが完了しました。") # グローバル設定変数 SYSTEM_MESSAGE = "あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。" MAX_TOKENS = 512 TEMPERATURE = 0.7 TOP_P = 0.95 def respond(message, history): """ チャットボットの応答を生成する関数 """ try: # システムメッセージと会話履歴を含むプロンプトを構築 conversation = "" if SYSTEM_MESSAGE.strip(): conversation += f"システム: {SYSTEM_MESSAGE}\n" # 会話履歴を追加 for user_msg, bot_msg in history: if user_msg: conversation += f"ユーザー: {user_msg}\n" if bot_msg: conversation += f"アシスタント: {bot_msg}\n" # 現在のメッセージを追加 conversation += f"ユーザー: {message}\nアシスタント: " # トークン化 inputs = tokenizer.encode(conversation, return_tensors="pt") # GPU使用時はCUDAに移動 if torch.cuda.is_available(): inputs = inputs.cuda() # 応答生成 with torch.no_grad(): outputs = model.generate( inputs, max_new_tokens=MAX_TOKENS, temperature=TEMPERATURE, top_p=TOP_P, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.1 ) # 生成されたテキストをデコード generated = tokenizer.decode(outputs[0], skip_special_tokens=True) # 応答部分のみを抽出 full_response = generated[len(conversation):].strip() # 不要な部分を除去 if "ユーザー:" in full_response: full_response = full_response.split("ユーザー:")[0].strip() return full_response except Exception as e: return f"エラーが発生しました: {str(e)}" # Gradio Blocksを使用したチャットインターフェース with gr.Blocks( title="🤖 Sarashina Chatbot", theme=gr.themes.Soft() ) as demo: gr.Markdown("# 🤖 Sarashina Chatbot") gr.Markdown("Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。") chatbot = gr.Chatbot(height=500) msg = gr.Textbox( label="メッセージを入力してください", placeholder="こんにちは!何かお手伝いできることはありますか?", lines=2 ) clear = gr.Button("会話をクリア") def user(message, history): return "", history + [[message, None]] def bot(history): history[-1][1] = respond(history[-1][0], history[:-1]) return history msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, chatbot, chatbot ) clear.click(lambda: None, None, chatbot, queue=False) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=True, # パブリックリンクを作成 show_api=True, debug=True )