doropiza commited on
Commit
3fd04b5
·
1 Parent(s): 18ad7bf
Files changed (2) hide show
  1. app.py +34 -33
  2. requirements.txt +1 -1
app.py CHANGED
@@ -99,55 +99,47 @@ import os
99
  import gradio as gr
100
  from transformers import AutoTokenizer, AutoModelForCausalLM
101
  import torch
102
- HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
103
 
 
 
104
  class ChatBot:
105
  def __init__(self):
106
- # 軽量なローカルLLMを使用(safetensors対応モデル)
107
- model_name = "microsoft/DialoGPT-small" # smallバージョンでメモリ使用量を削減
108
 
109
  try:
110
- self.tokenizer = AutoTokenizer.from_pretrained(
111
- model_name,
112
- token=HUGGINGFACE_TOKEN,
113
- trust_remote_code=True,
114
- use_safetensors=True
115
-
116
- )
117
  self.model = AutoModelForCausalLM.from_pretrained(
118
  model_name,
119
  token=HUGGINGFACE_TOKEN,
120
- trust_remote_code=True,
121
- use_safetensors=True,
122
- torch_dtype=torch.float16, # メモリ効率化
123
- device_map="auto"
124
  )
 
 
125
  except Exception as e:
126
- # フォールバック:より軽量なモデル
127
  print(f"モデル読み込みエラー: {e}")
128
- print("軽量モデルにフォールバック中...")
129
- model_name = "distilgpt2"
130
-
131
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
132
- self.model = AutoModelForCausalLM.from_pretrained(
133
- model_name,
134
- torch_dtype=torch.float16,
135
- device_map="auto"
136
- )
137
 
138
  # パディングトークンを設定
139
- if self.tokenizer.pad_token is None:
140
  self.tokenizer.pad_token = self.tokenizer.eos_token
141
 
142
  self.chat_history = []
143
 
144
  def generate_response(self, message):
 
 
 
 
145
  try:
146
  # 入力をトークン化
147
  inputs = self.tokenizer.encode(
148
- message + self.tokenizer.eos_token,
149
  return_tensors='pt',
150
- max_length=512,
151
  truncation=True
152
  )
153
 
@@ -155,13 +147,13 @@ class ChatBot:
155
  with torch.no_grad():
156
  outputs = self.model.generate(
157
  inputs,
158
- max_new_tokens=50, # 新しいトークン数を制限
159
  num_return_sequences=1,
160
- temperature=0.7,
161
  do_sample=True,
162
  pad_token_id=self.tokenizer.pad_token_id,
163
  eos_token_id=self.tokenizer.eos_token_id,
164
- repetition_penalty=1.1
165
  )
166
 
167
  # レスポンスをデコード
@@ -170,14 +162,23 @@ class ChatBot:
170
  skip_special_tokens=True
171
  )
172
 
173
- # 空のレスポンスの場合はデフォルト応答
174
  if not response.strip():
175
- response = "申し訳ありませんが、適切な応答を生成できませんでした。"
 
 
 
 
 
 
 
 
176
 
177
  return response.strip()
178
 
179
  except Exception as e:
180
- return f"エラーが発生しました: {str(e)}"
 
181
 
182
  def chat_interface(self, message, history):
183
  if not message.strip():
 
99
  import gradio as gr
100
  from transformers import AutoTokenizer, AutoModelForCausalLM
101
  import torch
 
102
 
103
+
104
+ HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
105
  class ChatBot:
106
  def __init__(self):
107
+ # ZeroGPU環境対応の軽量モデル
108
+ model_name = "distilgpt2" # 最も軽量で安定
109
 
110
  try:
111
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name,token=HUGGINGFACE_TOKEN)
 
 
 
 
 
 
112
  self.model = AutoModelForCausalLM.from_pretrained(
113
  model_name,
114
  token=HUGGINGFACE_TOKEN,
115
+ torch_dtype=torch.float32, # ZeroGPU互換性のためfloat32使用
116
+ low_cpu_mem_usage=True
 
 
117
  )
118
+ print(f"モデル {model_name} を正常に読み込みました")
119
+
120
  except Exception as e:
 
121
  print(f"モデル読み込みエラー: {e}")
122
+ # 最もシンプルなフォールバック
123
+ self.tokenizer = None
124
+ self.model = None
 
 
 
 
 
 
125
 
126
  # パディングトークンを設定
127
+ if self.tokenizer and self.tokenizer.pad_token is None:
128
  self.tokenizer.pad_token = self.tokenizer.eos_token
129
 
130
  self.chat_history = []
131
 
132
  def generate_response(self, message):
133
+ # モデルが利用できない場合のフォールバック
134
+ if not self.model or not self.tokenizer:
135
+ return "申し訳ありませんが、現在AIモデルが利用できません。シンプルな応答機能で対応いたします。"
136
+
137
  try:
138
  # 入力をトークン化
139
  inputs = self.tokenizer.encode(
140
+ message,
141
  return_tensors='pt',
142
+ max_length=256,
143
  truncation=True
144
  )
145
 
 
147
  with torch.no_grad():
148
  outputs = self.model.generate(
149
  inputs,
150
+ max_new_tokens=30, # さらに短縮
151
  num_return_sequences=1,
152
+ temperature=0.8,
153
  do_sample=True,
154
  pad_token_id=self.tokenizer.pad_token_id,
155
  eos_token_id=self.tokenizer.eos_token_id,
156
+ repetition_penalty=1.2
157
  )
158
 
159
  # レスポンスをデコード
 
162
  skip_special_tokens=True
163
  )
164
 
165
+ # 空のレスポンスの場合はシンプルな応答
166
  if not response.strip():
167
+ responses = [
168
+ "興味深いですね。",
169
+ "そうですね。",
170
+ "なるほど。",
171
+ "もう少し詳しく教えてください。",
172
+ "それについてどう思いますか?"
173
+ ]
174
+ import random
175
+ response = random.choice(responses)
176
 
177
  return response.strip()
178
 
179
  except Exception as e:
180
+ print(f"生成エラー: {e}")
181
+ return "申し訳ありませんが、応答の生成中にエラーが発生しました。"
182
 
183
  def chat_interface(self, message, history):
184
  if not message.strip():
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  huggingface_hub>=0.23.0
2
  gradio>=4.0.0
3
  transformers>=4.30.0
4
- torch>=2.6.0
5
  accelerate>=0.20.0
6
  sentencepiece>=0.1.99
7
  google-generativeai>=0.3.0
 
1
  huggingface_hub>=0.23.0
2
  gradio>=4.0.0
3
  transformers>=4.30.0
4
+ torch>=2.0.0,<2.6.0
5
  accelerate>=0.20.0
6
  sentencepiece>=0.1.99
7
  google-generativeai>=0.3.0