"""Memory management for Talk2Me application""" import gc import logging import psutil import torch import os import time from contextlib import contextmanager from functools import wraps from dataclasses import dataclass from typing import Optional, Dict, Any import threading logger = logging.getLogger(__name__) @dataclass class MemoryStats: """Memory statistics""" process_memory_mb: float available_memory_mb: float memory_percent: float gpu_memory_mb: float = 0.0 gpu_memory_percent: float = 0.0 class MemoryManager: """Manage memory usage for the application""" def __init__(self, app=None, config: Optional[Dict[str, Any]] = None): self.app = app self.config = config or {} self.memory_threshold_mb = self.config.get('memory_threshold_mb', 4096) self.gpu_memory_threshold_mb = self.config.get('gpu_memory_threshold_mb', 2048) self.cleanup_interval = self.config.get('cleanup_interval', 30) self.whisper_model = None self._cleanup_thread = None self._stop_cleanup = threading.Event() if app: self.init_app(app) def init_app(self, app): """Initialize with Flask app""" self.app = app app.memory_manager = self # Start cleanup thread self._start_cleanup_thread() logger.info(f"Memory manager initialized with thresholds: " f"Process={self.memory_threshold_mb}MB, " f"GPU={self.gpu_memory_threshold_mb}MB") def set_whisper_model(self, model): """Set reference to Whisper model for memory management""" self.whisper_model = model def get_memory_stats(self) -> MemoryStats: """Get current memory statistics""" process = psutil.Process() memory_info = process.memory_info() stats = MemoryStats( process_memory_mb=memory_info.rss / 1024 / 1024, available_memory_mb=psutil.virtual_memory().available / 1024 / 1024, memory_percent=process.memory_percent() ) # Check GPU memory if available if torch.cuda.is_available(): try: stats.gpu_memory_mb = torch.cuda.memory_allocated() / 1024 / 1024 stats.gpu_memory_percent = (torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory * 100) except Exception as e: logger.error(f"Error getting GPU memory stats: {e}") return stats def check_memory_pressure(self) -> bool: """Check if system is under memory pressure""" stats = self.get_memory_stats() # Check process memory if stats.process_memory_mb > self.memory_threshold_mb: logger.warning(f"High process memory usage: {stats.process_memory_mb:.1f}MB") return True # Check system memory if stats.memory_percent > 80: logger.warning(f"High system memory usage: {stats.memory_percent:.1f}%") return True # Check GPU memory if stats.gpu_memory_mb > self.gpu_memory_threshold_mb: logger.warning(f"High GPU memory usage: {stats.gpu_memory_mb:.1f}MB") return True return False def cleanup_memory(self, aggressive: bool = False): """Clean up memory""" logger.info("Starting memory cleanup...") # Run garbage collection collected = gc.collect() logger.info(f"Garbage collector: collected {collected} objects") # Clear GPU cache if available if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() logger.info("Cleared GPU cache") if aggressive: # Force garbage collection of all generations for i in range(3): gc.collect(i) # Clear Whisper model cache if needed if self.whisper_model and hasattr(self.whisper_model, 'clear_cache'): self.whisper_model.clear_cache() logger.info("Cleared Whisper model cache") def _cleanup_worker(self): """Background cleanup worker""" while not self._stop_cleanup.wait(self.cleanup_interval): try: if self.check_memory_pressure(): self.cleanup_memory(aggressive=True) else: # Light cleanup gc.collect(0) if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: logger.error(f"Error in memory cleanup worker: {e}") def _start_cleanup_thread(self): """Start background cleanup thread""" if self._cleanup_thread and self._cleanup_thread.is_alive(): return self._stop_cleanup.clear() self._cleanup_thread = threading.Thread(target=self._cleanup_worker, daemon=True) self._cleanup_thread.start() logger.info("Started memory cleanup thread") def stop(self): """Stop memory manager""" self._stop_cleanup.set() if self._cleanup_thread: self._cleanup_thread.join(timeout=5) def get_metrics(self) -> Dict[str, Any]: """Get memory metrics for monitoring""" stats = self.get_memory_stats() return { 'process_memory_mb': round(stats.process_memory_mb, 2), 'available_memory_mb': round(stats.available_memory_mb, 2), 'memory_percent': round(stats.memory_percent, 2), 'gpu_memory_mb': round(stats.gpu_memory_mb, 2), 'gpu_memory_percent': round(stats.gpu_memory_percent, 2), 'thresholds': { 'process_mb': self.memory_threshold_mb, 'gpu_mb': self.gpu_memory_threshold_mb }, 'under_pressure': self.check_memory_pressure() } class AudioProcessingContext: """Context manager for audio processing with memory management""" def __init__(self, memory_manager: MemoryManager, name: str = "audio_processing"): self.memory_manager = memory_manager self.name = name self.temp_files = [] self.start_time = None def __enter__(self): self.start_time = time.time() # Check memory before processing if self.memory_manager and self.memory_manager.check_memory_pressure(): logger.warning(f"Memory pressure detected before {self.name}") self.memory_manager.cleanup_memory() return self def __exit__(self, exc_type, exc_val, exc_tb): # Clean up temporary files for temp_file in self.temp_files: try: if os.path.exists(temp_file): os.remove(temp_file) except Exception as e: logger.error(f"Failed to remove temp file {temp_file}: {e}") # Clean up memory after processing if self.memory_manager: gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() duration = time.time() - self.start_time logger.info(f"{self.name} completed in {duration:.2f}s") def add_temp_file(self, filepath: str): """Add a temporary file to be cleaned up""" self.temp_files.append(filepath) def with_memory_management(func): """Decorator to add memory management to functions""" @wraps(func) def wrapper(*args, **kwargs): # Get memory manager from app context from flask import current_app memory_manager = getattr(current_app, 'memory_manager', None) if memory_manager: # Check memory before if memory_manager.check_memory_pressure(): logger.warning(f"Memory pressure before {func.__name__}") memory_manager.cleanup_memory() try: result = func(*args, **kwargs) return result finally: # Light cleanup after gc.collect(0) if torch.cuda.is_available(): torch.cuda.empty_cache() return wrapper