TheWeeeed commited on
Commit
4c26b67
·
verified ·
1 Parent(s): 2bb107e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -28
app.py CHANGED
@@ -215,48 +215,68 @@ def prepare_features_for_qa_inference(examples, tokenizer, pad_on_right, max_seq
215
 
216
  final_batch = {}
217
  if not processed_features:
218
- logger.warning(f"No features generated for example IDs: {examples.get('id', ['N/A'])}. Returning empty structure.")
219
- # 確保返回的結構與 .map 期望的一致,即字典的鍵是列名,值是空列表
220
  for key_to_ensure in ['input_ids', 'attention_mask', 'token_type_ids', 'example_id', 'offset_mapping']:
221
  final_batch[key_to_ensure] = []
222
  return final_batch
223
 
224
- # 1. 首先,將 processed_features (list of dicts) 轉換為 final_batch (dict of lists)
225
  for key in processed_features[0].keys(): # 假設所有特徵字典有相同的鍵
226
  final_batch[key] = [feature[key] for feature in processed_features]
227
 
228
- # 2. 然後,對 final_batch 中需要轉換為張量的字段進行健壯性檢查和修正
229
  keys_to_fix_for_tensor_conversion = ["input_ids", "attention_mask", "token_type_ids"]
230
  pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
231
- cls_token_id = tokenizer.cls_token_id if tokenizer.cls_token_id is not None else 101
232
  sep_token_id = tokenizer.sep_token_id if tokenizer.sep_token_id is not None else 102
233
-
234
- for key in keys_to_fix_for_tensor_conversion:
235
- if key in final_batch:
236
- # final_batch[key] 是一個列表的列表,例如 [[ids_for_feature1], [ids_for_feature2], ...]
237
- corrected_list_of_lists = []
238
- for i, single_feature_list in enumerate(final_batch[key]):
239
- if single_feature_list is None:
240
- logger.warning(f"Feature list for '{key}' at index {i} is None. Replacing with default for max_seq_len {max_seq_len}.")
241
- if key == "input_ids":
 
 
 
 
242
  default_seq = [cls_token_id, sep_token_id] + [pad_token_id] * (max_seq_len - 2)
243
- corrected_list_of_lists.append(default_seq[:max_seq_len])
244
- elif key == "attention_mask":
245
  default_mask = [1, 1] + [0] * (max_seq_len - 2)
246
- corrected_list_of_lists.append(default_mask[:max_seq_len])
247
- elif key == "token_type_ids":
248
- corrected_list_of_lists.append([0] * max_seq_len)
249
- elif not all(isinstance(x, int) for x in single_feature_list):
250
- logger.warning(f"Feature list for '{key}' at index {i} contains non-integers: {str(single_feature_list)[:50]}... Fixing Nones.")
251
- default_val = pad_token_id if key == "input_ids" else 0
252
- fixed_list = [default_val if not isinstance(x, int) else x for x in single_feature_list]
253
- corrected_list_of_lists.append(fixed_list)
 
 
 
 
 
 
 
 
 
 
 
254
  else:
255
- corrected_list_of_lists.append(single_feature_list) # List is already good
256
- final_batch[key] = corrected_list_of_lists
257
 
258
- # 在返回前,可以再加一層打印,確認修正後的 final_batch 結構
259
- # logger.debug(f"Returning final_batch from prepare_features: { {k: str(v)[:200] + '...' for k,v in final_batch.items()} }")
 
 
 
 
 
 
260
  return final_batch
261
 
262
  # postprocess_qa_predictions 函數也需要從 utils_qa.py 複製或導入
 
215
 
216
  final_batch = {}
217
  if not processed_features:
218
+ logger.warning(f" prepare_features_for_qa_inference 中,由於 tokenizer 沒有為 ID {examples.get('id', ['N/A'])[0]} 生成任何有效特徵 (processed_features 為空), 將返回空的特徵結構。")
219
+ # 確保所有期望的鍵都存在,並且值是空列表,以匹配 .map 的期望輸出結構
220
  for key_to_ensure in ['input_ids', 'attention_mask', 'token_type_ids', 'example_id', 'offset_mapping']:
221
  final_batch[key_to_ensure] = []
222
  return final_batch
223
 
224
+ # 1. processed_features (list of dicts) 轉換為 final_batch (dict of lists)
225
  for key in processed_features[0].keys(): # 假設所有特徵字典有相同的鍵
226
  final_batch[key] = [feature[key] for feature in processed_features]
227
 
228
+ # 2. final_batch 中需要轉換為張量的字段進行健壯性檢查和修正
229
  keys_to_fix_for_tensor_conversion = ["input_ids", "attention_mask", "token_type_ids"]
230
  pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
231
+ cls_token_id = tokenizer.cls_token_id if tokenizer.cls_token_id is not None else 101
232
  sep_token_id = tokenizer.sep_token_id if tokenizer.sep_token_id is not None else 102
233
+
234
+ for key_to_fix in keys_to_fix_for_tensor_conversion:
235
+ if key_to_fix in final_batch:
236
+ # final_batch[key_to_fix] 應該是一個列表的列表,例如 [[ids_for_feature1], [ids_for_feature2], ...]
237
+ list_of_feature_sequences = final_batch[key_to_fix]
238
+ corrected_list_of_feature_sequences = []
239
+
240
+ for i, single_feature_sequence in enumerate(list_of_feature_sequences):
241
+ current_example_id = final_batch.get("example_id", [f"unknown_example_index_{i}"]*len(list_of_feature_sequences) )[i]
242
+
243
+ if single_feature_sequence is None:
244
+ logger.warning(f"對於樣本 {current_example_id} 的特徵 {i}, 字段 '{key_to_fix}' 的整個序列是 None。將用默認安全序列替換。")
245
+ if key_to_fix == "input_ids":
246
  default_seq = [cls_token_id, sep_token_id] + [pad_token_id] * (max_seq_len - 2)
247
+ corrected_list_of_feature_sequences.append(default_seq[:max_seq_len])
248
+ elif key_to_fix == "attention_mask":
249
  default_mask = [1, 1] + [0] * (max_seq_len - 2)
250
+ corrected_list_of_feature_sequences.append(default_mask[:max_seq_len])
251
+ elif key_to_fix == "token_type_ids":
252
+ corrected_list_of_feature_sequences.append([0] * max_seq_len)
253
+ else: # 不應該發生,因為我們只檢查這三個鍵
254
+ corrected_list_of_feature_sequences.append([0] * max_seq_len) # 一個備用安全值
255
+ elif not all(isinstance(x, int) for x in single_feature_sequence):
256
+ logger.warning(f"對於樣本 {current_example_id} 的特徵 {i}, 字段 '{key_to_fix}' 列表內部包含非整數值: {str(single_feature_sequence)[:50]}... 將嘗試修正 None 值。")
257
+ default_val_for_element = pad_token_id if key_to_fix == "input_ids" else 0
258
+
259
+ fixed_sequence = []
260
+ for x_val in single_feature_sequence:
261
+ if x_val is None: # 如果列表中的某個元素是 None
262
+ fixed_sequence.append(default_val_for_element)
263
+ elif not isinstance(x_val, int): # 如果不是整數也不是 None (異常情況)
264
+ logger.error(f"嚴重錯誤:在 {key_to_fix} 中發現了既不是 int 也不是 None 的值: {x_val} (類型: {type(x_val)})。用默認值替換。")
265
+ fixed_sequence.append(default_val_for_element)
266
+ else:
267
+ fixed_sequence.append(x_val)
268
+ corrected_list_of_feature_sequences.append(fixed_sequence)
269
  else:
270
+ corrected_list_of_feature_sequences.append(single_feature_sequence) # 列表本身是好的
 
271
 
272
+ final_batch[key_to_fix] = corrected_list_of_feature_sequences
273
+
274
+ # (可選) 添加最終調試打印,檢查修正後的 final_batch
275
+ logger.info(f"DEBUG: Final batch being returned by prepare_features_for_qa_inference for example {examples.get('id', ['N/A'])[0]}:")
276
+ for key_to_log in ["input_ids", "attention_mask", "token_type_ids"]:
277
+ if key_to_log in final_batch:
278
+ logger.info(f" {key_to_log}: {str(final_batch[key_to_log])[:200]}...") # 打印部分內容
279
+
280
  return final_batch
281
 
282
  # postprocess_qa_predictions 函數也需要從 utils_qa.py 複製或導入