from utils.model import BiLSTMAttentionBERT, BiLSTMConfig import torch from transformers import AutoTokenizer, AutoModel from sklearn.preprocessing import LabelEncoder import numpy as np import streamlit as st import requests from huggingface_hub import hf_hub_download def load_model_for_prediction(): try: st.write("Starting model loading...") # Initialize BERT first bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2') # Initialize config and model config = BiLSTMConfig( hidden_dim=128, num_classes=22, num_layers=2, dropout=0.5 ) model = BiLSTMAttentionBERT(config) model.bert = bert # Set pre-trained BERT # Load custom layers from checkpoint model_path = hf_hub_download( repo_id="joko333/BiLSTM_v01", filename="model_epoch8_acc72.53.pt" ) checkpoint = torch.load(model_path, map_location='cpu') # Debug checkpoint structure st.write("Checkpoint keys:", checkpoint.keys()) if 'model_state_dict' in checkpoint: # Extract only custom layer weights custom_state_dict = {} state_dict = checkpoint['model_state_dict'] for key, value in state_dict.items(): if not key.startswith('bert.'): custom_state_dict[key] = value # Load custom layers model.load_state_dict(custom_state_dict, strict=False) st.write("Model loaded successfully") else: st.error("Invalid checkpoint format") return None, None, None # Initialize label encoder from checkpoint label_encoder = LabelEncoder() if 'label_encoder_classes' in checkpoint: label_encoder.classes_ = checkpoint['label_encoder_classes'] else: st.error("Label encoder data not found in checkpoint") return None, None, None # Load tokenizer tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2') return model, label_encoder, tokenizer except Exception as e: st.error(f"Error loading model: {str(e)}") return None, None, None def predict_sentence(model, sentence, tokenizer, label_encoder): """ Make prediction for a single sentence with label validation. """ import time start_time = time.time() # Validation checks st.write("🔄 Starting prediction process...") if model is None: st.error("Error: Model not loaded") return "Error: Model not loaded", 0.0 if tokenizer is None: st.error("Error: Tokenizer not loaded") return "Error: Tokenizer not loaded", 0.0 if label_encoder is None: st.error("Error: Label encoder not loaded") return "Error: Label encoder not loaded", 0.0 # Force CPU device st.write("⚙️ Preparing model...") device = torch.device('cpu') model = model.to(device) model.eval() # Tokenize try: st.write(f"📝 Processing text: {sentence[:50]}...") encoding = tokenizer( sentence, add_special_tokens=True, max_length=512, padding='max_length', truncation=True, return_tensors='pt' ).to(device) st.write("🤖 Running model inference...") with torch.no_grad(): outputs = model(encoding['input_ids'], encoding['attention_mask']) probabilities = torch.softmax(outputs, dim=1) prob, pred_idx = torch.max(probabilities, dim=1) predicted_label = label_encoder.classes_[pred_idx.item()] elapsed_time = time.time() - start_time st.write(f"✅ Prediction completed in {elapsed_time:.2f} seconds") return predicted_label, prob.item() except Exception as e: st.error(f"❌ Prediction error: {str(e)}") return f"Error: {str(e)}", 0.0 def print_labels(label_encoder, show_counts=False): """Print all labels and their corresponding indices""" print("\nAvailable labels:") print("-" * 40) for idx, label in enumerate(label_encoder.classes_): print(f"Index {idx}: {label}") print("-" * 40) print(f"Total number of classes: {len(label_encoder.classes_)}\n")