Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import json | |
import time | |
import hashlib | |
from pathlib import Path | |
from typing import Any, Optional, Dict, Union | |
import pickle | |
import shutil | |
from datetime import datetime, timedelta | |
class ModelCache: | |
"""Manages caching for AI models and generated content""" | |
def __init__(self, cache_dir: Optional[Union[str, Path]] = None): | |
if cache_dir is None: | |
# Use HuggingFace Spaces persistent storage if available | |
if os.path.exists("/data"): | |
cache_dir = "/data/cache" | |
else: | |
cache_dir = Path.home() / ".cache" / "digipal" | |
self.cache_dir = Path(cache_dir) | |
self.cache_dir.mkdir(parents=True, exist_ok=True) | |
# Cache subdirectories | |
self.model_cache_dir = self.cache_dir / "models" | |
self.generation_cache_dir = self.cache_dir / "generations" | |
self.audio_cache_dir = self.cache_dir / "audio" | |
for dir_path in [self.model_cache_dir, self.generation_cache_dir, self.audio_cache_dir]: | |
dir_path.mkdir(exist_ok=True) | |
# Cache settings | |
self.max_cache_size_gb = 10 # Maximum cache size in GB | |
self.cache_expiry_days = 7 # Cache expiry in days | |
self.generation_cache_enabled = True | |
# In-memory cache for fast access | |
self.memory_cache = {} | |
self.cache_stats = self._load_cache_stats() | |
def cache_model_weights(self, model_id: str, model_data: Any) -> bool: | |
"""Cache model weights to disk""" | |
try: | |
model_hash = self._get_hash(model_id) | |
cache_path = self.model_cache_dir / f"{model_hash}.pkl" | |
with open(cache_path, 'wb') as f: | |
pickle.dump(model_data, f) | |
# Update cache stats | |
self._update_cache_stats('model', model_id, cache_path.stat().st_size) | |
return True | |
except Exception as e: | |
print(f"Failed to cache model {model_id}: {e}") | |
return False | |
def get_cached_model(self, model_id: str) -> Optional[Any]: | |
"""Retrieve cached model weights""" | |
try: | |
model_hash = self._get_hash(model_id) | |
cache_path = self.model_cache_dir / f"{model_hash}.pkl" | |
if cache_path.exists(): | |
# Check if cache is still valid | |
if self._is_cache_valid(cache_path): | |
with open(cache_path, 'rb') as f: | |
return pickle.load(f) | |
return None | |
except Exception as e: | |
print(f"Failed to load cached model {model_id}: {e}") | |
return None | |
def cache_generation(self, prompt: str, result: Dict[str, Any], | |
generation_type: str = "monster") -> str: | |
"""Cache generation results""" | |
if not self.generation_cache_enabled: | |
return "" | |
try: | |
# Create unique key for this generation | |
cache_key = self._get_generation_key(prompt, generation_type) | |
cache_dir = self.generation_cache_dir / generation_type / cache_key[:2] | |
cache_dir.mkdir(parents=True, exist_ok=True) | |
cache_file = cache_dir / f"{cache_key}.json" | |
# Prepare cache data | |
cache_data = { | |
'prompt': prompt, | |
'type': generation_type, | |
'timestamp': datetime.now().isoformat(), | |
'result': result | |
} | |
# Handle file paths in results | |
if 'image' in result and hasattr(result['image'], 'save'): | |
image_path = cache_dir / f"{cache_key}_image.png" | |
result['image'].save(image_path) | |
cache_data['result']['image'] = str(image_path) | |
if 'model_3d' in result and isinstance(result['model_3d'], str): | |
# Copy 3D model to cache | |
model_ext = Path(result['model_3d']).suffix | |
model_cache_path = cache_dir / f"{cache_key}_model{model_ext}" | |
shutil.copy2(result['model_3d'], model_cache_path) | |
cache_data['result']['model_3d'] = str(model_cache_path) | |
# Save cache data | |
with open(cache_file, 'w') as f: | |
json.dump(cache_data, f, indent=2) | |
# Update stats | |
self._update_cache_stats('generation', cache_key, cache_file.stat().st_size) | |
return cache_key | |
except Exception as e: | |
print(f"Failed to cache generation: {e}") | |
return "" | |
def get_cached_generation(self, prompt: str, generation_type: str = "monster") -> Optional[Dict[str, Any]]: | |
"""Retrieve cached generation if available""" | |
if not self.generation_cache_enabled: | |
return None | |
try: | |
cache_key = self._get_generation_key(prompt, generation_type) | |
cache_file = self.generation_cache_dir / generation_type / cache_key[:2] / f"{cache_key}.json" | |
if cache_file.exists() and self._is_cache_valid(cache_file): | |
with open(cache_file, 'r') as f: | |
cache_data = json.load(f) | |
# Load associated files | |
result = cache_data['result'] | |
if 'image' in result and isinstance(result['image'], str): | |
from PIL import Image | |
if os.path.exists(result['image']): | |
result['image'] = Image.open(result['image']) | |
return result | |
return None | |
except Exception as e: | |
print(f"Failed to load cached generation: {e}") | |
return None | |
def cache_audio_transcription(self, audio_path: str, transcription: str) -> bool: | |
"""Cache audio transcription results""" | |
try: | |
# Get audio file hash | |
with open(audio_path, 'rb') as f: | |
audio_hash = hashlib.md5(f.read()).hexdigest() | |
cache_file = self.audio_cache_dir / f"{audio_hash}.json" | |
cache_data = { | |
'audio_path': audio_path, | |
'transcription': transcription, | |
'timestamp': datetime.now().isoformat() | |
} | |
with open(cache_file, 'w') as f: | |
json.dump(cache_data, f) | |
return True | |
except Exception as e: | |
print(f"Failed to cache audio transcription: {e}") | |
return False | |
def get_cached_transcription(self, audio_path: str) -> Optional[str]: | |
"""Get cached audio transcription""" | |
try: | |
with open(audio_path, 'rb') as f: | |
audio_hash = hashlib.md5(f.read()).hexdigest() | |
cache_file = self.audio_cache_dir / f"{audio_hash}.json" | |
if cache_file.exists() and self._is_cache_valid(cache_file): | |
with open(cache_file, 'r') as f: | |
cache_data = json.load(f) | |
return cache_data['transcription'] | |
return None | |
except Exception as e: | |
print(f"Failed to load cached transcription: {e}") | |
return None | |
def add_to_memory_cache(self, key: str, value: Any, ttl_seconds: int = 300): | |
"""Add item to in-memory cache with TTL""" | |
expiry_time = time.time() + ttl_seconds | |
self.memory_cache[key] = { | |
'value': value, | |
'expiry': expiry_time | |
} | |
def get_from_memory_cache(self, key: str) -> Optional[Any]: | |
"""Get item from in-memory cache""" | |
if key in self.memory_cache: | |
cache_item = self.memory_cache[key] | |
if time.time() < cache_item['expiry']: | |
return cache_item['value'] | |
else: | |
# Remove expired item | |
del self.memory_cache[key] | |
return None | |
def clear_expired_cache(self): | |
"""Clear expired cache entries""" | |
current_time = datetime.now() | |
cleared_size = 0 | |
# Clear file cache | |
for cache_type in [self.model_cache_dir, self.generation_cache_dir, self.audio_cache_dir]: | |
for file_path in cache_type.rglob('*'): | |
if file_path.is_file(): | |
file_age = current_time - datetime.fromtimestamp(file_path.stat().st_mtime) | |
if file_age > timedelta(days=self.cache_expiry_days): | |
file_size = file_path.stat().st_size | |
file_path.unlink() | |
cleared_size += file_size | |
# Clear memory cache | |
expired_keys = [ | |
key for key, item in self.memory_cache.items() | |
if time.time() > item['expiry'] | |
] | |
for key in expired_keys: | |
del self.memory_cache[key] | |
print(f"Cleared {cleared_size / (1024**2):.2f} MB of expired cache") | |
return cleared_size | |
def get_cache_size(self) -> Dict[str, float]: | |
"""Get current cache size in MB""" | |
sizes = { | |
'models': 0, | |
'generations': 0, | |
'audio': 0, | |
'total': 0 | |
} | |
# Calculate directory sizes | |
for file_path in self.model_cache_dir.rglob('*'): | |
if file_path.is_file(): | |
sizes['models'] += file_path.stat().st_size | |
for file_path in self.generation_cache_dir.rglob('*'): | |
if file_path.is_file(): | |
sizes['generations'] += file_path.stat().st_size | |
for file_path in self.audio_cache_dir.rglob('*'): | |
if file_path.is_file(): | |
sizes['audio'] += file_path.stat().st_size | |
# Convert to MB | |
for key in sizes: | |
sizes[key] = sizes[key] / (1024 ** 2) | |
sizes['total'] = sizes['models'] + sizes['generations'] + sizes['audio'] | |
return sizes | |
def enforce_size_limit(self): | |
"""Enforce cache size limit by removing oldest entries""" | |
cache_size = self.get_cache_size() | |
if cache_size['total'] > self.max_cache_size_gb * 1024: # Convert GB to MB | |
# Get all cache files with timestamps | |
all_files = [] | |
for cache_dir in [self.model_cache_dir, self.generation_cache_dir, self.audio_cache_dir]: | |
for file_path in cache_dir.rglob('*'): | |
if file_path.is_file(): | |
all_files.append({ | |
'path': file_path, | |
'size': file_path.stat().st_size, | |
'mtime': file_path.stat().st_mtime | |
}) | |
# Sort by modification time (oldest first) | |
all_files.sort(key=lambda x: x['mtime']) | |
# Remove files until under limit | |
current_size = cache_size['total'] * (1024 ** 2) # Convert to bytes | |
target_size = self.max_cache_size_gb * (1024 ** 3) * 0.8 # 80% of limit | |
for file_info in all_files: | |
if current_size <= target_size: | |
break | |
file_info['path'].unlink() | |
current_size -= file_info['size'] | |
print(f"Removed {file_info['path'].name} to enforce cache limit") | |
def _get_hash(self, text: str) -> str: | |
"""Get MD5 hash of text""" | |
return hashlib.md5(text.encode()).hexdigest() | |
def _get_generation_key(self, prompt: str, generation_type: str) -> str: | |
"""Get unique key for generation cache""" | |
combined = f"{generation_type}:{prompt}" | |
return self._get_hash(combined) | |
def _is_cache_valid(self, cache_path: Path) -> bool: | |
"""Check if cache file is still valid""" | |
if not cache_path.exists(): | |
return False | |
file_age = datetime.now() - datetime.fromtimestamp(cache_path.stat().st_mtime) | |
return file_age < timedelta(days=self.cache_expiry_days) | |
def _load_cache_stats(self) -> Dict[str, Any]: | |
"""Load cache statistics""" | |
stats_file = self.cache_dir / "cache_stats.json" | |
if stats_file.exists(): | |
with open(stats_file, 'r') as f: | |
return json.load(f) | |
return { | |
'total_hits': 0, | |
'total_misses': 0, | |
'last_cleanup': datetime.now().isoformat(), | |
'entries': {} | |
} | |
def _update_cache_stats(self, cache_type: str, key: str, size: int): | |
"""Update cache statistics""" | |
self.cache_stats['entries'][key] = { | |
'type': cache_type, | |
'size': size, | |
'timestamp': datetime.now().isoformat() | |
} | |
# Save stats | |
stats_file = self.cache_dir / "cache_stats.json" | |
with open(stats_file, 'w') as f: | |
json.dump(self.cache_stats, f, indent=2) | |
def get_cache_info(self) -> Dict[str, Any]: | |
"""Get cache information and statistics""" | |
sizes = self.get_cache_size() | |
return { | |
'sizes': sizes, | |
'stats': self.cache_stats, | |
'memory_cache_items': len(self.memory_cache), | |
'cache_dir': str(self.cache_dir), | |
'max_size_gb': self.max_cache_size_gb, | |
'expiry_days': self.cache_expiry_days | |
} |