File size: 4,532 Bytes
1b17d16
ca5c473
67bf242
ca5c473
 
42d8a45
 
1b17d16
ca5c473
 
 
 
 
42d8a45
67bf242
 
 
 
 
1b17d16
 
 
 
 
 
 
 
67bf242
3681591
67bf242
1b17d16
 
 
 
3681591
42d8a45
67bf242
 
 
3681591
67bf242
 
3681591
67bf242
 
 
 
 
 
3681591
 
 
42d8a45
 
3681591
ca5c473
3681591
 
 
 
 
 
42d8a45
 
ca5c473
 
 
 
3681591
ca5c473
 
621f6b2
ca5c473
 
 
0858163
 
 
04dc908
0858163
04dc908
0858163
04dc908
 
0858163
04dc908
 
0858163
04dc908
 
 
0858163
04dc908
 
ca5c473
 
 
 
0858163
04dc908
 
 
 
 
 
 
 
 
0858163
ca5c473
 
 
 
04dc908
0858163
 
 
 
ca5c473
 
0858163
04dc908
ca5c473
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from utils.model import BiLSTMAttentionBERT, BiLSTMConfig
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import LabelEncoder
import numpy as np
import streamlit as st
import requests
from huggingface_hub import hf_hub_download



def load_model_for_prediction():
    try:
        st.write("Starting model loading...")
        
        # Initialize BERT first
        bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
        
        # Initialize config and model
        config = BiLSTMConfig(
            hidden_dim=128,
            num_classes=22,
            num_layers=2,
            dropout=0.5
        )
        
        model = BiLSTMAttentionBERT(config)
        model.bert = bert  # Set pre-trained BERT
        
        # Load custom layers from checkpoint
        model_path = hf_hub_download(
            repo_id="joko333/BiLSTM_v01",
            filename="model_epoch8_acc72.53.pt"
        )
        checkpoint = torch.load(model_path, map_location='cpu')
        
        # Debug checkpoint structure
        st.write("Checkpoint keys:", checkpoint.keys())
        
        if 'model_state_dict' in checkpoint:
            # Extract only custom layer weights
            custom_state_dict = {}
            state_dict = checkpoint['model_state_dict']
            for key, value in state_dict.items():
                if not key.startswith('bert.'):
                    custom_state_dict[key] = value
            
            # Load custom layers
            model.load_state_dict(custom_state_dict, strict=False)
            st.write("Model loaded successfully")
        else:
            st.error("Invalid checkpoint format")
            return None, None, None
            
        # Initialize label encoder from checkpoint
        label_encoder = LabelEncoder()
        if 'label_encoder_classes' in checkpoint:
            label_encoder.classes_ = checkpoint['label_encoder_classes']
        else:
            st.error("Label encoder data not found in checkpoint")
            return None, None, None
            
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
        
        return model, label_encoder, tokenizer
        
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return None, None, None

def predict_sentence(model, sentence, tokenizer, label_encoder):
    """
    Make prediction for a single sentence with label validation.
    """
    import time
    start_time = time.time()
    
    # Validation checks
    st.write("πŸ”„ Starting prediction process...")
    if model is None:
        st.error("Error: Model not loaded")
        return "Error: Model not loaded", 0.0
    if tokenizer is None:
        st.error("Error: Tokenizer not loaded")
        return "Error: Tokenizer not loaded", 0.0
    if label_encoder is None:
        st.error("Error: Label encoder not loaded")
        return "Error: Label encoder not loaded", 0.0
        
    # Force CPU device
    st.write("βš™οΈ Preparing model...")
    device = torch.device('cpu')
    model = model.to(device)
    model.eval()
    
    # Tokenize
    try:
        st.write(f"πŸ“ Processing text: {sentence[:50]}...")
        encoding = tokenizer(
            sentence,
            add_special_tokens=True,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        ).to(device)
        
        st.write("πŸ€– Running model inference...")
        with torch.no_grad():
            outputs = model(encoding['input_ids'], encoding['attention_mask'])
            probabilities = torch.softmax(outputs, dim=1)
            prob, pred_idx = torch.max(probabilities, dim=1)
            predicted_label = label_encoder.classes_[pred_idx.item()]
            
        elapsed_time = time.time() - start_time
        st.write(f"βœ… Prediction completed in {elapsed_time:.2f} seconds")
        return predicted_label, prob.item()
            
    except Exception as e:
        st.error(f"❌ Prediction error: {str(e)}")
        return f"Error: {str(e)}", 0.0
    
def print_labels(label_encoder, show_counts=False):
    """Print all labels and their corresponding indices"""
    print("\nAvailable labels:")
    print("-" * 40)
    for idx, label in enumerate(label_encoder.classes_):
        print(f"Index {idx}: {label}")
    print("-" * 40)
    print(f"Total number of classes: {len(label_encoder.classes_)}\n")