Spaces:
Running
Running
import torch | |
import numpy as np | |
import gc | |
import threading | |
import langdetect | |
import logging | |
from collections import OrderedDict, Counter | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from concurrent.futures import ThreadPoolExecutor | |
from functools import lru_cache, wraps | |
from contextlib import contextmanager | |
from typing import List, Dict, Optional, Tuple, Any, Callable | |
import re | |
from config import config | |
logger = logging.getLogger(__name__) | |
# Decorators and Context Managers | |
def handle_errors(default_return=None): | |
"""Centralized error handling decorator""" | |
def decorator(func: Callable) -> Callable: | |
def wrapper(*args, **kwargs): | |
try: | |
return func(*args, **kwargs) | |
except Exception as e: | |
logger.error(f"{func.__name__} failed: {e}") | |
return default_return if default_return is not None else f"Error: {str(e)}" | |
return wrapper | |
return decorator | |
def memory_cleanup(): | |
"""Context manager for memory cleanup""" | |
try: | |
yield | |
finally: | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
class ThemeContext: | |
"""Theme management context""" | |
def __init__(self, theme: str = 'default'): | |
self.theme = theme | |
self.colors = config.THEMES.get(theme, config.THEMES['default']) | |
class LRUModelCache: | |
"""LRU Cache for models with memory management""" | |
def __init__(self, max_size: int = 2): | |
self.max_size = max_size | |
self.cache = OrderedDict() | |
self.lock = threading.Lock() | |
def get(self, key): | |
with self.lock: | |
if key in self.cache: | |
# Move to end (most recently used) | |
self.cache.move_to_end(key) | |
return self.cache[key] | |
return None | |
def put(self, key, value): | |
with self.lock: | |
if key in self.cache: | |
self.cache.move_to_end(key) | |
else: | |
if len(self.cache) >= self.max_size: | |
# Remove least recently used | |
oldest_key = next(iter(self.cache)) | |
old_model, old_tokenizer = self.cache.pop(oldest_key) | |
# Force cleanup | |
del old_model, old_tokenizer | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
self.cache[key] = value | |
def clear(self): | |
with self.lock: | |
for model, tokenizer in self.cache.values(): | |
del model, tokenizer | |
self.cache.clear() | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Enhanced Model Manager with Optimized Memory Management | |
class ModelManager: | |
"""Optimized multi-language model manager with LRU cache and lazy loading""" | |
_instance = None | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super().__new__(cls) | |
cls._instance._initialized = False | |
return cls._instance | |
def __init__(self): | |
if not self._initialized: | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model_cache = LRUModelCache(config.MODEL_CACHE_SIZE) | |
self.loading_lock = threading.Lock() | |
self._initialized = True | |
logger.info(f"ModelManager initialized on device: {self.device}") | |
def _load_model(self, model_name: str, cache_key: str): | |
"""Load model with memory optimization""" | |
try: | |
logger.info(f"Loading model: {model_name}") | |
# Load with memory optimization | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" if torch.cuda.is_available() else None | |
) | |
if not torch.cuda.is_available(): | |
model.to(self.device) | |
# Set to eval mode to save memory | |
model.eval() | |
# Cache the model | |
self.model_cache.put(cache_key, (model, tokenizer)) | |
logger.info(f"Model {model_name} loaded and cached successfully") | |
return model, tokenizer | |
except Exception as e: | |
logger.error(f"Failed to load model {model_name}: {e}") | |
raise | |
def get_model(self, language='en'): | |
"""Get model for specific language with lazy loading and caching""" | |
# Determine cache key and model name | |
if language == 'zh': | |
cache_key = 'zh' | |
model_name = config.MODELS['zh'] | |
else: | |
cache_key = 'multilingual' | |
model_name = config.MODELS['multilingual'] | |
# Try to get from cache first | |
cached_model = self.model_cache.get(cache_key) | |
if cached_model is not None: | |
return cached_model | |
# Load model if not in cache (with thread safety) | |
with self.loading_lock: | |
# Double-check pattern | |
cached_model = self.model_cache.get(cache_key) | |
if cached_model is not None: | |
return cached_model | |
return self._load_model(model_name, cache_key) | |
def detect_language(text: str) -> str: | |
"""Detect text language""" | |
try: | |
detected = langdetect.detect(text) | |
language_mapping = { | |
'zh-cn': 'zh', | |
'zh-tw': 'zh' | |
} | |
detected = language_mapping.get(detected, detected) | |
return detected if detected in config.SUPPORTED_LANGUAGES else 'en' | |
except: | |
return 'en' | |
# Core Sentiment Analysis Engine with Performance Optimizations | |
class SentimentEngine: | |
"""Optimized multi-language sentiment analysis engine""" | |
def __init__(self): | |
self.model_manager = ModelManager() | |
self.executor = ThreadPoolExecutor(max_workers=4) | |
def analyze_single(self, text: str, language: str = 'auto', preprocessing_options: Dict = None) -> Dict: | |
"""Optimized single text analysis""" | |
if not text.strip(): | |
raise ValueError("Empty text provided") | |
# Detect language | |
if language == 'auto': | |
detected_lang = self.model_manager.detect_language(text) | |
else: | |
detected_lang = language | |
# Get appropriate model | |
model, tokenizer = self.model_manager.get_model(detected_lang) | |
# Preprocessing | |
options = preprocessing_options or {} | |
processed_text = text | |
if options.get('clean_text', False) and not re.search(r'[\u4e00-\u9fff]', text): | |
from data_utils import TextProcessor | |
processed_text = TextProcessor.clean_text( | |
text, | |
options.get('remove_punctuation', True), | |
options.get('remove_numbers', False) | |
) | |
# Tokenize and analyze with memory optimization | |
inputs = tokenizer(processed_text, return_tensors="pt", padding=True, | |
truncation=True, max_length=config.MAX_TEXT_LENGTH).to(self.model_manager.device) | |
# Use no_grad for inference to save memory | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()[0] | |
# Clear GPU cache after inference | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Handle different model outputs | |
if len(probs) == 3: # negative, neutral, positive | |
sentiment_idx = np.argmax(probs) | |
sentiment_labels = ['Negative', 'Neutral', 'Positive'] | |
sentiment = sentiment_labels[sentiment_idx] | |
confidence = float(probs[sentiment_idx]) | |
result = { | |
'sentiment': sentiment, | |
'confidence': confidence, | |
'neg_prob': float(probs[0]), | |
'neu_prob': float(probs[1]), | |
'pos_prob': float(probs[2]), | |
'has_neutral': True | |
} | |
else: # negative, positive | |
pred = np.argmax(probs) | |
sentiment = "Positive" if pred == 1 else "Negative" | |
confidence = float(probs[pred]) | |
result = { | |
'sentiment': sentiment, | |
'confidence': confidence, | |
'neg_prob': float(probs[0]), | |
'pos_prob': float(probs[1]), | |
'neu_prob': 0.0, | |
'has_neutral': False | |
} | |
# Add metadata | |
result.update({ | |
'language': detected_lang, | |
'word_count': len(text.split()), | |
'char_count': len(text) | |
}) | |
return result | |
def _analyze_text_batch(self, text: str, language: str, preprocessing_options: Dict, index: int) -> Dict: | |
"""Single text analysis for batch processing""" | |
try: | |
result = self.analyze_single(text, language, preprocessing_options) | |
result['batch_index'] = index | |
result['text'] = text[:100] + '...' if len(text) > 100 else text | |
result['full_text'] = text | |
return result | |
except Exception as e: | |
return { | |
'sentiment': 'Error', | |
'confidence': 0.0, | |
'error': str(e), | |
'batch_index': index, | |
'text': text[:100] + '...' if len(text) > 100 else text, | |
'full_text': text | |
} | |
def analyze_batch(self, texts: List[str], language: str = 'auto', | |
preprocessing_options: Dict = None, progress_callback=None) -> List[Dict]: | |
"""Optimized parallel batch processing""" | |
if len(texts) > config.BATCH_SIZE_LIMIT: | |
texts = texts[:config.BATCH_SIZE_LIMIT] | |
if not texts: | |
return [] | |
# Pre-load model to avoid race conditions | |
self.model_manager.get_model(language if language != 'auto' else 'en') | |
# Use ThreadPoolExecutor for parallel processing | |
with ThreadPoolExecutor(max_workers=min(4, len(texts))) as executor: | |
futures = [] | |
for i, text in enumerate(texts): | |
future = executor.submit( | |
self._analyze_text_batch, | |
text, language, preprocessing_options, i | |
) | |
futures.append(future) | |
results = [] | |
for i, future in enumerate(futures): | |
if progress_callback: | |
progress_callback((i + 1) / len(futures)) | |
try: | |
result = future.result(timeout=30) # 30 second timeout per text | |
results.append(result) | |
except Exception as e: | |
results.append({ | |
'sentiment': 'Error', | |
'confidence': 0.0, | |
'error': f"Timeout or error: {str(e)}", | |
'batch_index': i, | |
'text': texts[i][:100] + '...' if len(texts[i]) > 100 else texts[i], | |
'full_text': texts[i] | |
}) | |
return results |