# Memory management system to prevent leaks and monitor usage import gc import os import psutil import torch import logging import threading import time from typing import Dict, Optional, Callable from dataclasses import dataclass, field from datetime import datetime import weakref import tempfile import shutil logger = logging.getLogger(__name__) @dataclass class MemoryStats: """Current memory statistics""" timestamp: float = field(default_factory=time.time) process_memory_mb: float = 0.0 system_memory_percent: float = 0.0 gpu_memory_mb: float = 0.0 gpu_memory_percent: float = 0.0 temp_files_count: int = 0 temp_files_size_mb: float = 0.0 active_sessions: int = 0 gc_collections: Dict[int, int] = field(default_factory=dict) class MemoryManager: """ Comprehensive memory management system to prevent leaks """ def __init__(self, app=None, config=None): self.config = config or {} self.app = app self._cleanup_callbacks = [] self._resource_registry = weakref.WeakValueDictionary() self._monitoring_thread = None self._shutdown = False # Memory thresholds self.memory_threshold_mb = self.config.get('memory_threshold_mb', 4096) # 4GB self.gpu_memory_threshold_mb = self.config.get('gpu_memory_threshold_mb', 2048) # 2GB self.cleanup_interval = self.config.get('cleanup_interval', 30) # 30 seconds # Whisper model reference self.whisper_model = None self.model_reload_count = 0 self.last_model_reload = time.time() if app: self.init_app(app) def init_app(self, app): """Initialize memory management for Flask app""" self.app = app app.memory_manager = self # Start monitoring thread self._start_monitoring() # Register cleanup on shutdown import atexit atexit.register(self.shutdown) logger.info("Memory manager initialized") def set_whisper_model(self, model): """Register the Whisper model for management""" self.whisper_model = model logger.info("Whisper model registered with memory manager") def _start_monitoring(self): """Start background memory monitoring""" self._monitoring_thread = threading.Thread( target=self._monitor_memory, daemon=True ) self._monitoring_thread.start() def _monitor_memory(self): """Background thread to monitor and manage memory""" logger.info("Memory monitoring thread started") while not self._shutdown: try: # Collect memory statistics stats = self.get_memory_stats() # Check if we need to free memory if self._should_cleanup(stats): logger.warning(f"Memory threshold exceeded - Process: {stats.process_memory_mb:.1f}MB, " f"GPU: {stats.gpu_memory_mb:.1f}MB") self.cleanup_memory(aggressive=True) # Log stats periodically if int(time.time()) % 300 == 0: # Every 5 minutes logger.info(f"Memory stats - Process: {stats.process_memory_mb:.1f}MB, " f"System: {stats.system_memory_percent:.1f}%, " f"GPU: {stats.gpu_memory_mb:.1f}MB") except Exception as e: logger.error(f"Error in memory monitoring: {e}") time.sleep(self.cleanup_interval) def _should_cleanup(self, stats: MemoryStats) -> bool: """Determine if memory cleanup is needed""" # Check process memory if stats.process_memory_mb > self.memory_threshold_mb: return True # Check system memory if stats.system_memory_percent > 85: return True # Check GPU memory if stats.gpu_memory_mb > self.gpu_memory_threshold_mb: return True return False def get_memory_stats(self) -> MemoryStats: """Get current memory statistics""" stats = MemoryStats() try: # Process memory process = psutil.Process() memory_info = process.memory_info() stats.process_memory_mb = memory_info.rss / 1024 / 1024 # System memory system_memory = psutil.virtual_memory() stats.system_memory_percent = system_memory.percent # GPU memory if available if torch.cuda.is_available(): stats.gpu_memory_mb = torch.cuda.memory_allocated() / 1024 / 1024 stats.gpu_memory_percent = (torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory * 100) # Temp files temp_dir = self.app.config.get('UPLOAD_FOLDER', tempfile.gettempdir()) if os.path.exists(temp_dir): temp_files = list(os.listdir(temp_dir)) stats.temp_files_count = len(temp_files) stats.temp_files_size_mb = sum( os.path.getsize(os.path.join(temp_dir, f)) for f in temp_files if os.path.isfile(os.path.join(temp_dir, f)) ) / 1024 / 1024 # Session count if hasattr(self.app, 'session_manager'): stats.active_sessions = len(self.app.session_manager.sessions) # GC stats gc_stats = gc.get_stats() for i, stat in enumerate(gc_stats): if isinstance(stat, dict): stats.gc_collections[i] = stat.get('collections', 0) except Exception as e: logger.error(f"Error collecting memory stats: {e}") return stats def cleanup_memory(self, aggressive=False): """Perform memory cleanup""" logger.info(f"Starting memory cleanup (aggressive={aggressive})") freed_mb = 0 try: # 1. Force garbage collection gc.collect() if aggressive: gc.collect(2) # Full collection # 2. Clear GPU memory cache if torch.cuda.is_available(): before_gpu = torch.cuda.memory_allocated() / 1024 / 1024 torch.cuda.empty_cache() torch.cuda.synchronize() after_gpu = torch.cuda.memory_allocated() / 1024 / 1024 freed_mb += (before_gpu - after_gpu) logger.info(f"Freed {before_gpu - after_gpu:.1f}MB GPU memory") # 3. Clean old temporary files if hasattr(self.app, 'config'): temp_dir = self.app.config.get('UPLOAD_FOLDER') if temp_dir and os.path.exists(temp_dir): freed_mb += self._cleanup_temp_files(temp_dir, aggressive) # 4. Trigger session cleanup if hasattr(self.app, 'session_manager'): self.app.session_manager.cleanup_expired_sessions() if aggressive: self.app.session_manager.cleanup_idle_sessions() # 5. Run registered cleanup callbacks for callback in self._cleanup_callbacks: try: callback() except Exception as e: logger.error(f"Cleanup callback error: {e}") # 6. Reload Whisper model if needed (aggressive mode only) if aggressive and self.whisper_model and torch.cuda.is_available(): current_gpu_mb = torch.cuda.memory_allocated() / 1024 / 1024 if current_gpu_mb > self.gpu_memory_threshold_mb * 0.8: self._reload_whisper_model() logger.info(f"Memory cleanup completed - freed approximately {freed_mb:.1f}MB") except Exception as e: logger.error(f"Error during memory cleanup: {e}") def _cleanup_temp_files(self, temp_dir: str, aggressive: bool) -> float: """Clean up temporary files""" freed_mb = 0 current_time = time.time() max_age = 300 if not aggressive else 60 # 5 minutes or 1 minute try: for filename in os.listdir(temp_dir): filepath = os.path.join(temp_dir, filename) if os.path.isfile(filepath): file_age = current_time - os.path.getmtime(filepath) if file_age > max_age: file_size = os.path.getsize(filepath) / 1024 / 1024 try: os.remove(filepath) freed_mb += file_size logger.debug(f"Removed old temp file: {filename}") except Exception as e: logger.error(f"Failed to remove {filepath}: {e}") except Exception as e: logger.error(f"Error cleaning temp files: {e}") return freed_mb def _reload_whisper_model(self): """Reload Whisper model to clear GPU memory fragmentation""" if not self.whisper_model: return # Don't reload too frequently if time.time() - self.last_model_reload < 300: # 5 minutes return try: logger.info("Reloading Whisper model to clear GPU memory") # Get model info import whisper model_size = getattr(self.whisper_model, 'model_size', 'base') device = next(self.whisper_model.parameters()).device # Clear the old model del self.whisper_model torch.cuda.empty_cache() gc.collect() # Reload model self.whisper_model = whisper.load_model(model_size, device=device) self.model_reload_count += 1 self.last_model_reload = time.time() # Update app reference if hasattr(self.app, 'whisper_model'): self.app.whisper_model = self.whisper_model logger.info(f"Whisper model reloaded successfully (reload #{self.model_reload_count})") except Exception as e: logger.error(f"Failed to reload Whisper model: {e}") def register_cleanup_callback(self, callback: Callable): """Register a callback to be called during cleanup""" self._cleanup_callbacks.append(callback) def register_resource(self, resource, name: str = None): """Register a resource for tracking""" if name: self._resource_registry[name] = resource def release_resource(self, name: str): """Release a tracked resource""" if name in self._resource_registry: del self._resource_registry[name] def get_metrics(self) -> Dict: """Get memory management metrics""" stats = self.get_memory_stats() return { 'memory': { 'process_mb': round(stats.process_memory_mb, 1), 'system_percent': round(stats.system_memory_percent, 1), 'gpu_mb': round(stats.gpu_memory_mb, 1), 'gpu_percent': round(stats.gpu_memory_percent, 1) }, 'temp_files': { 'count': stats.temp_files_count, 'size_mb': round(stats.temp_files_size_mb, 1) }, 'sessions': { 'active': stats.active_sessions }, 'model': { 'reload_count': self.model_reload_count, 'last_reload': datetime.fromtimestamp(self.last_model_reload).isoformat() }, 'thresholds': { 'memory_mb': self.memory_threshold_mb, 'gpu_mb': self.gpu_memory_threshold_mb } } def shutdown(self): """Shutdown memory manager""" logger.info("Shutting down memory manager") self._shutdown = True # Final cleanup self.cleanup_memory(aggressive=True) # Wait for monitoring thread if self._monitoring_thread: self._monitoring_thread.join(timeout=5) # Context manager for audio processing class AudioProcessingContext: """Context manager to ensure audio resources are cleaned up""" def __init__(self, memory_manager: MemoryManager, name: str = None): self.memory_manager = memory_manager self.name = name or f"audio_{int(time.time() * 1000)}" self.temp_files = [] self.start_time = None self.start_memory = None def __enter__(self): self.start_time = time.time() if torch.cuda.is_available(): self.start_memory = torch.cuda.memory_allocated() return self def __exit__(self, exc_type, exc_val, exc_tb): # Clean up temp files for filepath in self.temp_files: try: if os.path.exists(filepath): os.remove(filepath) except Exception as e: logger.error(f"Failed to remove temp file {filepath}: {e}") # Clear GPU cache if used if torch.cuda.is_available(): torch.cuda.empty_cache() # Log memory usage if self.start_memory is not None: memory_used = torch.cuda.memory_allocated() - self.start_memory duration = time.time() - self.start_time logger.debug(f"Audio processing '{self.name}' - Duration: {duration:.2f}s, " f"GPU memory: {memory_used / 1024 / 1024:.1f}MB") # Force garbage collection if there was an error if exc_type is not None: gc.collect() def add_temp_file(self, filepath: str): """Register a temporary file for cleanup""" self.temp_files.append(filepath) # Utility functions def with_memory_management(func): """Decorator to add memory management to functions""" def wrapper(*args, **kwargs): # Get memory manager from app context from flask import current_app memory_manager = getattr(current_app, 'memory_manager', None) if memory_manager: with AudioProcessingContext(memory_manager, name=func.__name__): return func(*args, **kwargs) else: return func(*args, **kwargs) return wrapper def init_memory_management(app, **kwargs): """Initialize memory management for the application""" config = { 'memory_threshold_mb': kwargs.get('memory_threshold_mb', 4096), 'gpu_memory_threshold_mb': kwargs.get('gpu_memory_threshold_mb', 2048), 'cleanup_interval': kwargs.get('cleanup_interval', 30) } memory_manager = MemoryManager(app, config) return memory_manager