Spaces:
Sleeping
Sleeping
Implement sentence analysis functionality in Analysis page; add BiLSTM model and prediction utilities
Browse files- pages/Analysis.py +99 -11
- requirements.txt +4 -2
- utils/BiLSTM.py +79 -0
- utils/__init__.py +0 -0
- utils/prediction.py +122 -0
pages/Analysis.py
CHANGED
@@ -1,15 +1,103 @@
|
|
1 |
import streamlit as st
|
2 |
-
|
3 |
-
|
|
|
4 |
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import re
|
4 |
+
from utils.prediction import predict_sentence
|
5 |
|
6 |
+
def split_sentences_regex(text):
|
7 |
+
# Clean the text
|
8 |
+
text = re.sub(r'[\n\r]', ' ', text) # Remove newlines
|
9 |
+
text = re.sub(r'["\']', '', text) # Remove quotes
|
10 |
+
text = re.sub(r'\s+', ' ', text) # Normalize whitespace
|
11 |
+
|
12 |
+
# More aggressive pattern that looks for sentence endings
|
13 |
+
#pattern = r'[.!?]+[\s]+|[.!?]+$'
|
14 |
+
pattern = r'[.]'
|
15 |
+
# Split and clean resulting sentences
|
16 |
+
sentences = [s.strip() for s in re.split(pattern, text) if s]
|
17 |
+
|
18 |
+
# Filter out empty strings but keep sentences that don't start with capitals
|
19 |
+
return [s for s in sentences if len(s) > 0]
|
20 |
|
21 |
+
def split_sentences_with_abbrev(text):
|
22 |
+
# Common abbreviations to ignore
|
23 |
+
abbreviations = {'mr.', 'mrs.', 'dr.', 'sr.', 'jr.', 'vs.', 'e.g.', 'i.e.', 'etc.'}
|
24 |
+
|
25 |
+
# Split initially by potential sentence endings
|
26 |
+
parts = text.split('. ')
|
27 |
+
sentences = []
|
28 |
+
current = parts[0]
|
29 |
+
|
30 |
+
for part in parts[1:]:
|
31 |
+
# Check if the previous part ends with an abbreviation
|
32 |
+
ends_with_abbrev = any(current.lower().endswith(abbr) for abbr in abbreviations)
|
33 |
+
|
34 |
+
if ends_with_abbrev:
|
35 |
+
current = current + '. ' + part
|
36 |
+
else:
|
37 |
+
sentences.append(current)
|
38 |
+
current = part
|
39 |
+
|
40 |
+
sentences.append(current)
|
41 |
+
return sentences
|
42 |
|
43 |
+
def show_analysis():
|
44 |
+
st.title("Text Analysis")
|
45 |
+
st.write("Use this section to analyze the logical structure of your text.")
|
46 |
+
|
47 |
+
try:
|
48 |
+
if 'model' not in st.session_state:
|
49 |
+
st.error("Please initialize the model from the home page first.")
|
50 |
+
return
|
51 |
+
|
52 |
+
model = st.session_state.model
|
53 |
+
label_encoder = st.session_state.label_encoder
|
54 |
+
tokenizer = st.session_state.tokenizer
|
55 |
+
|
56 |
+
# Text input section
|
57 |
+
st.header("Analyze Your Text")
|
58 |
+
user_text = st.text_area("Enter your text here (multiple sentences allowed):", height=150)
|
59 |
+
|
60 |
+
if st.button("Analyze"):
|
61 |
+
if user_text:
|
62 |
+
# Split and analyze sentences
|
63 |
+
sentences = split_sentences_regex(user_text)
|
64 |
+
|
65 |
+
st.subheader("Analysis Results:")
|
66 |
+
for i, sentence in enumerate(sentences, 1):
|
67 |
+
with st.container():
|
68 |
+
label, confidence = predict_sentence(
|
69 |
+
model, sentence, tokenizer, label_encoder
|
70 |
+
)
|
71 |
+
if label not in ("Unknown", "Error"):
|
72 |
+
st.write("---")
|
73 |
+
st.write(f"**Sentence:** {sentence}")
|
74 |
+
st.write(f"**Predicted:** {label}")
|
75 |
+
st.progress(confidence)
|
76 |
+
else:
|
77 |
+
st.warning("Please enter some text to analyze.")
|
78 |
+
|
79 |
+
# Example Analysis Section
|
80 |
+
st.header("Example Analysis")
|
81 |
+
show_examples = st.checkbox("Show example analysis", key='show_examples')
|
82 |
+
|
83 |
+
if show_examples:
|
84 |
+
try:
|
85 |
+
df = pd.read_csv('data/raw/history_01.csv')
|
86 |
+
for sentence in df['Sentence'].head(5): # Limit to 5 examples
|
87 |
+
with st.container():
|
88 |
+
label, confidence = predict_sentence(
|
89 |
+
model, sentence, tokenizer, label_encoder
|
90 |
+
)
|
91 |
+
if label not in ("Unknown", "Error"):
|
92 |
+
st.write("---")
|
93 |
+
st.write(f"**Sentence:** {sentence}")
|
94 |
+
st.write(f"**Predicted:** {label}")
|
95 |
+
st.progress(confidence)
|
96 |
+
except FileNotFoundError:
|
97 |
+
st.warning("Example file not found. Please check the data path.")
|
98 |
+
|
99 |
+
except Exception as e:
|
100 |
+
st.error(f"Error: {str(e)}")
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
show_analysis()
|
requirements.txt
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
streamlit
|
2 |
-
|
3 |
-
|
|
|
|
|
|
1 |
streamlit
|
2 |
+
pandas
|
3 |
+
numpy
|
4 |
+
transformers
|
5 |
+
torch
|
utils/BiLSTM.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from transformers import AutoModel
|
4 |
+
|
5 |
+
|
6 |
+
class BiLSTMAttentionBERT(nn.Module):
|
7 |
+
def __init__(self,
|
8 |
+
hidden_dim=256,
|
9 |
+
num_classes=22, # Based on the label distribution
|
10 |
+
num_layers=2, # Multiple LSTM layers
|
11 |
+
dropout=0.1):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
# Load BioBERT instead of BERT
|
15 |
+
self.bert_model = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
|
16 |
+
bert_dim = self.bert_model.config.hidden_size # Still 768 for BioBERT basee
|
17 |
+
# Dropout for BERT outputs
|
18 |
+
self.dropout_bert = nn.Dropout(dropout)
|
19 |
+
# Multi-layer BiLSTM
|
20 |
+
self.lstm = nn.LSTM(
|
21 |
+
input_size=bert_dim,
|
22 |
+
hidden_size=hidden_dim,
|
23 |
+
num_layers=num_layers,
|
24 |
+
bidirectional=True,
|
25 |
+
batch_first=True,
|
26 |
+
dropout=dropout if num_layers > 1 else 0
|
27 |
+
)
|
28 |
+
|
29 |
+
# Multi-head attention
|
30 |
+
self.attention = nn.MultiheadAttention(
|
31 |
+
embed_dim=hidden_dim * 2, # *2 for bidirectional
|
32 |
+
num_heads=1,
|
33 |
+
dropout=dropout,
|
34 |
+
batch_first=True
|
35 |
+
)
|
36 |
+
|
37 |
+
# Regularization layers
|
38 |
+
self.dropout1 = nn.Dropout(dropout)
|
39 |
+
self.dropout2 = nn.Dropout(dropout + 0.1)
|
40 |
+
self.layer_norm = nn.LayerNorm(hidden_dim * 2)
|
41 |
+
self.batch_norm = nn.BatchNorm1d(hidden_dim * 2)
|
42 |
+
|
43 |
+
# Classification head
|
44 |
+
self.classifier = nn.Sequential(
|
45 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
46 |
+
nn.ReLU(),
|
47 |
+
nn.Dropout(dropout),
|
48 |
+
nn.BatchNorm1d(hidden_dim),
|
49 |
+
nn.Linear(hidden_dim, num_classes)
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, input_ids, attention_mask):
|
53 |
+
# BERT encoding
|
54 |
+
bert_output = self.bert_model(
|
55 |
+
input_ids=input_ids,
|
56 |
+
attention_mask=attention_mask,
|
57 |
+
return_dict=True
|
58 |
+
)
|
59 |
+
sequence_output = self.dropout_bert(bert_output.last_hidden_state)
|
60 |
+
|
61 |
+
# BiLSTM processing
|
62 |
+
lstm_out, _ = self.lstm(sequence_output)
|
63 |
+
lstm_out = self.layer_norm(lstm_out)
|
64 |
+
|
65 |
+
# Self-attention
|
66 |
+
attn_out, _ = self.attention(
|
67 |
+
query=lstm_out,
|
68 |
+
key=lstm_out,
|
69 |
+
value=lstm_out,
|
70 |
+
need_weights=False
|
71 |
+
)
|
72 |
+
|
73 |
+
# Pooling and normalization
|
74 |
+
pooled = torch.mean(attn_out, dim=1)
|
75 |
+
pooled = self.batch_norm(pooled)
|
76 |
+
pooled = self.dropout2(pooled)
|
77 |
+
|
78 |
+
# Classification
|
79 |
+
return self.classifier(pooled)
|
utils/__init__.py
ADDED
File without changes
|
utils/prediction.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from sklearn.preprocessing import LabelEncoder
|
4 |
+
from BiLSTM import BiLSTMAttentionBERT
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
def load_model_for_prediction():
|
10 |
+
# Force CPU
|
11 |
+
device = torch.device('cpu')
|
12 |
+
torch.backends.mps.enabled = False
|
13 |
+
|
14 |
+
try:
|
15 |
+
# Load model from Hugging Face Hub
|
16 |
+
model = BiLSTMAttentionBERT.from_pretrained(
|
17 |
+
"joko333/BiLSTM_v01",
|
18 |
+
hidden_dim=128,
|
19 |
+
num_classes=22,
|
20 |
+
num_layers=2,
|
21 |
+
dropout=0.5
|
22 |
+
).to(device)
|
23 |
+
|
24 |
+
model.eval()
|
25 |
+
|
26 |
+
# Initialize label encoder with predefined classes
|
27 |
+
label_encoder = LabelEncoder()
|
28 |
+
label_encoder.classes_ = np.array(['Addition', 'Causal', 'Cause and Effect',
|
29 |
+
'Clarification', 'Comparison', 'Concession',
|
30 |
+
'Conditional', 'Contrast', 'Contrastive Emphasis',
|
31 |
+
'Definition', 'Elaboration', 'Emphasis',
|
32 |
+
'Enumeration', 'Explanation', 'Generalization',
|
33 |
+
'Illustration', 'Inference', 'Problem Solution',
|
34 |
+
'Purpose', 'Sequential', 'Summary',
|
35 |
+
'Temporal Sequence'])
|
36 |
+
|
37 |
+
# Initialize tokenizer
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
39 |
+
'dmis-lab/biobert-base-cased-v1.2'
|
40 |
+
)
|
41 |
+
|
42 |
+
return model, label_encoder, tokenizer
|
43 |
+
|
44 |
+
except Exception as e:
|
45 |
+
print(f"Error loading model components: {str(e)}")
|
46 |
+
return None, None, None
|
47 |
+
|
48 |
+
def predict_sentence(model, sentence, tokenizer, label_encoder, device=None):
|
49 |
+
"""
|
50 |
+
Make prediction for a single sentence with label validation.
|
51 |
+
"""
|
52 |
+
device = torch.device('cpu')
|
53 |
+
model = model.to(device)
|
54 |
+
model.eval()
|
55 |
+
|
56 |
+
# Tokenize
|
57 |
+
encoding = tokenizer(
|
58 |
+
sentence,
|
59 |
+
add_special_tokens=True,
|
60 |
+
max_length=512,
|
61 |
+
padding='max_length',
|
62 |
+
truncation=True,
|
63 |
+
return_tensors='pt'
|
64 |
+
).to(device)
|
65 |
+
|
66 |
+
try:
|
67 |
+
with torch.no_grad():
|
68 |
+
# Get model outputs
|
69 |
+
outputs = model(encoding['input_ids'], encoding['attention_mask'])
|
70 |
+
probabilities = torch.softmax(outputs, dim=1)
|
71 |
+
|
72 |
+
# Get prediction and probability
|
73 |
+
prob, pred_idx = torch.max(probabilities, dim=1)
|
74 |
+
|
75 |
+
# Validate prediction index
|
76 |
+
if pred_idx.item() >= len(label_encoder.classes_):
|
77 |
+
print(f"Warning: Model predicted invalid label index {pred_idx.item()}")
|
78 |
+
return "Unknown", 0.0
|
79 |
+
|
80 |
+
# Convert to label
|
81 |
+
try:
|
82 |
+
predicted_class = label_encoder.classes_[pred_idx.item()]
|
83 |
+
return predicted_class, prob.item()
|
84 |
+
except IndexError:
|
85 |
+
print(f"Warning: Invalid label index {pred_idx.item()}")
|
86 |
+
return "Unknown", 0.0
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Prediction error: {str(e)}")
|
90 |
+
return "Error", 0.0
|
91 |
+
|
92 |
+
def print_labels(label_encoder, show_counts=False):
|
93 |
+
"""Print all labels and their corresponding indices"""
|
94 |
+
print("\nAvailable labels:")
|
95 |
+
print("-" * 40)
|
96 |
+
for idx, label in enumerate(label_encoder.classes_):
|
97 |
+
print(f"Index {idx}: {label}")
|
98 |
+
print("-" * 40)
|
99 |
+
print(f"Total number of classes: {len(label_encoder.classes_)}\n")
|
100 |
+
|
101 |
+
def predict_sentence2(sentence, model, tokenizer, label_encoder):
|
102 |
+
# Tokenize the input
|
103 |
+
inputs = tokenizer(sentence,
|
104 |
+
padding=True,
|
105 |
+
truncation=True,
|
106 |
+
return_tensors='pt',
|
107 |
+
max_length=512)
|
108 |
+
|
109 |
+
# Move inputs to the same device as model
|
110 |
+
device = next(model.parameters()).device
|
111 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
112 |
+
|
113 |
+
# Make prediction
|
114 |
+
with torch.no_grad():
|
115 |
+
outputs = model(**inputs)
|
116 |
+
predictions = torch.argmax(outputs.logits, dim=1)
|
117 |
+
|
118 |
+
# Convert prediction to label
|
119 |
+
predicted_label = label_encoder.inverse_transform(predictions.cpu().numpy())[0]
|
120 |
+
|
121 |
+
return predicted_label
|
122 |
+
|