sentiment-multi / analysis.py
entropy25's picture
Create analysis.py
c185c85 verified
import torch
import numpy as np
import logging
import plotly.graph_objects as go
from typing import Tuple, Dict
# Advanced analysis imports
import shap
import lime
from lime.lime_text import LimeTextExplainer
from config import config
from models import ModelManager, handle_errors
logger = logging.getLogger(__name__)
class AdvancedAnalysisEngine:
"""Advanced analysis using SHAP and LIME with FIXED implementation"""
def __init__(self):
self.model_manager = ModelManager()
def create_prediction_function(self, model, tokenizer, device):
"""Create FIXED prediction function for SHAP/LIME"""
def predict_proba(texts):
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
elif isinstance(texts, np.ndarray):
texts = texts.tolist()
# Convert all elements to strings
texts = [str(text) for text in texts]
results = []
batch_size = 16 # Process in smaller batches
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
try:
with torch.no_grad():
# Tokenize batch
inputs = tokenizer(
batch_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=config.MAX_TEXT_LENGTH
).to(device)
# Batch inference
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()
results.extend(probs)
except Exception as e:
logger.error(f"Prediction batch failed: {e}")
# Return neutral predictions for failed batch
batch_size_actual = len(batch_texts)
if hasattr(model.config, 'num_labels') and model.config.num_labels == 3:
neutral_probs = np.array([[0.33, 0.34, 0.33]] * batch_size_actual)
else:
neutral_probs = np.array([[0.5, 0.5]] * batch_size_actual)
results.extend(neutral_probs)
return np.array(results)
return predict_proba
@handle_errors(default_return=("Analysis failed", None, None))
def analyze_with_shap(self, text: str, language: str = 'auto', num_samples: int = 100) -> Tuple[str, go.Figure, Dict]:
"""FIXED SHAP analysis implementation"""
if not text.strip():
return "Please enter text for analysis", None, {}
# Detect language and get model
if language == 'auto':
detected_lang = self.model_manager.detect_language(text)
else:
detected_lang = language
model, tokenizer = self.model_manager.get_model(detected_lang)
try:
# Create FIXED prediction function
predict_fn = self.create_prediction_function(model, tokenizer, self.model_manager.device)
# Test the prediction function first
test_pred = predict_fn([text])
if test_pred is None or len(test_pred) == 0:
return "Prediction function test failed", None, {}
# Use SHAP Text Explainer instead of generic Explainer
explainer = shap.Explainer(predict_fn, masker=shap.maskers.Text(tokenizer))
# Get SHAP values with proper text input
shap_values = explainer([text], max_evals=num_samples)
# Extract data safely
if hasattr(shap_values, 'data') and hasattr(shap_values, 'values'):
tokens = shap_values.data[0] if len(shap_values.data) > 0 else []
values = shap_values.values[0] if len(shap_values.values) > 0 else []
else:
return "SHAP values extraction failed", None, {}
if len(tokens) == 0 or len(values) == 0:
return "No tokens or values extracted from SHAP", None, {}
# Handle multi-dimensional values
if len(values.shape) > 1:
# Use positive class values (last column for 3-class, second for 2-class)
pos_values = values[:, -1] if values.shape[1] >= 2 else values[:, 0]
else:
pos_values = values
# Ensure we have matching lengths
min_len = min(len(tokens), len(pos_values))
tokens = tokens[:min_len]
pos_values = pos_values[:min_len]
# Create visualization
fig = go.Figure()
colors = ['red' if v < 0 else 'green' for v in pos_values]
fig.add_trace(go.Bar(
x=list(range(len(tokens))),
y=pos_values,
text=tokens,
textposition='outside',
marker_color=colors,
name='SHAP Values',
hovertemplate='<b>%{text}</b><br>SHAP Value: %{y:.4f}<extra></extra>'
))
fig.update_layout(
title=f"SHAP Analysis - Token Importance (Samples: {num_samples})",
xaxis_title="Token Index",
yaxis_title="SHAP Value",
height=500,
xaxis=dict(tickmode='array', tickvals=list(range(len(tokens))), ticktext=tokens)
)
# Create analysis summary
analysis_data = {
'method': 'SHAP',
'language': detected_lang,
'total_tokens': len(tokens),
'samples_used': num_samples,
'positive_influence': sum(1 for v in pos_values if v > 0),
'negative_influence': sum(1 for v in pos_values if v < 0),
'most_important_tokens': [(str(tokens[i]), float(pos_values[i]))
for i in np.argsort(np.abs(pos_values))[-5:]]
}
summary_text = f"""
**SHAP Analysis Results:**
- **Language:** {detected_lang.upper()}
- **Total Tokens:** {analysis_data['total_tokens']}
- **Samples Used:** {num_samples}
- **Positive Influence Tokens:** {analysis_data['positive_influence']}
- **Negative Influence Tokens:** {analysis_data['negative_influence']}
- **Most Important Tokens:** {', '.join([f"{token}({score:.3f})" for token, score in analysis_data['most_important_tokens']])}
- **Status:** SHAP analysis completed successfully
"""
return summary_text, fig, analysis_data
except Exception as e:
logger.error(f"SHAP analysis failed: {e}")
error_msg = f"""
**SHAP Analysis Failed:**
- **Error:** {str(e)}
- **Language:** {detected_lang.upper()}
- **Suggestion:** Try with a shorter text or reduce number of samples
**Common fixes:**
- Reduce sample size to 50-100
- Use shorter input text (< 200 words)
- Check if model supports the text language
"""
return error_msg, None, {}
@handle_errors(default_return=("Analysis failed", None, None))
def analyze_with_lime(self, text: str, language: str = 'auto', num_samples: int = 100) -> Tuple[str, go.Figure, Dict]:
"""FIXED LIME analysis implementation - Bug Fix for mode parameter"""
if not text.strip():
return "Please enter text for analysis", None, {}
# Detect language and get model
if language == 'auto':
detected_lang = self.model_manager.detect_language(text)
else:
detected_lang = language
model, tokenizer = self.model_manager.get_model(detected_lang)
try:
# Create FIXED prediction function
predict_fn = self.create_prediction_function(model, tokenizer, self.model_manager.device)
# Test the prediction function first
test_pred = predict_fn([text])
if test_pred is None or len(test_pred) == 0:
return "Prediction function test failed", None, {}
# Determine class names based on model output
num_classes = test_pred.shape[1] if len(test_pred.shape) > 1 else 2
if num_classes == 3:
class_names = ['Negative', 'Neutral', 'Positive']
else:
class_names = ['Negative', 'Positive']
# Initialize LIME explainer - FIXED: Remove 'mode' parameter
explainer = LimeTextExplainer(class_names=class_names)
# Get LIME explanation
exp = explainer.explain_instance(
text,
predict_fn,
num_features=min(20, len(text.split())), # Limit features
num_samples=num_samples
)
# Extract feature importance
lime_data = exp.as_list()
if not lime_data:
return "No LIME features extracted", None, {}
# Create visualization
words = [item[0] for item in lime_data]
scores = [item[1] for item in lime_data]
fig = go.Figure()
colors = ['red' if s < 0 else 'green' for s in scores]
fig.add_trace(go.Bar(
y=words,
x=scores,
orientation='h',
marker_color=colors,
text=[f'{s:.3f}' for s in scores],
textposition='auto',
name='LIME Importance',
hovertemplate='<b>%{y}</b><br>Importance: %{x:.4f}<extra></extra>'
))
fig.update_layout(
title=f"LIME Analysis - Feature Importance (Samples: {num_samples})",
xaxis_title="Importance Score",
yaxis_title="Words/Phrases",
height=500
)
# Create analysis summary
analysis_data = {
'method': 'LIME',
'language': detected_lang,
'features_analyzed': len(lime_data),
'samples_used': num_samples,
'positive_features': sum(1 for _, score in lime_data if score > 0),
'negative_features': sum(1 for _, score in lime_data if score < 0),
'feature_importance': lime_data
}
summary_text = f"""
**LIME Analysis Results:**
- **Language:** {detected_lang.upper()}
- **Features Analyzed:** {analysis_data['features_analyzed']}
- **Classes:** {', '.join(class_names)}
- **Samples Used:** {num_samples}
- **Positive Features:** {analysis_data['positive_features']}
- **Negative Features:** {analysis_data['negative_features']}
- **Top Features:** {', '.join([f"{word}({score:.3f})" for word, score in lime_data[:5]])}
- **Status:** LIME analysis completed successfully
"""
return summary_text, fig, analysis_data
except Exception as e:
logger.error(f"LIME analysis failed: {e}")
error_msg = f"""
**LIME Analysis Failed:**
- **Error:** {str(e)}
- **Language:** {detected_lang.upper()}
- **Suggestion:** Try with a shorter text or reduce number of samples
**Bug Fix Applied:**
- βœ… Removed 'mode' parameter from LimeTextExplainer initialization
- βœ… This should resolve the "unexpected keyword argument 'mode'" error
**Common fixes:**
- Reduce sample size to 50-100
- Use shorter input text (< 200 words)
- Check if model supports the text language
"""
return error_msg, None, {}