joko333's picture
Enhance prediction function with validation checks and improved error handling
04dc908
raw
history blame
3.34 kB
import torch
from transformers import AutoTokenizer
from sklearn.preprocessing import LabelEncoder
from utils.BiLSTM import BiLSTMAttentionBERT
import numpy as np
def load_model_for_prediction():
try:
# Load model from Hugging Face Hub
model = BiLSTMAttentionBERT.from_pretrained(
"joko333/BiLSTM_v01",
hidden_dim=128,
num_classes=22,
num_layers=2,
dropout=0.5
)
model.eval()
# Initialize label encoder with predefined classes
label_encoder = LabelEncoder()
label_encoder.classes_ = np.array(['Addition', 'Causal', 'Cause and Effect',
'Clarification', 'Comparison', 'Concession',
'Conditional', 'Contrast', 'Contrastive Emphasis',
'Definition', 'Elaboration', 'Emphasis',
'Enumeration', 'Explanation', 'Generalization',
'Illustration', 'Inference', 'Problem Solution',
'Purpose', 'Sequential', 'Summary',
'Temporal Sequence'])
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(
'dmis-lab/biobert-base-cased-v1.2'
)
return model, label_encoder, tokenizer
except Exception as e:
print(f"Error loading model components: {str(e)}")
return None, None, None
def predict_sentence(model, sentence, tokenizer, label_encoder):
"""
Make prediction for a single sentence with label validation.
"""
# Validation checks
if model is None:
print("Error: Model not loaded")
return "Error: Model not loaded", 0.0
if tokenizer is None:
print("Error: Tokenizer not loaded")
return "Error: Tokenizer not loaded", 0.0
if label_encoder is None:
print("Error: Label encoder not loaded")
return "Error: Label encoder not loaded", 0.0
# Force CPU device
device = torch.device('cpu')
model = model.to(device)
model.eval()
# Tokenize
try:
encoding = tokenizer(
sentence,
add_special_tokens=True,
max_length=512,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(device)
with torch.no_grad():
outputs = model(encoding['input_ids'], encoding['attention_mask'])
probabilities = torch.softmax(outputs, dim=1)
prob, pred_idx = torch.max(probabilities, dim=1)
predicted_label = label_encoder.classes_[pred_idx.item()]
return predicted_label, prob.item()
except Exception as e:
print(f"Prediction error: {str(e)}")
return f"Error: {str(e)}", 0.0
def print_labels(label_encoder, show_counts=False):
"""Print all labels and their corresponding indices"""
print("\nAvailable labels:")
print("-" * 40)
for idx, label in enumerate(label_encoder.classes_):
print(f"Index {idx}: {label}")
print("-" * 40)
print(f"Total number of classes: {len(label_encoder.classes_)}\n")