Spaces:
Sleeping
Sleeping
File size: 4,532 Bytes
1b17d16 ca5c473 67bf242 ca5c473 42d8a45 1b17d16 ca5c473 42d8a45 67bf242 1b17d16 67bf242 3681591 67bf242 1b17d16 3681591 42d8a45 67bf242 3681591 67bf242 3681591 67bf242 3681591 42d8a45 3681591 ca5c473 3681591 42d8a45 ca5c473 3681591 ca5c473 621f6b2 ca5c473 0858163 04dc908 0858163 04dc908 0858163 04dc908 0858163 04dc908 0858163 04dc908 0858163 04dc908 ca5c473 0858163 04dc908 0858163 ca5c473 04dc908 0858163 ca5c473 0858163 04dc908 ca5c473 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from utils.model import BiLSTMAttentionBERT, BiLSTMConfig
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import LabelEncoder
import numpy as np
import streamlit as st
import requests
from huggingface_hub import hf_hub_download
def load_model_for_prediction():
try:
st.write("Starting model loading...")
# Initialize BERT first
bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
# Initialize config and model
config = BiLSTMConfig(
hidden_dim=128,
num_classes=22,
num_layers=2,
dropout=0.5
)
model = BiLSTMAttentionBERT(config)
model.bert = bert # Set pre-trained BERT
# Load custom layers from checkpoint
model_path = hf_hub_download(
repo_id="joko333/BiLSTM_v01",
filename="model_epoch8_acc72.53.pt"
)
checkpoint = torch.load(model_path, map_location='cpu')
# Debug checkpoint structure
st.write("Checkpoint keys:", checkpoint.keys())
if 'model_state_dict' in checkpoint:
# Extract only custom layer weights
custom_state_dict = {}
state_dict = checkpoint['model_state_dict']
for key, value in state_dict.items():
if not key.startswith('bert.'):
custom_state_dict[key] = value
# Load custom layers
model.load_state_dict(custom_state_dict, strict=False)
st.write("Model loaded successfully")
else:
st.error("Invalid checkpoint format")
return None, None, None
# Initialize label encoder from checkpoint
label_encoder = LabelEncoder()
if 'label_encoder_classes' in checkpoint:
label_encoder.classes_ = checkpoint['label_encoder_classes']
else:
st.error("Label encoder data not found in checkpoint")
return None, None, None
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
return model, label_encoder, tokenizer
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None, None, None
def predict_sentence(model, sentence, tokenizer, label_encoder):
"""
Make prediction for a single sentence with label validation.
"""
import time
start_time = time.time()
# Validation checks
st.write("π Starting prediction process...")
if model is None:
st.error("Error: Model not loaded")
return "Error: Model not loaded", 0.0
if tokenizer is None:
st.error("Error: Tokenizer not loaded")
return "Error: Tokenizer not loaded", 0.0
if label_encoder is None:
st.error("Error: Label encoder not loaded")
return "Error: Label encoder not loaded", 0.0
# Force CPU device
st.write("βοΈ Preparing model...")
device = torch.device('cpu')
model = model.to(device)
model.eval()
# Tokenize
try:
st.write(f"π Processing text: {sentence[:50]}...")
encoding = tokenizer(
sentence,
add_special_tokens=True,
max_length=512,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(device)
st.write("π€ Running model inference...")
with torch.no_grad():
outputs = model(encoding['input_ids'], encoding['attention_mask'])
probabilities = torch.softmax(outputs, dim=1)
prob, pred_idx = torch.max(probabilities, dim=1)
predicted_label = label_encoder.classes_[pred_idx.item()]
elapsed_time = time.time() - start_time
st.write(f"β
Prediction completed in {elapsed_time:.2f} seconds")
return predicted_label, prob.item()
except Exception as e:
st.error(f"β Prediction error: {str(e)}")
return f"Error: {str(e)}", 0.0
def print_labels(label_encoder, show_counts=False):
"""Print all labels and their corresponding indices"""
print("\nAvailable labels:")
print("-" * 40)
for idx, label in enumerate(label_encoder.classes_):
print(f"Index {idx}: {label}")
print("-" * 40)
print(f"Total number of classes: {len(label_encoder.classes_)}\n")
|