entropy25 commited on
Commit
c185c85
·
verified ·
1 Parent(s): 33fba47

Create analysis.py

Browse files
Files changed (1) hide show
  1. analysis.py +301 -0
analysis.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import logging
4
+ import plotly.graph_objects as go
5
+ from typing import Tuple, Dict
6
+
7
+ # Advanced analysis imports
8
+ import shap
9
+ import lime
10
+ from lime.lime_text import LimeTextExplainer
11
+
12
+ from config import config
13
+ from models import ModelManager, handle_errors
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class AdvancedAnalysisEngine:
18
+ """Advanced analysis using SHAP and LIME with FIXED implementation"""
19
+
20
+ def __init__(self):
21
+ self.model_manager = ModelManager()
22
+
23
+ def create_prediction_function(self, model, tokenizer, device):
24
+ """Create FIXED prediction function for SHAP/LIME"""
25
+ def predict_proba(texts):
26
+ # Ensure texts is a list
27
+ if isinstance(texts, str):
28
+ texts = [texts]
29
+ elif isinstance(texts, np.ndarray):
30
+ texts = texts.tolist()
31
+
32
+ # Convert all elements to strings
33
+ texts = [str(text) for text in texts]
34
+
35
+ results = []
36
+ batch_size = 16 # Process in smaller batches
37
+
38
+ for i in range(0, len(texts), batch_size):
39
+ batch_texts = texts[i:i + batch_size]
40
+
41
+ try:
42
+ with torch.no_grad():
43
+ # Tokenize batch
44
+ inputs = tokenizer(
45
+ batch_texts,
46
+ return_tensors="pt",
47
+ padding=True,
48
+ truncation=True,
49
+ max_length=config.MAX_TEXT_LENGTH
50
+ ).to(device)
51
+
52
+ # Batch inference
53
+ outputs = model(**inputs)
54
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()
55
+
56
+ results.extend(probs)
57
+
58
+ except Exception as e:
59
+ logger.error(f"Prediction batch failed: {e}")
60
+ # Return neutral predictions for failed batch
61
+ batch_size_actual = len(batch_texts)
62
+ if hasattr(model.config, 'num_labels') and model.config.num_labels == 3:
63
+ neutral_probs = np.array([[0.33, 0.34, 0.33]] * batch_size_actual)
64
+ else:
65
+ neutral_probs = np.array([[0.5, 0.5]] * batch_size_actual)
66
+ results.extend(neutral_probs)
67
+
68
+ return np.array(results)
69
+
70
+ return predict_proba
71
+
72
+ @handle_errors(default_return=("Analysis failed", None, None))
73
+ def analyze_with_shap(self, text: str, language: str = 'auto', num_samples: int = 100) -> Tuple[str, go.Figure, Dict]:
74
+ """FIXED SHAP analysis implementation"""
75
+ if not text.strip():
76
+ return "Please enter text for analysis", None, {}
77
+
78
+ # Detect language and get model
79
+ if language == 'auto':
80
+ detected_lang = self.model_manager.detect_language(text)
81
+ else:
82
+ detected_lang = language
83
+
84
+ model, tokenizer = self.model_manager.get_model(detected_lang)
85
+
86
+ try:
87
+ # Create FIXED prediction function
88
+ predict_fn = self.create_prediction_function(model, tokenizer, self.model_manager.device)
89
+
90
+ # Test the prediction function first
91
+ test_pred = predict_fn([text])
92
+ if test_pred is None or len(test_pred) == 0:
93
+ return "Prediction function test failed", None, {}
94
+
95
+ # Use SHAP Text Explainer instead of generic Explainer
96
+ explainer = shap.Explainer(predict_fn, masker=shap.maskers.Text(tokenizer))
97
+
98
+ # Get SHAP values with proper text input
99
+ shap_values = explainer([text], max_evals=num_samples)
100
+
101
+ # Extract data safely
102
+ if hasattr(shap_values, 'data') and hasattr(shap_values, 'values'):
103
+ tokens = shap_values.data[0] if len(shap_values.data) > 0 else []
104
+ values = shap_values.values[0] if len(shap_values.values) > 0 else []
105
+ else:
106
+ return "SHAP values extraction failed", None, {}
107
+
108
+ if len(tokens) == 0 or len(values) == 0:
109
+ return "No tokens or values extracted from SHAP", None, {}
110
+
111
+ # Handle multi-dimensional values
112
+ if len(values.shape) > 1:
113
+ # Use positive class values (last column for 3-class, second for 2-class)
114
+ pos_values = values[:, -1] if values.shape[1] >= 2 else values[:, 0]
115
+ else:
116
+ pos_values = values
117
+
118
+ # Ensure we have matching lengths
119
+ min_len = min(len(tokens), len(pos_values))
120
+ tokens = tokens[:min_len]
121
+ pos_values = pos_values[:min_len]
122
+
123
+ # Create visualization
124
+ fig = go.Figure()
125
+
126
+ colors = ['red' if v < 0 else 'green' for v in pos_values]
127
+
128
+ fig.add_trace(go.Bar(
129
+ x=list(range(len(tokens))),
130
+ y=pos_values,
131
+ text=tokens,
132
+ textposition='outside',
133
+ marker_color=colors,
134
+ name='SHAP Values',
135
+ hovertemplate='<b>%{text}</b><br>SHAP Value: %{y:.4f}<extra></extra>'
136
+ ))
137
+
138
+ fig.update_layout(
139
+ title=f"SHAP Analysis - Token Importance (Samples: {num_samples})",
140
+ xaxis_title="Token Index",
141
+ yaxis_title="SHAP Value",
142
+ height=500,
143
+ xaxis=dict(tickmode='array', tickvals=list(range(len(tokens))), ticktext=tokens)
144
+ )
145
+
146
+ # Create analysis summary
147
+ analysis_data = {
148
+ 'method': 'SHAP',
149
+ 'language': detected_lang,
150
+ 'total_tokens': len(tokens),
151
+ 'samples_used': num_samples,
152
+ 'positive_influence': sum(1 for v in pos_values if v > 0),
153
+ 'negative_influence': sum(1 for v in pos_values if v < 0),
154
+ 'most_important_tokens': [(str(tokens[i]), float(pos_values[i]))
155
+ for i in np.argsort(np.abs(pos_values))[-5:]]
156
+ }
157
+
158
+ summary_text = f"""
159
+ **SHAP Analysis Results:**
160
+ - **Language:** {detected_lang.upper()}
161
+ - **Total Tokens:** {analysis_data['total_tokens']}
162
+ - **Samples Used:** {num_samples}
163
+ - **Positive Influence Tokens:** {analysis_data['positive_influence']}
164
+ - **Negative Influence Tokens:** {analysis_data['negative_influence']}
165
+ - **Most Important Tokens:** {', '.join([f"{token}({score:.3f})" for token, score in analysis_data['most_important_tokens']])}
166
+ - **Status:** SHAP analysis completed successfully
167
+ """
168
+
169
+ return summary_text, fig, analysis_data
170
+
171
+ except Exception as e:
172
+ logger.error(f"SHAP analysis failed: {e}")
173
+ error_msg = f"""
174
+ **SHAP Analysis Failed:**
175
+ - **Error:** {str(e)}
176
+ - **Language:** {detected_lang.upper()}
177
+ - **Suggestion:** Try with a shorter text or reduce number of samples
178
+
179
+ **Common fixes:**
180
+ - Reduce sample size to 50-100
181
+ - Use shorter input text (< 200 words)
182
+ - Check if model supports the text language
183
+ """
184
+ return error_msg, None, {}
185
+
186
+ @handle_errors(default_return=("Analysis failed", None, None))
187
+ def analyze_with_lime(self, text: str, language: str = 'auto', num_samples: int = 100) -> Tuple[str, go.Figure, Dict]:
188
+ """FIXED LIME analysis implementation - Bug Fix for mode parameter"""
189
+ if not text.strip():
190
+ return "Please enter text for analysis", None, {}
191
+
192
+ # Detect language and get model
193
+ if language == 'auto':
194
+ detected_lang = self.model_manager.detect_language(text)
195
+ else:
196
+ detected_lang = language
197
+
198
+ model, tokenizer = self.model_manager.get_model(detected_lang)
199
+
200
+ try:
201
+ # Create FIXED prediction function
202
+ predict_fn = self.create_prediction_function(model, tokenizer, self.model_manager.device)
203
+
204
+ # Test the prediction function first
205
+ test_pred = predict_fn([text])
206
+ if test_pred is None or len(test_pred) == 0:
207
+ return "Prediction function test failed", None, {}
208
+
209
+ # Determine class names based on model output
210
+ num_classes = test_pred.shape[1] if len(test_pred.shape) > 1 else 2
211
+ if num_classes == 3:
212
+ class_names = ['Negative', 'Neutral', 'Positive']
213
+ else:
214
+ class_names = ['Negative', 'Positive']
215
+
216
+ # Initialize LIME explainer - FIXED: Remove 'mode' parameter
217
+ explainer = LimeTextExplainer(class_names=class_names)
218
+
219
+ # Get LIME explanation
220
+ exp = explainer.explain_instance(
221
+ text,
222
+ predict_fn,
223
+ num_features=min(20, len(text.split())), # Limit features
224
+ num_samples=num_samples
225
+ )
226
+
227
+ # Extract feature importance
228
+ lime_data = exp.as_list()
229
+
230
+ if not lime_data:
231
+ return "No LIME features extracted", None, {}
232
+
233
+ # Create visualization
234
+ words = [item[0] for item in lime_data]
235
+ scores = [item[1] for item in lime_data]
236
+
237
+ fig = go.Figure()
238
+
239
+ colors = ['red' if s < 0 else 'green' for s in scores]
240
+
241
+ fig.add_trace(go.Bar(
242
+ y=words,
243
+ x=scores,
244
+ orientation='h',
245
+ marker_color=colors,
246
+ text=[f'{s:.3f}' for s in scores],
247
+ textposition='auto',
248
+ name='LIME Importance',
249
+ hovertemplate='<b>%{y}</b><br>Importance: %{x:.4f}<extra></extra>'
250
+ ))
251
+
252
+ fig.update_layout(
253
+ title=f"LIME Analysis - Feature Importance (Samples: {num_samples})",
254
+ xaxis_title="Importance Score",
255
+ yaxis_title="Words/Phrases",
256
+ height=500
257
+ )
258
+
259
+ # Create analysis summary
260
+ analysis_data = {
261
+ 'method': 'LIME',
262
+ 'language': detected_lang,
263
+ 'features_analyzed': len(lime_data),
264
+ 'samples_used': num_samples,
265
+ 'positive_features': sum(1 for _, score in lime_data if score > 0),
266
+ 'negative_features': sum(1 for _, score in lime_data if score < 0),
267
+ 'feature_importance': lime_data
268
+ }
269
+
270
+ summary_text = f"""
271
+ **LIME Analysis Results:**
272
+ - **Language:** {detected_lang.upper()}
273
+ - **Features Analyzed:** {analysis_data['features_analyzed']}
274
+ - **Classes:** {', '.join(class_names)}
275
+ - **Samples Used:** {num_samples}
276
+ - **Positive Features:** {analysis_data['positive_features']}
277
+ - **Negative Features:** {analysis_data['negative_features']}
278
+ - **Top Features:** {', '.join([f"{word}({score:.3f})" for word, score in lime_data[:5]])}
279
+ - **Status:** LIME analysis completed successfully
280
+ """
281
+
282
+ return summary_text, fig, analysis_data
283
+
284
+ except Exception as e:
285
+ logger.error(f"LIME analysis failed: {e}")
286
+ error_msg = f"""
287
+ **LIME Analysis Failed:**
288
+ - **Error:** {str(e)}
289
+ - **Language:** {detected_lang.upper()}
290
+ - **Suggestion:** Try with a shorter text or reduce number of samples
291
+
292
+ **Bug Fix Applied:**
293
+ - ✅ Removed 'mode' parameter from LimeTextExplainer initialization
294
+ - ✅ This should resolve the "unexpected keyword argument 'mode'" error
295
+
296
+ **Common fixes:**
297
+ - Reduce sample size to 50-100
298
+ - Use shorter input text (< 200 words)
299
+ - Check if model supports the text language
300
+ """
301
+ return error_msg, None, {}