import os import json import time import hashlib from pathlib import Path from typing import Any, Optional, Dict, Union import pickle import shutil from datetime import datetime, timedelta class ModelCache: """Manages caching for AI models and generated content""" def __init__(self, cache_dir: Optional[Union[str, Path]] = None): if cache_dir is None: # Use HuggingFace Spaces persistent storage if available if os.path.exists("/data"): cache_dir = "/data/cache" else: cache_dir = Path.home() / ".cache" / "digipal" self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) # Cache subdirectories self.model_cache_dir = self.cache_dir / "models" self.generation_cache_dir = self.cache_dir / "generations" self.audio_cache_dir = self.cache_dir / "audio" for dir_path in [self.model_cache_dir, self.generation_cache_dir, self.audio_cache_dir]: dir_path.mkdir(exist_ok=True) # Cache settings self.max_cache_size_gb = 10 # Maximum cache size in GB self.cache_expiry_days = 7 # Cache expiry in days self.generation_cache_enabled = True # In-memory cache for fast access self.memory_cache = {} self.cache_stats = self._load_cache_stats() def cache_model_weights(self, model_id: str, model_data: Any) -> bool: """Cache model weights to disk""" try: model_hash = self._get_hash(model_id) cache_path = self.model_cache_dir / f"{model_hash}.pkl" with open(cache_path, 'wb') as f: pickle.dump(model_data, f) # Update cache stats self._update_cache_stats('model', model_id, cache_path.stat().st_size) return True except Exception as e: print(f"Failed to cache model {model_id}: {e}") return False def get_cached_model(self, model_id: str) -> Optional[Any]: """Retrieve cached model weights""" try: model_hash = self._get_hash(model_id) cache_path = self.model_cache_dir / f"{model_hash}.pkl" if cache_path.exists(): # Check if cache is still valid if self._is_cache_valid(cache_path): with open(cache_path, 'rb') as f: return pickle.load(f) return None except Exception as e: print(f"Failed to load cached model {model_id}: {e}") return None def cache_generation(self, prompt: str, result: Dict[str, Any], generation_type: str = "monster") -> str: """Cache generation results""" if not self.generation_cache_enabled: return "" try: # Create unique key for this generation cache_key = self._get_generation_key(prompt, generation_type) cache_dir = self.generation_cache_dir / generation_type / cache_key[:2] cache_dir.mkdir(parents=True, exist_ok=True) cache_file = cache_dir / f"{cache_key}.json" # Prepare cache data cache_data = { 'prompt': prompt, 'type': generation_type, 'timestamp': datetime.now().isoformat(), 'result': result } # Handle file paths in results if 'image' in result and hasattr(result['image'], 'save'): image_path = cache_dir / f"{cache_key}_image.png" result['image'].save(image_path) cache_data['result']['image'] = str(image_path) if 'model_3d' in result and isinstance(result['model_3d'], str): # Copy 3D model to cache model_ext = Path(result['model_3d']).suffix model_cache_path = cache_dir / f"{cache_key}_model{model_ext}" shutil.copy2(result['model_3d'], model_cache_path) cache_data['result']['model_3d'] = str(model_cache_path) # Save cache data with open(cache_file, 'w') as f: json.dump(cache_data, f, indent=2) # Update stats self._update_cache_stats('generation', cache_key, cache_file.stat().st_size) return cache_key except Exception as e: print(f"Failed to cache generation: {e}") return "" def get_cached_generation(self, prompt: str, generation_type: str = "monster") -> Optional[Dict[str, Any]]: """Retrieve cached generation if available""" if not self.generation_cache_enabled: return None try: cache_key = self._get_generation_key(prompt, generation_type) cache_file = self.generation_cache_dir / generation_type / cache_key[:2] / f"{cache_key}.json" if cache_file.exists() and self._is_cache_valid(cache_file): with open(cache_file, 'r') as f: cache_data = json.load(f) # Load associated files result = cache_data['result'] if 'image' in result and isinstance(result['image'], str): from PIL import Image if os.path.exists(result['image']): result['image'] = Image.open(result['image']) return result return None except Exception as e: print(f"Failed to load cached generation: {e}") return None def cache_audio_transcription(self, audio_path: str, transcription: str) -> bool: """Cache audio transcription results""" try: # Get audio file hash with open(audio_path, 'rb') as f: audio_hash = hashlib.md5(f.read()).hexdigest() cache_file = self.audio_cache_dir / f"{audio_hash}.json" cache_data = { 'audio_path': audio_path, 'transcription': transcription, 'timestamp': datetime.now().isoformat() } with open(cache_file, 'w') as f: json.dump(cache_data, f) return True except Exception as e: print(f"Failed to cache audio transcription: {e}") return False def get_cached_transcription(self, audio_path: str) -> Optional[str]: """Get cached audio transcription""" try: with open(audio_path, 'rb') as f: audio_hash = hashlib.md5(f.read()).hexdigest() cache_file = self.audio_cache_dir / f"{audio_hash}.json" if cache_file.exists() and self._is_cache_valid(cache_file): with open(cache_file, 'r') as f: cache_data = json.load(f) return cache_data['transcription'] return None except Exception as e: print(f"Failed to load cached transcription: {e}") return None def add_to_memory_cache(self, key: str, value: Any, ttl_seconds: int = 300): """Add item to in-memory cache with TTL""" expiry_time = time.time() + ttl_seconds self.memory_cache[key] = { 'value': value, 'expiry': expiry_time } def get_from_memory_cache(self, key: str) -> Optional[Any]: """Get item from in-memory cache""" if key in self.memory_cache: cache_item = self.memory_cache[key] if time.time() < cache_item['expiry']: return cache_item['value'] else: # Remove expired item del self.memory_cache[key] return None def clear_expired_cache(self): """Clear expired cache entries""" current_time = datetime.now() cleared_size = 0 # Clear file cache for cache_type in [self.model_cache_dir, self.generation_cache_dir, self.audio_cache_dir]: for file_path in cache_type.rglob('*'): if file_path.is_file(): file_age = current_time - datetime.fromtimestamp(file_path.stat().st_mtime) if file_age > timedelta(days=self.cache_expiry_days): file_size = file_path.stat().st_size file_path.unlink() cleared_size += file_size # Clear memory cache expired_keys = [ key for key, item in self.memory_cache.items() if time.time() > item['expiry'] ] for key in expired_keys: del self.memory_cache[key] print(f"Cleared {cleared_size / (1024**2):.2f} MB of expired cache") return cleared_size def get_cache_size(self) -> Dict[str, float]: """Get current cache size in MB""" sizes = { 'models': 0, 'generations': 0, 'audio': 0, 'total': 0 } # Calculate directory sizes for file_path in self.model_cache_dir.rglob('*'): if file_path.is_file(): sizes['models'] += file_path.stat().st_size for file_path in self.generation_cache_dir.rglob('*'): if file_path.is_file(): sizes['generations'] += file_path.stat().st_size for file_path in self.audio_cache_dir.rglob('*'): if file_path.is_file(): sizes['audio'] += file_path.stat().st_size # Convert to MB for key in sizes: sizes[key] = sizes[key] / (1024 ** 2) sizes['total'] = sizes['models'] + sizes['generations'] + sizes['audio'] return sizes def enforce_size_limit(self): """Enforce cache size limit by removing oldest entries""" cache_size = self.get_cache_size() if cache_size['total'] > self.max_cache_size_gb * 1024: # Convert GB to MB # Get all cache files with timestamps all_files = [] for cache_dir in [self.model_cache_dir, self.generation_cache_dir, self.audio_cache_dir]: for file_path in cache_dir.rglob('*'): if file_path.is_file(): all_files.append({ 'path': file_path, 'size': file_path.stat().st_size, 'mtime': file_path.stat().st_mtime }) # Sort by modification time (oldest first) all_files.sort(key=lambda x: x['mtime']) # Remove files until under limit current_size = cache_size['total'] * (1024 ** 2) # Convert to bytes target_size = self.max_cache_size_gb * (1024 ** 3) * 0.8 # 80% of limit for file_info in all_files: if current_size <= target_size: break file_info['path'].unlink() current_size -= file_info['size'] print(f"Removed {file_info['path'].name} to enforce cache limit") def _get_hash(self, text: str) -> str: """Get MD5 hash of text""" return hashlib.md5(text.encode()).hexdigest() def _get_generation_key(self, prompt: str, generation_type: str) -> str: """Get unique key for generation cache""" combined = f"{generation_type}:{prompt}" return self._get_hash(combined) def _is_cache_valid(self, cache_path: Path) -> bool: """Check if cache file is still valid""" if not cache_path.exists(): return False file_age = datetime.now() - datetime.fromtimestamp(cache_path.stat().st_mtime) return file_age < timedelta(days=self.cache_expiry_days) def _load_cache_stats(self) -> Dict[str, Any]: """Load cache statistics""" stats_file = self.cache_dir / "cache_stats.json" if stats_file.exists(): with open(stats_file, 'r') as f: return json.load(f) return { 'total_hits': 0, 'total_misses': 0, 'last_cleanup': datetime.now().isoformat(), 'entries': {} } def _update_cache_stats(self, cache_type: str, key: str, size: int): """Update cache statistics""" self.cache_stats['entries'][key] = { 'type': cache_type, 'size': size, 'timestamp': datetime.now().isoformat() } # Save stats stats_file = self.cache_dir / "cache_stats.json" with open(stats_file, 'w') as f: json.dump(self.cache_stats, f, indent=2) def get_cache_info(self) -> Dict[str, Any]: """Get cache information and statistics""" sizes = self.get_cache_size() return { 'sizes': sizes, 'stats': self.cache_stats, 'memory_cache_items': len(self.memory_cache), 'cache_dir': str(self.cache_dir), 'max_size_gb': self.max_cache_size_gb, 'expiry_days': self.cache_expiry_days }