talk2me/memory_manager.py
Adolfo Delorenzo 1b9ad03400 Fix potential memory leaks in audio handling - Can crash server after extended use
This comprehensive fix addresses memory leaks in both backend and frontend that could cause server crashes after extended use.

Backend fixes:
- MemoryManager class monitors process and GPU memory usage
- Automatic cleanup when thresholds exceeded (4GB process, 2GB GPU)
- Whisper model reloading to clear GPU memory fragmentation
- Aggressive temporary file cleanup based on age
- Context manager for audio processing with guaranteed cleanup
- Integration with session manager for resource tracking
- Background monitoring thread runs every 30 seconds

Frontend fixes:
- MemoryManager singleton tracks all browser resources
- SafeMediaRecorder wrapper ensures stream cleanup
- AudioBlobHandler manages blob lifecycle and object URLs
- Automatic cleanup of closed AudioContexts
- Proper MediaStream track stopping
- Periodic cleanup of orphaned resources
- Cleanup on page unload

Admin features:
- GET /admin/memory - View memory statistics
- POST /admin/memory/cleanup - Trigger manual cleanup
- Real-time metrics including GPU usage and temp files
- Model reload tracking

Key improvements:
- AudioContext properly closed after use
- Object URLs revoked after use
- MediaRecorder streams properly stopped
- Audio chunks cleared after processing
- GPU cache cleared after each transcription
- Temp files tracked and cleaned aggressively

This prevents the gradual memory increase that could lead to out-of-memory errors or performance degradation after hours of use.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-06-03 08:37:13 -06:00

401 lines
15 KiB
Python

# Memory management system to prevent leaks and monitor usage
import gc
import os
import psutil
import torch
import logging
import threading
import time
from typing import Dict, Optional, Callable
from dataclasses import dataclass, field
from datetime import datetime
import weakref
import tempfile
import shutil
logger = logging.getLogger(__name__)
@dataclass
class MemoryStats:
"""Current memory statistics"""
timestamp: float = field(default_factory=time.time)
process_memory_mb: float = 0.0
system_memory_percent: float = 0.0
gpu_memory_mb: float = 0.0
gpu_memory_percent: float = 0.0
temp_files_count: int = 0
temp_files_size_mb: float = 0.0
active_sessions: int = 0
gc_collections: Dict[int, int] = field(default_factory=dict)
class MemoryManager:
"""
Comprehensive memory management system to prevent leaks
"""
def __init__(self, app=None, config=None):
self.config = config or {}
self.app = app
self._cleanup_callbacks = []
self._resource_registry = weakref.WeakValueDictionary()
self._monitoring_thread = None
self._shutdown = False
# Memory thresholds
self.memory_threshold_mb = self.config.get('memory_threshold_mb', 4096) # 4GB
self.gpu_memory_threshold_mb = self.config.get('gpu_memory_threshold_mb', 2048) # 2GB
self.cleanup_interval = self.config.get('cleanup_interval', 30) # 30 seconds
# Whisper model reference
self.whisper_model = None
self.model_reload_count = 0
self.last_model_reload = time.time()
if app:
self.init_app(app)
def init_app(self, app):
"""Initialize memory management for Flask app"""
self.app = app
app.memory_manager = self
# Start monitoring thread
self._start_monitoring()
# Register cleanup on shutdown
import atexit
atexit.register(self.shutdown)
logger.info("Memory manager initialized")
def set_whisper_model(self, model):
"""Register the Whisper model for management"""
self.whisper_model = model
logger.info("Whisper model registered with memory manager")
def _start_monitoring(self):
"""Start background memory monitoring"""
self._monitoring_thread = threading.Thread(
target=self._monitor_memory,
daemon=True
)
self._monitoring_thread.start()
def _monitor_memory(self):
"""Background thread to monitor and manage memory"""
logger.info("Memory monitoring thread started")
while not self._shutdown:
try:
# Collect memory statistics
stats = self.get_memory_stats()
# Check if we need to free memory
if self._should_cleanup(stats):
logger.warning(f"Memory threshold exceeded - Process: {stats.process_memory_mb:.1f}MB, "
f"GPU: {stats.gpu_memory_mb:.1f}MB")
self.cleanup_memory(aggressive=True)
# Log stats periodically
if int(time.time()) % 300 == 0: # Every 5 minutes
logger.info(f"Memory stats - Process: {stats.process_memory_mb:.1f}MB, "
f"System: {stats.system_memory_percent:.1f}%, "
f"GPU: {stats.gpu_memory_mb:.1f}MB")
except Exception as e:
logger.error(f"Error in memory monitoring: {e}")
time.sleep(self.cleanup_interval)
def _should_cleanup(self, stats: MemoryStats) -> bool:
"""Determine if memory cleanup is needed"""
# Check process memory
if stats.process_memory_mb > self.memory_threshold_mb:
return True
# Check system memory
if stats.system_memory_percent > 85:
return True
# Check GPU memory
if stats.gpu_memory_mb > self.gpu_memory_threshold_mb:
return True
return False
def get_memory_stats(self) -> MemoryStats:
"""Get current memory statistics"""
stats = MemoryStats()
try:
# Process memory
process = psutil.Process()
memory_info = process.memory_info()
stats.process_memory_mb = memory_info.rss / 1024 / 1024
# System memory
system_memory = psutil.virtual_memory()
stats.system_memory_percent = system_memory.percent
# GPU memory if available
if torch.cuda.is_available():
stats.gpu_memory_mb = torch.cuda.memory_allocated() / 1024 / 1024
stats.gpu_memory_percent = (torch.cuda.memory_allocated() /
torch.cuda.get_device_properties(0).total_memory * 100)
# Temp files
temp_dir = self.app.config.get('UPLOAD_FOLDER', tempfile.gettempdir())
if os.path.exists(temp_dir):
temp_files = list(os.listdir(temp_dir))
stats.temp_files_count = len(temp_files)
stats.temp_files_size_mb = sum(
os.path.getsize(os.path.join(temp_dir, f))
for f in temp_files if os.path.isfile(os.path.join(temp_dir, f))
) / 1024 / 1024
# Session count
if hasattr(self.app, 'session_manager'):
stats.active_sessions = len(self.app.session_manager.sessions)
# GC stats
for i in range(gc.get_count()):
stats.gc_collections[i] = gc.get_stats()[i].get('collections', 0)
except Exception as e:
logger.error(f"Error collecting memory stats: {e}")
return stats
def cleanup_memory(self, aggressive=False):
"""Perform memory cleanup"""
logger.info(f"Starting memory cleanup (aggressive={aggressive})")
freed_mb = 0
try:
# 1. Force garbage collection
gc.collect()
if aggressive:
gc.collect(2) # Full collection
# 2. Clear GPU memory cache
if torch.cuda.is_available():
before_gpu = torch.cuda.memory_allocated() / 1024 / 1024
torch.cuda.empty_cache()
torch.cuda.synchronize()
after_gpu = torch.cuda.memory_allocated() / 1024 / 1024
freed_mb += (before_gpu - after_gpu)
logger.info(f"Freed {before_gpu - after_gpu:.1f}MB GPU memory")
# 3. Clean old temporary files
if hasattr(self.app, 'config'):
temp_dir = self.app.config.get('UPLOAD_FOLDER')
if temp_dir and os.path.exists(temp_dir):
freed_mb += self._cleanup_temp_files(temp_dir, aggressive)
# 4. Trigger session cleanup
if hasattr(self.app, 'session_manager'):
self.app.session_manager.cleanup_expired_sessions()
if aggressive:
self.app.session_manager.cleanup_idle_sessions()
# 5. Run registered cleanup callbacks
for callback in self._cleanup_callbacks:
try:
callback()
except Exception as e:
logger.error(f"Cleanup callback error: {e}")
# 6. Reload Whisper model if needed (aggressive mode only)
if aggressive and self.whisper_model and torch.cuda.is_available():
current_gpu_mb = torch.cuda.memory_allocated() / 1024 / 1024
if current_gpu_mb > self.gpu_memory_threshold_mb * 0.8:
self._reload_whisper_model()
logger.info(f"Memory cleanup completed - freed approximately {freed_mb:.1f}MB")
except Exception as e:
logger.error(f"Error during memory cleanup: {e}")
def _cleanup_temp_files(self, temp_dir: str, aggressive: bool) -> float:
"""Clean up temporary files"""
freed_mb = 0
current_time = time.time()
max_age = 300 if not aggressive else 60 # 5 minutes or 1 minute
try:
for filename in os.listdir(temp_dir):
filepath = os.path.join(temp_dir, filename)
if os.path.isfile(filepath):
file_age = current_time - os.path.getmtime(filepath)
if file_age > max_age:
file_size = os.path.getsize(filepath) / 1024 / 1024
try:
os.remove(filepath)
freed_mb += file_size
logger.debug(f"Removed old temp file: {filename}")
except Exception as e:
logger.error(f"Failed to remove {filepath}: {e}")
except Exception as e:
logger.error(f"Error cleaning temp files: {e}")
return freed_mb
def _reload_whisper_model(self):
"""Reload Whisper model to clear GPU memory fragmentation"""
if not self.whisper_model:
return
# Don't reload too frequently
if time.time() - self.last_model_reload < 300: # 5 minutes
return
try:
logger.info("Reloading Whisper model to clear GPU memory")
# Get model info
import whisper
model_size = getattr(self.whisper_model, 'model_size', 'base')
device = next(self.whisper_model.parameters()).device
# Clear the old model
del self.whisper_model
torch.cuda.empty_cache()
gc.collect()
# Reload model
self.whisper_model = whisper.load_model(model_size, device=device)
self.model_reload_count += 1
self.last_model_reload = time.time()
# Update app reference
if hasattr(self.app, 'whisper_model'):
self.app.whisper_model = self.whisper_model
logger.info(f"Whisper model reloaded successfully (reload #{self.model_reload_count})")
except Exception as e:
logger.error(f"Failed to reload Whisper model: {e}")
def register_cleanup_callback(self, callback: Callable):
"""Register a callback to be called during cleanup"""
self._cleanup_callbacks.append(callback)
def register_resource(self, resource, name: str = None):
"""Register a resource for tracking"""
if name:
self._resource_registry[name] = resource
def release_resource(self, name: str):
"""Release a tracked resource"""
if name in self._resource_registry:
del self._resource_registry[name]
def get_metrics(self) -> Dict:
"""Get memory management metrics"""
stats = self.get_memory_stats()
return {
'memory': {
'process_mb': round(stats.process_memory_mb, 1),
'system_percent': round(stats.system_memory_percent, 1),
'gpu_mb': round(stats.gpu_memory_mb, 1),
'gpu_percent': round(stats.gpu_memory_percent, 1)
},
'temp_files': {
'count': stats.temp_files_count,
'size_mb': round(stats.temp_files_size_mb, 1)
},
'sessions': {
'active': stats.active_sessions
},
'model': {
'reload_count': self.model_reload_count,
'last_reload': datetime.fromtimestamp(self.last_model_reload).isoformat()
},
'thresholds': {
'memory_mb': self.memory_threshold_mb,
'gpu_mb': self.gpu_memory_threshold_mb
}
}
def shutdown(self):
"""Shutdown memory manager"""
logger.info("Shutting down memory manager")
self._shutdown = True
# Final cleanup
self.cleanup_memory(aggressive=True)
# Wait for monitoring thread
if self._monitoring_thread:
self._monitoring_thread.join(timeout=5)
# Context manager for audio processing
class AudioProcessingContext:
"""Context manager to ensure audio resources are cleaned up"""
def __init__(self, memory_manager: MemoryManager, name: str = None):
self.memory_manager = memory_manager
self.name = name or f"audio_{int(time.time() * 1000)}"
self.temp_files = []
self.start_time = None
self.start_memory = None
def __enter__(self):
self.start_time = time.time()
if torch.cuda.is_available():
self.start_memory = torch.cuda.memory_allocated()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Clean up temp files
for filepath in self.temp_files:
try:
if os.path.exists(filepath):
os.remove(filepath)
except Exception as e:
logger.error(f"Failed to remove temp file {filepath}: {e}")
# Clear GPU cache if used
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Log memory usage
if self.start_memory is not None:
memory_used = torch.cuda.memory_allocated() - self.start_memory
duration = time.time() - self.start_time
logger.debug(f"Audio processing '{self.name}' - Duration: {duration:.2f}s, "
f"GPU memory: {memory_used / 1024 / 1024:.1f}MB")
# Force garbage collection if there was an error
if exc_type is not None:
gc.collect()
def add_temp_file(self, filepath: str):
"""Register a temporary file for cleanup"""
self.temp_files.append(filepath)
# Utility functions
def with_memory_management(func):
"""Decorator to add memory management to functions"""
def wrapper(*args, **kwargs):
# Get memory manager from app context
from flask import current_app
memory_manager = getattr(current_app, 'memory_manager', None)
if memory_manager:
with AudioProcessingContext(memory_manager, name=func.__name__):
return func(*args, **kwargs)
else:
return func(*args, **kwargs)
return wrapper
def init_memory_management(app, **kwargs):
"""Initialize memory management for the application"""
config = {
'memory_threshold_mb': kwargs.get('memory_threshold_mb', 4096),
'gpu_memory_threshold_mb': kwargs.get('gpu_memory_threshold_mb', 2048),
'cleanup_interval': kwargs.get('cleanup_interval', 30)
}
memory_manager = MemoryManager(app, config)
return memory_manager