# Redis-based rate limiting implementation import time import logging from functools import wraps from flask import request, jsonify, g import hashlib from typing import Optional, Dict, Tuple logger = logging.getLogger(__name__) class RedisRateLimiter: """Token bucket rate limiter using Redis for distributed rate limiting""" def __init__(self, redis_manager): self.redis = redis_manager # Default limits (can be overridden per endpoint) self.default_limits = { 'requests_per_minute': 30, 'requests_per_hour': 500, 'burst_size': 10, 'token_refresh_rate': 0.5 # tokens per second } # Endpoint-specific limits self.endpoint_limits = { '/transcribe': { 'requests_per_minute': 10, 'requests_per_hour': 100, 'burst_size': 3, 'token_refresh_rate': 0.167, 'max_request_size': 10 * 1024 * 1024 # 10MB }, '/translate': { 'requests_per_minute': 20, 'requests_per_hour': 300, 'burst_size': 5, 'token_refresh_rate': 0.333, 'max_request_size': 100 * 1024 # 100KB }, '/translate/stream': { 'requests_per_minute': 10, 'requests_per_hour': 150, 'burst_size': 3, 'token_refresh_rate': 0.167, 'max_request_size': 100 * 1024 # 100KB }, '/speak': { 'requests_per_minute': 15, 'requests_per_hour': 200, 'burst_size': 3, 'token_refresh_rate': 0.25, 'max_request_size': 50 * 1024 # 50KB } } # Global limits self.global_limits = { 'total_requests_per_minute': 1000, 'total_requests_per_hour': 10000, 'concurrent_requests': 50 } def get_client_id(self, req) -> str: """Get unique client identifier""" ip = req.remote_addr or 'unknown' user_agent = req.headers.get('User-Agent', '') # Handle proxied requests forwarded_for = req.headers.get('X-Forwarded-For') if forwarded_for: ip = forwarded_for.split(',')[0].strip() # Create unique identifier identifier = f"{ip}:{user_agent}" return hashlib.md5(identifier.encode()).hexdigest() def get_limits(self, endpoint: str) -> Dict: """Get rate limits for endpoint""" return self.endpoint_limits.get(endpoint, self.default_limits) def is_ip_blocked(self, ip: str) -> bool: """Check if IP is blocked""" # Check permanent blocks if self.redis.sismember('blocked_ips:permanent', ip): return True # Check temporary blocks block_key = f'blocked_ip:{ip}' if self.redis.exists(block_key): return True return False def block_ip_temporarily(self, ip: str, duration: int = 3600): """Block IP temporarily""" block_key = f'blocked_ip:{ip}' self.redis.set(block_key, 1, expire=duration) logger.warning(f"IP {ip} temporarily blocked for {duration} seconds") def check_global_limits(self) -> Tuple[bool, Optional[str]]: """Check global rate limits""" now = time.time() # Check requests per minute minute_key = 'global:requests:minute' allowed, remaining = self.redis.check_rate_limit( minute_key, self.global_limits['total_requests_per_minute'], 60 ) if not allowed: return False, "Global rate limit exceeded (per minute)" # Check requests per hour hour_key = 'global:requests:hour' allowed, remaining = self.redis.check_rate_limit( hour_key, self.global_limits['total_requests_per_hour'], 3600 ) if not allowed: return False, "Global rate limit exceeded (per hour)" # Check concurrent requests concurrent_key = 'global:concurrent' current_concurrent = self.redis.get(concurrent_key, 0) if current_concurrent >= self.global_limits['concurrent_requests']: return False, "Too many concurrent requests" return True, None def check_rate_limit(self, client_id: str, endpoint: str, request_size: int = 0) -> Tuple[bool, Optional[str], Optional[Dict]]: """Check if request should be allowed""" # Check global limits first global_ok, global_msg = self.check_global_limits() if not global_ok: return False, global_msg, None # Get limits for endpoint limits = self.get_limits(endpoint) # Check request size if applicable if request_size > 0 and 'max_request_size' in limits: if request_size > limits['max_request_size']: return False, "Request too large", None # Token bucket implementation using Redis bucket_key = f'bucket:{client_id}:{endpoint}' now = time.time() # Get current bucket state bucket_data = self.redis.hgetall(bucket_key) # Initialize bucket if empty if not bucket_data: bucket_data = { 'tokens': limits['burst_size'], 'last_update': now } else: # Update tokens based on time passed last_update = float(bucket_data.get('last_update', now)) time_passed = now - last_update new_tokens = time_passed * limits['token_refresh_rate'] current_tokens = float(bucket_data.get('tokens', 0)) bucket_data['tokens'] = min( limits['burst_size'], current_tokens + new_tokens ) bucket_data['last_update'] = now # Check sliding window limits minute_allowed, minute_remaining = self.redis.check_rate_limit( f'window:{client_id}:{endpoint}:minute', limits['requests_per_minute'], 60 ) if not minute_allowed: return False, "Rate limit exceeded (per minute)", { 'retry_after': 60, 'limit': limits['requests_per_minute'], 'remaining': 0, 'reset': int(now + 60) } hour_allowed, hour_remaining = self.redis.check_rate_limit( f'window:{client_id}:{endpoint}:hour', limits['requests_per_hour'], 3600 ) if not hour_allowed: return False, "Rate limit exceeded (per hour)", { 'retry_after': 3600, 'limit': limits['requests_per_hour'], 'remaining': 0, 'reset': int(now + 3600) } # Check token bucket if float(bucket_data['tokens']) < 1: retry_after = int(1 / limits['token_refresh_rate']) return False, "Rate limit exceeded (burst)", { 'retry_after': retry_after, 'limit': limits['burst_size'], 'remaining': 0, 'reset': int(now + retry_after) } # Request allowed - update bucket bucket_data['tokens'] = float(bucket_data['tokens']) - 1 # Save bucket state self.redis.hset(bucket_key, 'tokens', bucket_data['tokens']) self.redis.hset(bucket_key, 'last_update', bucket_data['last_update']) self.redis.expire(bucket_key, 86400) # Expire after 24 hours return True, None, { 'limit': limits['requests_per_minute'], 'remaining': minute_remaining, 'reset': int(now + 60) } def increment_concurrent(self): """Increment concurrent request counter""" self.redis.incr('global:concurrent') def decrement_concurrent(self): """Decrement concurrent request counter""" self.redis.decr('global:concurrent') def get_client_stats(self, client_id: str) -> Optional[Dict]: """Get statistics for a client""" stats = { 'requests_last_minute': 0, 'requests_last_hour': 0, 'buckets': {} } # Get request counts from all endpoints for endpoint in self.endpoint_limits.keys(): minute_key = f'window:{client_id}:{endpoint}:minute' hour_key = f'window:{client_id}:{endpoint}:hour' # This is approximate since we're using sliding windows minute_count = self.redis.scard(minute_key) hour_count = self.redis.scard(hour_key) stats['requests_last_minute'] += minute_count stats['requests_last_hour'] += hour_count # Get bucket info bucket_key = f'bucket:{client_id}:{endpoint}' bucket_data = self.redis.hgetall(bucket_key) if bucket_data: stats['buckets'][endpoint] = { 'tokens': float(bucket_data.get('tokens', 0)), 'last_update': float(bucket_data.get('last_update', 0)) } return stats def rate_limit(endpoint=None, requests_per_minute=None, requests_per_hour=None, burst_size=None, check_size=False): """ Rate limiting decorator for Flask routes using Redis """ def decorator(f): @wraps(f) def decorated_function(*args, **kwargs): # Get Redis rate limiter from app from flask import current_app if not hasattr(current_app, 'redis_rate_limiter'): # No Redis rate limiter, execute function normally return f(*args, **kwargs) rate_limiter = current_app.redis_rate_limiter # Get client ID client_id = rate_limiter.get_client_id(request) ip = request.remote_addr # Check if IP is blocked if rate_limiter.is_ip_blocked(ip): return jsonify({ 'error': 'IP temporarily blocked due to excessive requests' }), 429 # Get endpoint endpoint_path = endpoint or request.endpoint # Override default limits if specified if any([requests_per_minute, requests_per_hour, burst_size]): limits = rate_limiter.get_limits(endpoint_path).copy() if requests_per_minute: limits['requests_per_minute'] = requests_per_minute if requests_per_hour: limits['requests_per_hour'] = requests_per_hour if burst_size: limits['burst_size'] = burst_size rate_limiter.endpoint_limits[endpoint_path] = limits # Check request size if needed request_size = 0 if check_size: request_size = request.content_length or 0 # Check rate limit allowed, message, headers = rate_limiter.check_rate_limit( client_id, endpoint_path, request_size ) if not allowed: # Log excessive requests logger.warning(f"Rate limit exceeded for {client_id} on {endpoint_path}: {message}") # Check if we should temporarily block this IP stats = rate_limiter.get_client_stats(client_id) if stats and stats['requests_last_minute'] > 100: rate_limiter.block_ip_temporarily(ip, 3600) # 1 hour block response = jsonify({ 'error': message, 'retry_after': headers.get('retry_after') if headers else 60 }) response.status_code = 429 # Add rate limit headers if headers: response.headers['X-RateLimit-Limit'] = str(headers['limit']) response.headers['X-RateLimit-Remaining'] = str(headers['remaining']) response.headers['X-RateLimit-Reset'] = str(headers['reset']) response.headers['Retry-After'] = str(headers['retry_after']) return response # Track concurrent requests rate_limiter.increment_concurrent() try: # Add rate limit info to response g.rate_limit_headers = headers response = f(*args, **kwargs) # Add headers to successful response if headers and hasattr(response, 'headers'): response.headers['X-RateLimit-Limit'] = str(headers['limit']) response.headers['X-RateLimit-Remaining'] = str(headers['remaining']) response.headers['X-RateLimit-Reset'] = str(headers['reset']) return response finally: rate_limiter.decrement_concurrent() return decorated_function return decorator