File size: 4,139 Bytes
7767876
2aa9275
 
 
 
a327f7b
 
f333515
 
a327f7b
 
2aa9275
 
 
 
 
 
 
 
 
 
 
 
a327f7b
8b32594
 
 
 
 
 
 
2aa9275
f333515
2aa9275
 
f333515
 
8b32594
 
2aa9275
691edc1
f333515
 
 
 
 
2aa9275
 
 
 
 
 
 
 
 
 
 
c266967
2aa9275
 
 
8b32594
 
 
2aa9275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c266967
2aa9275
 
c266967
a327f7b
173f32e
 
f333515
173f32e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a327f7b
 
2aa9275
 
 
c266967
 
2aa9275
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings("ignore")

"""
Sarashinaモデルを使用したGradioチャットボット
Hugging Face Transformersライブラリを使用してローカルでモデルを実行
"""

# モデルとトークナイザーの初期化
MODEL_NAME = "sbintuitions/sarashina2.2-3b-instruct-v0.1"

print("モデルを読み込み中...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None,
    trust_remote_code=True
)
print("モデルの読み込みが完了しました。")

# グローバル設定変数
SYSTEM_MESSAGE = "あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。"
MAX_TOKENS = 512
TEMPERATURE = 0.7
TOP_P = 0.95

def respond(message, history):
    """
    チャットボットの応答を生成する関数
    """
    try:
        # システムメッセージと会話履歴を含むプロンプトを構築
        conversation = ""
        if SYSTEM_MESSAGE.strip():
            conversation += f"システム: {SYSTEM_MESSAGE}\n"
        
        # 会話履歴を追加
        for user_msg, bot_msg in history:
            if user_msg:
                conversation += f"ユーザー: {user_msg}\n"
            if bot_msg:
                conversation += f"アシスタント: {bot_msg}\n"
        
        # 現在のメッセージを追加
        conversation += f"ユーザー: {message}\nアシスタント: "
        
        # トークン化
        inputs = tokenizer.encode(conversation, return_tensors="pt")
        
        # GPU使用時はCUDAに移動
        if torch.cuda.is_available():
            inputs = inputs.cuda()
        
        # 応答生成
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_new_tokens=MAX_TOKENS,
                temperature=TEMPERATURE,
                top_p=TOP_P,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                repetition_penalty=1.1
            )
        
        # 生成されたテキストをデコード
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # 応答部分のみを抽出
        full_response = generated[len(conversation):].strip()
        
        # 不要な部分を除去
        if "ユーザー:" in full_response:
            full_response = full_response.split("ユーザー:")[0].strip()
        
        return full_response
            
    except Exception as e:
        return f"エラーが発生しました: {str(e)}"

# Gradio Blocksを使用したチャットインターフェース
with gr.Blocks(
    title="🤖 Sarashina Chatbot",
    theme=gr.themes.Soft()
) as demo:
    
    gr.Markdown("# 🤖 Sarashina Chatbot")
    gr.Markdown("Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。")
    
    chatbot = gr.Chatbot(height=500)
    msg = gr.Textbox(
        label="メッセージを入力してください",
        placeholder="こんにちは!何かお手伝いできることはありますか?",
        lines=2
    )
    clear = gr.Button("会話をクリア")
    
    def user(message, history):
        return "", history + [[message, None]]
    
    def bot(history):
        history[-1][1] = respond(history[-1][0], history[:-1])
        return history
    
    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True,  # パブリックリンクを作成
        show_api=True,
        debug=True
    )