entropy25 commited on
Commit
33fba47
·
verified ·
1 Parent(s): 12cdc6e

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +320 -0
models.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gc
4
+ import threading
5
+ import langdetect
6
+ import logging
7
+ from collections import OrderedDict, Counter
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from functools import lru_cache, wraps
11
+ from contextlib import contextmanager
12
+ from typing import List, Dict, Optional, Tuple, Any, Callable
13
+ import re
14
+
15
+ from config import config
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Decorators and Context Managers
20
+ def handle_errors(default_return=None):
21
+ """Centralized error handling decorator"""
22
+ def decorator(func: Callable) -> Callable:
23
+ @wraps(func)
24
+ def wrapper(*args, **kwargs):
25
+ try:
26
+ return func(*args, **kwargs)
27
+ except Exception as e:
28
+ logger.error(f"{func.__name__} failed: {e}")
29
+ return default_return if default_return is not None else f"Error: {str(e)}"
30
+ return wrapper
31
+ return decorator
32
+
33
+ @contextmanager
34
+ def memory_cleanup():
35
+ """Context manager for memory cleanup"""
36
+ try:
37
+ yield
38
+ finally:
39
+ gc.collect()
40
+ if torch.cuda.is_available():
41
+ torch.cuda.empty_cache()
42
+
43
+ class ThemeContext:
44
+ """Theme management context"""
45
+ def __init__(self, theme: str = 'default'):
46
+ self.theme = theme
47
+ self.colors = config.THEMES.get(theme, config.THEMES['default'])
48
+
49
+ class LRUModelCache:
50
+ """LRU Cache for models with memory management"""
51
+ def __init__(self, max_size: int = 2):
52
+ self.max_size = max_size
53
+ self.cache = OrderedDict()
54
+ self.lock = threading.Lock()
55
+
56
+ def get(self, key):
57
+ with self.lock:
58
+ if key in self.cache:
59
+ # Move to end (most recently used)
60
+ self.cache.move_to_end(key)
61
+ return self.cache[key]
62
+ return None
63
+
64
+ def put(self, key, value):
65
+ with self.lock:
66
+ if key in self.cache:
67
+ self.cache.move_to_end(key)
68
+ else:
69
+ if len(self.cache) >= self.max_size:
70
+ # Remove least recently used
71
+ oldest_key = next(iter(self.cache))
72
+ old_model, old_tokenizer = self.cache.pop(oldest_key)
73
+ # Force cleanup
74
+ del old_model, old_tokenizer
75
+ gc.collect()
76
+ if torch.cuda.is_available():
77
+ torch.cuda.empty_cache()
78
+
79
+ self.cache[key] = value
80
+
81
+ def clear(self):
82
+ with self.lock:
83
+ for model, tokenizer in self.cache.values():
84
+ del model, tokenizer
85
+ self.cache.clear()
86
+ gc.collect()
87
+ if torch.cuda.is_available():
88
+ torch.cuda.empty_cache()
89
+
90
+ # Enhanced Model Manager with Optimized Memory Management
91
+ class ModelManager:
92
+ """Optimized multi-language model manager with LRU cache and lazy loading"""
93
+ _instance = None
94
+
95
+ def __new__(cls):
96
+ if cls._instance is None:
97
+ cls._instance = super().__new__(cls)
98
+ cls._instance._initialized = False
99
+ return cls._instance
100
+
101
+ def __init__(self):
102
+ if not self._initialized:
103
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
+ self.model_cache = LRUModelCache(config.MODEL_CACHE_SIZE)
105
+ self.loading_lock = threading.Lock()
106
+ self._initialized = True
107
+ logger.info(f"ModelManager initialized on device: {self.device}")
108
+
109
+ def _load_model(self, model_name: str, cache_key: str):
110
+ """Load model with memory optimization"""
111
+ try:
112
+ logger.info(f"Loading model: {model_name}")
113
+
114
+ # Load with memory optimization
115
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
116
+ model = AutoModelForSequenceClassification.from_pretrained(
117
+ model_name,
118
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
119
+ device_map="auto" if torch.cuda.is_available() else None
120
+ )
121
+
122
+ if not torch.cuda.is_available():
123
+ model.to(self.device)
124
+
125
+ # Set to eval mode to save memory
126
+ model.eval()
127
+
128
+ # Cache the model
129
+ self.model_cache.put(cache_key, (model, tokenizer))
130
+ logger.info(f"Model {model_name} loaded and cached successfully")
131
+
132
+ return model, tokenizer
133
+
134
+ except Exception as e:
135
+ logger.error(f"Failed to load model {model_name}: {e}")
136
+ raise
137
+
138
+ def get_model(self, language='en'):
139
+ """Get model for specific language with lazy loading and caching"""
140
+ # Determine cache key and model name
141
+ if language == 'zh':
142
+ cache_key = 'zh'
143
+ model_name = config.MODELS['zh']
144
+ else:
145
+ cache_key = 'multilingual'
146
+ model_name = config.MODELS['multilingual']
147
+
148
+ # Try to get from cache first
149
+ cached_model = self.model_cache.get(cache_key)
150
+ if cached_model is not None:
151
+ return cached_model
152
+
153
+ # Load model if not in cache (with thread safety)
154
+ with self.loading_lock:
155
+ # Double-check pattern
156
+ cached_model = self.model_cache.get(cache_key)
157
+ if cached_model is not None:
158
+ return cached_model
159
+
160
+ return self._load_model(model_name, cache_key)
161
+
162
+ @staticmethod
163
+ def detect_language(text: str) -> str:
164
+ """Detect text language"""
165
+ try:
166
+ detected = langdetect.detect(text)
167
+ language_mapping = {
168
+ 'zh-cn': 'zh',
169
+ 'zh-tw': 'zh'
170
+ }
171
+ detected = language_mapping.get(detected, detected)
172
+ return detected if detected in config.SUPPORTED_LANGUAGES else 'en'
173
+ except:
174
+ return 'en'
175
+
176
+ # Core Sentiment Analysis Engine with Performance Optimizations
177
+ class SentimentEngine:
178
+ """Optimized multi-language sentiment analysis engine"""
179
+
180
+ def __init__(self):
181
+ self.model_manager = ModelManager()
182
+ self.executor = ThreadPoolExecutor(max_workers=4)
183
+
184
+ @handle_errors(default_return={'sentiment': 'Unknown', 'confidence': 0.0})
185
+ def analyze_single(self, text: str, language: str = 'auto', preprocessing_options: Dict = None) -> Dict:
186
+ """Optimized single text analysis"""
187
+ if not text.strip():
188
+ raise ValueError("Empty text provided")
189
+
190
+ # Detect language
191
+ if language == 'auto':
192
+ detected_lang = self.model_manager.detect_language(text)
193
+ else:
194
+ detected_lang = language
195
+
196
+ # Get appropriate model
197
+ model, tokenizer = self.model_manager.get_model(detected_lang)
198
+
199
+ # Preprocessing
200
+ options = preprocessing_options or {}
201
+ processed_text = text
202
+ if options.get('clean_text', False) and not re.search(r'[\u4e00-\u9fff]', text):
203
+ from data_utils import TextProcessor
204
+ processed_text = TextProcessor.clean_text(
205
+ text,
206
+ options.get('remove_punctuation', True),
207
+ options.get('remove_numbers', False)
208
+ )
209
+
210
+ # Tokenize and analyze with memory optimization
211
+ inputs = tokenizer(processed_text, return_tensors="pt", padding=True,
212
+ truncation=True, max_length=config.MAX_TEXT_LENGTH).to(self.model_manager.device)
213
+
214
+ # Use no_grad for inference to save memory
215
+ with torch.no_grad():
216
+ outputs = model(**inputs)
217
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()[0]
218
+
219
+ # Clear GPU cache after inference
220
+ if torch.cuda.is_available():
221
+ torch.cuda.empty_cache()
222
+
223
+ # Handle different model outputs
224
+ if len(probs) == 3: # negative, neutral, positive
225
+ sentiment_idx = np.argmax(probs)
226
+ sentiment_labels = ['Negative', 'Neutral', 'Positive']
227
+ sentiment = sentiment_labels[sentiment_idx]
228
+ confidence = float(probs[sentiment_idx])
229
+
230
+ result = {
231
+ 'sentiment': sentiment,
232
+ 'confidence': confidence,
233
+ 'neg_prob': float(probs[0]),
234
+ 'neu_prob': float(probs[1]),
235
+ 'pos_prob': float(probs[2]),
236
+ 'has_neutral': True
237
+ }
238
+ else: # negative, positive
239
+ pred = np.argmax(probs)
240
+ sentiment = "Positive" if pred == 1 else "Negative"
241
+ confidence = float(probs[pred])
242
+
243
+ result = {
244
+ 'sentiment': sentiment,
245
+ 'confidence': confidence,
246
+ 'neg_prob': float(probs[0]),
247
+ 'pos_prob': float(probs[1]),
248
+ 'neu_prob': 0.0,
249
+ 'has_neutral': False
250
+ }
251
+
252
+ # Add metadata
253
+ result.update({
254
+ 'language': detected_lang,
255
+ 'word_count': len(text.split()),
256
+ 'char_count': len(text)
257
+ })
258
+
259
+ return result
260
+
261
+ def _analyze_text_batch(self, text: str, language: str, preprocessing_options: Dict, index: int) -> Dict:
262
+ """Single text analysis for batch processing"""
263
+ try:
264
+ result = self.analyze_single(text, language, preprocessing_options)
265
+ result['batch_index'] = index
266
+ result['text'] = text[:100] + '...' if len(text) > 100 else text
267
+ result['full_text'] = text
268
+ return result
269
+ except Exception as e:
270
+ return {
271
+ 'sentiment': 'Error',
272
+ 'confidence': 0.0,
273
+ 'error': str(e),
274
+ 'batch_index': index,
275
+ 'text': text[:100] + '...' if len(text) > 100 else text,
276
+ 'full_text': text
277
+ }
278
+
279
+ @handle_errors(default_return=[])
280
+ def analyze_batch(self, texts: List[str], language: str = 'auto',
281
+ preprocessing_options: Dict = None, progress_callback=None) -> List[Dict]:
282
+ """Optimized parallel batch processing"""
283
+ if len(texts) > config.BATCH_SIZE_LIMIT:
284
+ texts = texts[:config.BATCH_SIZE_LIMIT]
285
+
286
+ if not texts:
287
+ return []
288
+
289
+ # Pre-load model to avoid race conditions
290
+ self.model_manager.get_model(language if language != 'auto' else 'en')
291
+
292
+ # Use ThreadPoolExecutor for parallel processing
293
+ with ThreadPoolExecutor(max_workers=min(4, len(texts))) as executor:
294
+ futures = []
295
+ for i, text in enumerate(texts):
296
+ future = executor.submit(
297
+ self._analyze_text_batch,
298
+ text, language, preprocessing_options, i
299
+ )
300
+ futures.append(future)
301
+
302
+ results = []
303
+ for i, future in enumerate(futures):
304
+ if progress_callback:
305
+ progress_callback((i + 1) / len(futures))
306
+
307
+ try:
308
+ result = future.result(timeout=30) # 30 second timeout per text
309
+ results.append(result)
310
+ except Exception as e:
311
+ results.append({
312
+ 'sentiment': 'Error',
313
+ 'confidence': 0.0,
314
+ 'error': f"Timeout or error: {str(e)}",
315
+ 'batch_index': i,
316
+ 'text': texts[i][:100] + '...' if len(texts[i]) > 100 else texts[i],
317
+ 'full_text': texts[i]
318
+ })
319
+
320
+ return results