Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer | |
from sklearn.preprocessing import LabelEncoder | |
from utils.BiLSTM import BiLSTMAttentionBERT | |
import numpy as np | |
def load_model_for_prediction(): | |
try: | |
# Load model from Hugging Face Hub | |
model = BiLSTMAttentionBERT.from_pretrained( | |
"joko333/BiLSTM_v01", | |
hidden_dim=128, | |
num_classes=22, | |
num_layers=2, | |
dropout=0.5 | |
) | |
model.eval() | |
# Initialize label encoder with predefined classes | |
label_encoder = LabelEncoder() | |
label_encoder.classes_ = np.array(['Addition', 'Causal', 'Cause and Effect', | |
'Clarification', 'Comparison', 'Concession', | |
'Conditional', 'Contrast', 'Contrastive Emphasis', | |
'Definition', 'Elaboration', 'Emphasis', | |
'Enumeration', 'Explanation', 'Generalization', | |
'Illustration', 'Inference', 'Problem Solution', | |
'Purpose', 'Sequential', 'Summary', | |
'Temporal Sequence']) | |
# Initialize tokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
'dmis-lab/biobert-base-cased-v1.2' | |
) | |
return model, label_encoder, tokenizer | |
except Exception as e: | |
print(f"Error loading model components: {str(e)}") | |
return None, None, None | |
def predict_sentence(model, sentence, tokenizer, label_encoder): | |
""" | |
Make prediction for a single sentence with label validation. | |
""" | |
model.eval() | |
# Tokenize | |
encoding = tokenizer( | |
sentence, | |
add_special_tokens=True, | |
max_length=512, | |
padding='max_length', | |
truncation=True, | |
return_tensors='pt' | |
) | |
try: | |
with torch.no_grad(): | |
# Get model outputs | |
outputs = model(encoding['input_ids'], encoding['attention_mask']) | |
probabilities = torch.softmax(outputs, dim=1) | |
# Get prediction and probability | |
prob, pred_idx = torch.max(probabilities, dim=1) | |
# Validate prediction index | |
if pred_idx.item() >= len(label_encoder.classes_): | |
print(f"Warning: Model predicted invalid label index {pred_idx.item()}") | |
return "Unknown", 0.0 | |
# Convert to label | |
try: | |
predicted_class = label_encoder.classes_[pred_idx.item()] | |
return predicted_class, prob.item() | |
except IndexError: | |
print(f"Warning: Invalid label index {pred_idx.item()}") | |
return "Unknown", 0.0 | |
except Exception as e: | |
print(f"Prediction error: {str(e)}") | |
return "Error", 0.0 | |
def print_labels(label_encoder, show_counts=False): | |
"""Print all labels and their corresponding indices""" | |
print("\nAvailable labels:") | |
print("-" * 40) | |
for idx, label in enumerate(label_encoder.classes_): | |
print(f"Index {idx}: {label}") | |
print("-" * 40) | |
print(f"Total number of classes: {len(label_encoder.classes_)}\n") | |