joko333 commited on
Commit
ca5c473
·
1 Parent(s): 7fd3f3b

Implement sentence analysis functionality in Analysis page; add BiLSTM model and prediction utilities

Browse files
pages/Analysis.py CHANGED
@@ -1,15 +1,103 @@
1
  import streamlit as st
2
- #import transformers
3
- #from transformers import pipeline
 
4
 
5
- st.title('Text Analysis')
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # input_text = st.text_area("Enter text to analyze:")
10
- # if st.button('Analyze'):
11
- # if input_text:
12
- # result = classifier(input_text)
13
- # st.write(result)
14
- # else:
15
- # st.warning('Please enter some text')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import re
4
+ from utils.prediction import predict_sentence
5
 
6
+ def split_sentences_regex(text):
7
+ # Clean the text
8
+ text = re.sub(r'[\n\r]', ' ', text) # Remove newlines
9
+ text = re.sub(r'["\']', '', text) # Remove quotes
10
+ text = re.sub(r'\s+', ' ', text) # Normalize whitespace
11
+
12
+ # More aggressive pattern that looks for sentence endings
13
+ #pattern = r'[.!?]+[\s]+|[.!?]+$'
14
+ pattern = r'[.]'
15
+ # Split and clean resulting sentences
16
+ sentences = [s.strip() for s in re.split(pattern, text) if s]
17
+
18
+ # Filter out empty strings but keep sentences that don't start with capitals
19
+ return [s for s in sentences if len(s) > 0]
20
 
21
+ def split_sentences_with_abbrev(text):
22
+ # Common abbreviations to ignore
23
+ abbreviations = {'mr.', 'mrs.', 'dr.', 'sr.', 'jr.', 'vs.', 'e.g.', 'i.e.', 'etc.'}
24
+
25
+ # Split initially by potential sentence endings
26
+ parts = text.split('. ')
27
+ sentences = []
28
+ current = parts[0]
29
+
30
+ for part in parts[1:]:
31
+ # Check if the previous part ends with an abbreviation
32
+ ends_with_abbrev = any(current.lower().endswith(abbr) for abbr in abbreviations)
33
+
34
+ if ends_with_abbrev:
35
+ current = current + '. ' + part
36
+ else:
37
+ sentences.append(current)
38
+ current = part
39
+
40
+ sentences.append(current)
41
+ return sentences
42
 
43
+ def show_analysis():
44
+ st.title("Text Analysis")
45
+ st.write("Use this section to analyze the logical structure of your text.")
46
+
47
+ try:
48
+ if 'model' not in st.session_state:
49
+ st.error("Please initialize the model from the home page first.")
50
+ return
51
+
52
+ model = st.session_state.model
53
+ label_encoder = st.session_state.label_encoder
54
+ tokenizer = st.session_state.tokenizer
55
+
56
+ # Text input section
57
+ st.header("Analyze Your Text")
58
+ user_text = st.text_area("Enter your text here (multiple sentences allowed):", height=150)
59
+
60
+ if st.button("Analyze"):
61
+ if user_text:
62
+ # Split and analyze sentences
63
+ sentences = split_sentences_regex(user_text)
64
+
65
+ st.subheader("Analysis Results:")
66
+ for i, sentence in enumerate(sentences, 1):
67
+ with st.container():
68
+ label, confidence = predict_sentence(
69
+ model, sentence, tokenizer, label_encoder
70
+ )
71
+ if label not in ("Unknown", "Error"):
72
+ st.write("---")
73
+ st.write(f"**Sentence:** {sentence}")
74
+ st.write(f"**Predicted:** {label}")
75
+ st.progress(confidence)
76
+ else:
77
+ st.warning("Please enter some text to analyze.")
78
+
79
+ # Example Analysis Section
80
+ st.header("Example Analysis")
81
+ show_examples = st.checkbox("Show example analysis", key='show_examples')
82
+
83
+ if show_examples:
84
+ try:
85
+ df = pd.read_csv('data/raw/history_01.csv')
86
+ for sentence in df['Sentence'].head(5): # Limit to 5 examples
87
+ with st.container():
88
+ label, confidence = predict_sentence(
89
+ model, sentence, tokenizer, label_encoder
90
+ )
91
+ if label not in ("Unknown", "Error"):
92
+ st.write("---")
93
+ st.write(f"**Sentence:** {sentence}")
94
+ st.write(f"**Predicted:** {label}")
95
+ st.progress(confidence)
96
+ except FileNotFoundError:
97
+ st.warning("Example file not found. Please check the data path.")
98
+
99
+ except Exception as e:
100
+ st.error(f"Error: {str(e)}")
101
+
102
+ if __name__ == "__main__":
103
+ show_analysis()
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  streamlit
2
- #transformers
3
- #torch
 
 
 
1
  streamlit
2
+ pandas
3
+ numpy
4
+ transformers
5
+ torch
utils/BiLSTM.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import AutoModel
4
+
5
+
6
+ class BiLSTMAttentionBERT(nn.Module):
7
+ def __init__(self,
8
+ hidden_dim=256,
9
+ num_classes=22, # Based on the label distribution
10
+ num_layers=2, # Multiple LSTM layers
11
+ dropout=0.1):
12
+ super().__init__()
13
+
14
+ # Load BioBERT instead of BERT
15
+ self.bert_model = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
16
+ bert_dim = self.bert_model.config.hidden_size # Still 768 for BioBERT basee
17
+ # Dropout for BERT outputs
18
+ self.dropout_bert = nn.Dropout(dropout)
19
+ # Multi-layer BiLSTM
20
+ self.lstm = nn.LSTM(
21
+ input_size=bert_dim,
22
+ hidden_size=hidden_dim,
23
+ num_layers=num_layers,
24
+ bidirectional=True,
25
+ batch_first=True,
26
+ dropout=dropout if num_layers > 1 else 0
27
+ )
28
+
29
+ # Multi-head attention
30
+ self.attention = nn.MultiheadAttention(
31
+ embed_dim=hidden_dim * 2, # *2 for bidirectional
32
+ num_heads=1,
33
+ dropout=dropout,
34
+ batch_first=True
35
+ )
36
+
37
+ # Regularization layers
38
+ self.dropout1 = nn.Dropout(dropout)
39
+ self.dropout2 = nn.Dropout(dropout + 0.1)
40
+ self.layer_norm = nn.LayerNorm(hidden_dim * 2)
41
+ self.batch_norm = nn.BatchNorm1d(hidden_dim * 2)
42
+
43
+ # Classification head
44
+ self.classifier = nn.Sequential(
45
+ nn.Linear(hidden_dim * 2, hidden_dim),
46
+ nn.ReLU(),
47
+ nn.Dropout(dropout),
48
+ nn.BatchNorm1d(hidden_dim),
49
+ nn.Linear(hidden_dim, num_classes)
50
+ )
51
+
52
+ def forward(self, input_ids, attention_mask):
53
+ # BERT encoding
54
+ bert_output = self.bert_model(
55
+ input_ids=input_ids,
56
+ attention_mask=attention_mask,
57
+ return_dict=True
58
+ )
59
+ sequence_output = self.dropout_bert(bert_output.last_hidden_state)
60
+
61
+ # BiLSTM processing
62
+ lstm_out, _ = self.lstm(sequence_output)
63
+ lstm_out = self.layer_norm(lstm_out)
64
+
65
+ # Self-attention
66
+ attn_out, _ = self.attention(
67
+ query=lstm_out,
68
+ key=lstm_out,
69
+ value=lstm_out,
70
+ need_weights=False
71
+ )
72
+
73
+ # Pooling and normalization
74
+ pooled = torch.mean(attn_out, dim=1)
75
+ pooled = self.batch_norm(pooled)
76
+ pooled = self.dropout2(pooled)
77
+
78
+ # Classification
79
+ return self.classifier(pooled)
utils/__init__.py ADDED
File without changes
utils/prediction.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ from sklearn.preprocessing import LabelEncoder
4
+ from BiLSTM import BiLSTMAttentionBERT
5
+ import numpy as np
6
+
7
+
8
+
9
+ def load_model_for_prediction():
10
+ # Force CPU
11
+ device = torch.device('cpu')
12
+ torch.backends.mps.enabled = False
13
+
14
+ try:
15
+ # Load model from Hugging Face Hub
16
+ model = BiLSTMAttentionBERT.from_pretrained(
17
+ "joko333/BiLSTM_v01",
18
+ hidden_dim=128,
19
+ num_classes=22,
20
+ num_layers=2,
21
+ dropout=0.5
22
+ ).to(device)
23
+
24
+ model.eval()
25
+
26
+ # Initialize label encoder with predefined classes
27
+ label_encoder = LabelEncoder()
28
+ label_encoder.classes_ = np.array(['Addition', 'Causal', 'Cause and Effect',
29
+ 'Clarification', 'Comparison', 'Concession',
30
+ 'Conditional', 'Contrast', 'Contrastive Emphasis',
31
+ 'Definition', 'Elaboration', 'Emphasis',
32
+ 'Enumeration', 'Explanation', 'Generalization',
33
+ 'Illustration', 'Inference', 'Problem Solution',
34
+ 'Purpose', 'Sequential', 'Summary',
35
+ 'Temporal Sequence'])
36
+
37
+ # Initialize tokenizer
38
+ tokenizer = AutoTokenizer.from_pretrained(
39
+ 'dmis-lab/biobert-base-cased-v1.2'
40
+ )
41
+
42
+ return model, label_encoder, tokenizer
43
+
44
+ except Exception as e:
45
+ print(f"Error loading model components: {str(e)}")
46
+ return None, None, None
47
+
48
+ def predict_sentence(model, sentence, tokenizer, label_encoder, device=None):
49
+ """
50
+ Make prediction for a single sentence with label validation.
51
+ """
52
+ device = torch.device('cpu')
53
+ model = model.to(device)
54
+ model.eval()
55
+
56
+ # Tokenize
57
+ encoding = tokenizer(
58
+ sentence,
59
+ add_special_tokens=True,
60
+ max_length=512,
61
+ padding='max_length',
62
+ truncation=True,
63
+ return_tensors='pt'
64
+ ).to(device)
65
+
66
+ try:
67
+ with torch.no_grad():
68
+ # Get model outputs
69
+ outputs = model(encoding['input_ids'], encoding['attention_mask'])
70
+ probabilities = torch.softmax(outputs, dim=1)
71
+
72
+ # Get prediction and probability
73
+ prob, pred_idx = torch.max(probabilities, dim=1)
74
+
75
+ # Validate prediction index
76
+ if pred_idx.item() >= len(label_encoder.classes_):
77
+ print(f"Warning: Model predicted invalid label index {pred_idx.item()}")
78
+ return "Unknown", 0.0
79
+
80
+ # Convert to label
81
+ try:
82
+ predicted_class = label_encoder.classes_[pred_idx.item()]
83
+ return predicted_class, prob.item()
84
+ except IndexError:
85
+ print(f"Warning: Invalid label index {pred_idx.item()}")
86
+ return "Unknown", 0.0
87
+
88
+ except Exception as e:
89
+ print(f"Prediction error: {str(e)}")
90
+ return "Error", 0.0
91
+
92
+ def print_labels(label_encoder, show_counts=False):
93
+ """Print all labels and their corresponding indices"""
94
+ print("\nAvailable labels:")
95
+ print("-" * 40)
96
+ for idx, label in enumerate(label_encoder.classes_):
97
+ print(f"Index {idx}: {label}")
98
+ print("-" * 40)
99
+ print(f"Total number of classes: {len(label_encoder.classes_)}\n")
100
+
101
+ def predict_sentence2(sentence, model, tokenizer, label_encoder):
102
+ # Tokenize the input
103
+ inputs = tokenizer(sentence,
104
+ padding=True,
105
+ truncation=True,
106
+ return_tensors='pt',
107
+ max_length=512)
108
+
109
+ # Move inputs to the same device as model
110
+ device = next(model.parameters()).device
111
+ inputs = {k: v.to(device) for k, v in inputs.items()}
112
+
113
+ # Make prediction
114
+ with torch.no_grad():
115
+ outputs = model(**inputs)
116
+ predictions = torch.argmax(outputs.logits, dim=1)
117
+
118
+ # Convert prediction to label
119
+ predicted_label = label_encoder.inverse_transform(predictions.cpu().numpy())[0]
120
+
121
+ return predicted_label
122
+