- app.py +34 -33
- requirements.txt +1 -1
app.py
CHANGED
@@ -99,55 +99,47 @@ import os
|
|
99 |
import gradio as gr
|
100 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
101 |
import torch
|
102 |
-
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
103 |
|
|
|
|
|
104 |
class ChatBot:
|
105 |
def __init__(self):
|
106 |
-
#
|
107 |
-
model_name = "
|
108 |
|
109 |
try:
|
110 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
111 |
-
model_name,
|
112 |
-
token=HUGGINGFACE_TOKEN,
|
113 |
-
trust_remote_code=True,
|
114 |
-
use_safetensors=True
|
115 |
-
|
116 |
-
)
|
117 |
self.model = AutoModelForCausalLM.from_pretrained(
|
118 |
model_name,
|
119 |
token=HUGGINGFACE_TOKEN,
|
120 |
-
|
121 |
-
|
122 |
-
torch_dtype=torch.float16, # メモリ効率化
|
123 |
-
device_map="auto"
|
124 |
)
|
|
|
|
|
125 |
except Exception as e:
|
126 |
-
# フォールバック:より軽量なモデル
|
127 |
print(f"モデル読み込みエラー: {e}")
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
132 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
133 |
-
model_name,
|
134 |
-
torch_dtype=torch.float16,
|
135 |
-
device_map="auto"
|
136 |
-
)
|
137 |
|
138 |
# パディングトークンを設定
|
139 |
-
if self.tokenizer.pad_token is None:
|
140 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
141 |
|
142 |
self.chat_history = []
|
143 |
|
144 |
def generate_response(self, message):
|
|
|
|
|
|
|
|
|
145 |
try:
|
146 |
# 入力をトークン化
|
147 |
inputs = self.tokenizer.encode(
|
148 |
-
message
|
149 |
return_tensors='pt',
|
150 |
-
max_length=
|
151 |
truncation=True
|
152 |
)
|
153 |
|
@@ -155,13 +147,13 @@ class ChatBot:
|
|
155 |
with torch.no_grad():
|
156 |
outputs = self.model.generate(
|
157 |
inputs,
|
158 |
-
max_new_tokens=
|
159 |
num_return_sequences=1,
|
160 |
-
temperature=0.
|
161 |
do_sample=True,
|
162 |
pad_token_id=self.tokenizer.pad_token_id,
|
163 |
eos_token_id=self.tokenizer.eos_token_id,
|
164 |
-
repetition_penalty=1.
|
165 |
)
|
166 |
|
167 |
# レスポンスをデコード
|
@@ -170,14 +162,23 @@ class ChatBot:
|
|
170 |
skip_special_tokens=True
|
171 |
)
|
172 |
|
173 |
-
#
|
174 |
if not response.strip():
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
return response.strip()
|
178 |
|
179 |
except Exception as e:
|
180 |
-
|
|
|
181 |
|
182 |
def chat_interface(self, message, history):
|
183 |
if not message.strip():
|
|
|
99 |
import gradio as gr
|
100 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
101 |
import torch
|
|
|
102 |
|
103 |
+
|
104 |
+
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
105 |
class ChatBot:
|
106 |
def __init__(self):
|
107 |
+
# ZeroGPU環境対応の軽量モデル
|
108 |
+
model_name = "distilgpt2" # 最も軽量で安定
|
109 |
|
110 |
try:
|
111 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name,token=HUGGINGFACE_TOKEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
self.model = AutoModelForCausalLM.from_pretrained(
|
113 |
model_name,
|
114 |
token=HUGGINGFACE_TOKEN,
|
115 |
+
torch_dtype=torch.float32, # ZeroGPU互換性のためfloat32使用
|
116 |
+
low_cpu_mem_usage=True
|
|
|
|
|
117 |
)
|
118 |
+
print(f"モデル {model_name} を正常に読み込みました")
|
119 |
+
|
120 |
except Exception as e:
|
|
|
121 |
print(f"モデル読み込みエラー: {e}")
|
122 |
+
# 最もシンプルなフォールバック
|
123 |
+
self.tokenizer = None
|
124 |
+
self.model = None
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
# パディングトークンを設定
|
127 |
+
if self.tokenizer and self.tokenizer.pad_token is None:
|
128 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
129 |
|
130 |
self.chat_history = []
|
131 |
|
132 |
def generate_response(self, message):
|
133 |
+
# モデルが利用できない場合のフォールバック
|
134 |
+
if not self.model or not self.tokenizer:
|
135 |
+
return "申し訳ありませんが、現在AIモデルが利用できません。シンプルな応答機能で対応いたします。"
|
136 |
+
|
137 |
try:
|
138 |
# 入力をトークン化
|
139 |
inputs = self.tokenizer.encode(
|
140 |
+
message,
|
141 |
return_tensors='pt',
|
142 |
+
max_length=256,
|
143 |
truncation=True
|
144 |
)
|
145 |
|
|
|
147 |
with torch.no_grad():
|
148 |
outputs = self.model.generate(
|
149 |
inputs,
|
150 |
+
max_new_tokens=30, # さらに短縮
|
151 |
num_return_sequences=1,
|
152 |
+
temperature=0.8,
|
153 |
do_sample=True,
|
154 |
pad_token_id=self.tokenizer.pad_token_id,
|
155 |
eos_token_id=self.tokenizer.eos_token_id,
|
156 |
+
repetition_penalty=1.2
|
157 |
)
|
158 |
|
159 |
# レスポンスをデコード
|
|
|
162 |
skip_special_tokens=True
|
163 |
)
|
164 |
|
165 |
+
# 空のレスポンスの場合はシンプルな応答
|
166 |
if not response.strip():
|
167 |
+
responses = [
|
168 |
+
"興味深いですね。",
|
169 |
+
"そうですね。",
|
170 |
+
"なるほど。",
|
171 |
+
"もう少し詳しく教えてください。",
|
172 |
+
"それについてどう思いますか?"
|
173 |
+
]
|
174 |
+
import random
|
175 |
+
response = random.choice(responses)
|
176 |
|
177 |
return response.strip()
|
178 |
|
179 |
except Exception as e:
|
180 |
+
print(f"生成エラー: {e}")
|
181 |
+
return "申し訳ありませんが、応答の生成中にエラーが発生しました。"
|
182 |
|
183 |
def chat_interface(self, message, history):
|
184 |
if not message.strip():
|
requirements.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
huggingface_hub>=0.23.0
|
2 |
gradio>=4.0.0
|
3 |
transformers>=4.30.0
|
4 |
-
torch>=2.6.0
|
5 |
accelerate>=0.20.0
|
6 |
sentencepiece>=0.1.99
|
7 |
google-generativeai>=0.3.0
|
|
|
1 |
huggingface_hub>=0.23.0
|
2 |
gradio>=4.0.0
|
3 |
transformers>=4.30.0
|
4 |
+
torch>=2.0.0,<2.6.0
|
5 |
accelerate>=0.20.0
|
6 |
sentencepiece>=0.1.99
|
7 |
google-generativeai>=0.3.0
|