Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -215,48 +215,68 @@ def prepare_features_for_qa_inference(examples, tokenizer, pad_on_right, max_seq
|
|
215 |
|
216 |
final_batch = {}
|
217 |
if not processed_features:
|
218 |
-
logger.warning(f"
|
219 |
-
#
|
220 |
for key_to_ensure in ['input_ids', 'attention_mask', 'token_type_ids', 'example_id', 'offset_mapping']:
|
221 |
final_batch[key_to_ensure] = []
|
222 |
return final_batch
|
223 |
|
224 |
-
# 1.
|
225 |
for key in processed_features[0].keys(): # 假設所有特徵字典有相同的鍵
|
226 |
final_batch[key] = [feature[key] for feature in processed_features]
|
227 |
|
228 |
-
# 2.
|
229 |
keys_to_fix_for_tensor_conversion = ["input_ids", "attention_mask", "token_type_ids"]
|
230 |
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
231 |
-
cls_token_id = tokenizer.cls_token_id if tokenizer.cls_token_id is not None else 101
|
232 |
sep_token_id = tokenizer.sep_token_id if tokenizer.sep_token_id is not None else 102
|
233 |
-
|
234 |
-
for
|
235 |
-
if
|
236 |
-
# final_batch[
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
|
|
|
|
|
|
|
|
242 |
default_seq = [cls_token_id, sep_token_id] + [pad_token_id] * (max_seq_len - 2)
|
243 |
-
|
244 |
-
elif
|
245 |
default_mask = [1, 1] + [0] * (max_seq_len - 2)
|
246 |
-
|
247 |
-
elif
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
else:
|
255 |
-
|
256 |
-
final_batch[key] = corrected_list_of_lists
|
257 |
|
258 |
-
|
259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
return final_batch
|
261 |
|
262 |
# postprocess_qa_predictions 函數也需要從 utils_qa.py 複製或導入
|
|
|
215 |
|
216 |
final_batch = {}
|
217 |
if not processed_features:
|
218 |
+
logger.warning(f"在 prepare_features_for_qa_inference 中,由於 tokenizer 沒有為 ID {examples.get('id', ['N/A'])[0]} 生成任何有效特徵 (processed_features 為空), 將返回空的特徵結構。")
|
219 |
+
# 確保所有期望的鍵都存在,並且值是空列表,以匹配 .map 的期望輸出結構
|
220 |
for key_to_ensure in ['input_ids', 'attention_mask', 'token_type_ids', 'example_id', 'offset_mapping']:
|
221 |
final_batch[key_to_ensure] = []
|
222 |
return final_batch
|
223 |
|
224 |
+
# 1. 將 processed_features (list of dicts) 轉換為 final_batch (dict of lists)
|
225 |
for key in processed_features[0].keys(): # 假設所有特徵字典有相同的鍵
|
226 |
final_batch[key] = [feature[key] for feature in processed_features]
|
227 |
|
228 |
+
# 2. 對 final_batch 中需要轉換為張量的字段進行健壯性檢查和修正
|
229 |
keys_to_fix_for_tensor_conversion = ["input_ids", "attention_mask", "token_type_ids"]
|
230 |
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
231 |
+
cls_token_id = tokenizer.cls_token_id if tokenizer.cls_token_id is not None else 101
|
232 |
sep_token_id = tokenizer.sep_token_id if tokenizer.sep_token_id is not None else 102
|
233 |
+
|
234 |
+
for key_to_fix in keys_to_fix_for_tensor_conversion:
|
235 |
+
if key_to_fix in final_batch:
|
236 |
+
# final_batch[key_to_fix] 應該是一個列表的列表,例如 [[ids_for_feature1], [ids_for_feature2], ...]
|
237 |
+
list_of_feature_sequences = final_batch[key_to_fix]
|
238 |
+
corrected_list_of_feature_sequences = []
|
239 |
+
|
240 |
+
for i, single_feature_sequence in enumerate(list_of_feature_sequences):
|
241 |
+
current_example_id = final_batch.get("example_id", [f"unknown_example_index_{i}"]*len(list_of_feature_sequences) )[i]
|
242 |
+
|
243 |
+
if single_feature_sequence is None:
|
244 |
+
logger.warning(f"對於樣本 {current_example_id} 的特徵 {i}, 字段 '{key_to_fix}' 的整個序列是 None。將用默認安全序列替換。")
|
245 |
+
if key_to_fix == "input_ids":
|
246 |
default_seq = [cls_token_id, sep_token_id] + [pad_token_id] * (max_seq_len - 2)
|
247 |
+
corrected_list_of_feature_sequences.append(default_seq[:max_seq_len])
|
248 |
+
elif key_to_fix == "attention_mask":
|
249 |
default_mask = [1, 1] + [0] * (max_seq_len - 2)
|
250 |
+
corrected_list_of_feature_sequences.append(default_mask[:max_seq_len])
|
251 |
+
elif key_to_fix == "token_type_ids":
|
252 |
+
corrected_list_of_feature_sequences.append([0] * max_seq_len)
|
253 |
+
else: # 不應該發生,因為我們只檢查這三個鍵
|
254 |
+
corrected_list_of_feature_sequences.append([0] * max_seq_len) # 一個備用安全值
|
255 |
+
elif not all(isinstance(x, int) for x in single_feature_sequence):
|
256 |
+
logger.warning(f"對於樣本 {current_example_id} 的特徵 {i}, 字段 '{key_to_fix}' 列表內部包含非整數值: {str(single_feature_sequence)[:50]}... 將嘗試修正 None 值。")
|
257 |
+
default_val_for_element = pad_token_id if key_to_fix == "input_ids" else 0
|
258 |
+
|
259 |
+
fixed_sequence = []
|
260 |
+
for x_val in single_feature_sequence:
|
261 |
+
if x_val is None: # 如果列表中的某個元素是 None
|
262 |
+
fixed_sequence.append(default_val_for_element)
|
263 |
+
elif not isinstance(x_val, int): # 如果不是整數也不是 None (異常情況)
|
264 |
+
logger.error(f"嚴重錯誤:在 {key_to_fix} 中發現了既不是 int 也不是 None 的值: {x_val} (類型: {type(x_val)})。用默認值替換。")
|
265 |
+
fixed_sequence.append(default_val_for_element)
|
266 |
+
else:
|
267 |
+
fixed_sequence.append(x_val)
|
268 |
+
corrected_list_of_feature_sequences.append(fixed_sequence)
|
269 |
else:
|
270 |
+
corrected_list_of_feature_sequences.append(single_feature_sequence) # 列表本身是好的
|
|
|
271 |
|
272 |
+
final_batch[key_to_fix] = corrected_list_of_feature_sequences
|
273 |
+
|
274 |
+
# (可選) 添加最終調試打印,檢查修正後的 final_batch
|
275 |
+
logger.info(f"DEBUG: Final batch being returned by prepare_features_for_qa_inference for example {examples.get('id', ['N/A'])[0]}:")
|
276 |
+
for key_to_log in ["input_ids", "attention_mask", "token_type_ids"]:
|
277 |
+
if key_to_log in final_batch:
|
278 |
+
logger.info(f" {key_to_log}: {str(final_batch[key_to_log])[:200]}...") # 打印部分內容
|
279 |
+
|
280 |
return final_batch
|
281 |
|
282 |
# postprocess_qa_predictions 函數也需要從 utils_qa.py 複製或導入
|