# Rate limiting implementation for Flask import time import logging from functools import wraps from collections import defaultdict, deque from threading import Lock from flask import request, jsonify, g from datetime import datetime, timedelta import hashlib import json logger = logging.getLogger(__name__) class RateLimiter: """ Token bucket rate limiter with sliding window and multiple strategies """ def __init__(self): self.buckets = defaultdict(lambda: { 'tokens': 0, 'last_update': time.time(), 'requests': deque(maxlen=1000) # Track last 1000 requests }) self.lock = Lock() # Default limits (can be overridden per endpoint) self.default_limits = { 'requests_per_minute': 30, 'requests_per_hour': 500, 'burst_size': 10, 'token_refresh_rate': 0.5 # tokens per second } # Endpoint-specific limits self.endpoint_limits = { '/transcribe': { 'requests_per_minute': 10, 'requests_per_hour': 100, 'burst_size': 3, 'token_refresh_rate': 0.167, # 1 token per 6 seconds 'max_request_size': 10 * 1024 * 1024 # 10MB }, '/translate': { 'requests_per_minute': 20, 'requests_per_hour': 300, 'burst_size': 5, 'token_refresh_rate': 0.333, # 1 token per 3 seconds 'max_request_size': 100 * 1024 # 100KB }, '/translate/stream': { 'requests_per_minute': 10, 'requests_per_hour': 150, 'burst_size': 3, 'token_refresh_rate': 0.167, 'max_request_size': 100 * 1024 # 100KB }, '/speak': { 'requests_per_minute': 15, 'requests_per_hour': 200, 'burst_size': 3, 'token_refresh_rate': 0.25, # 1 token per 4 seconds 'max_request_size': 50 * 1024 # 50KB } } # IP-based blocking self.blocked_ips = set() self.temp_blocked_ips = {} # IP -> unblock_time # Global limits self.global_limits = { 'total_requests_per_minute': 1000, 'total_requests_per_hour': 10000, 'concurrent_requests': 50 } self.global_requests = deque(maxlen=10000) self.concurrent_requests = 0 def get_client_id(self, request): """Get unique client identifier""" # Use IP address + user agent for better identification ip = request.remote_addr or 'unknown' user_agent = request.headers.get('User-Agent', '') # Handle proxied requests forwarded_for = request.headers.get('X-Forwarded-For') if forwarded_for: ip = forwarded_for.split(',')[0].strip() # Create unique identifier identifier = f"{ip}:{user_agent}" return hashlib.md5(identifier.encode()).hexdigest() def get_limits(self, endpoint): """Get rate limits for endpoint""" return self.endpoint_limits.get(endpoint, self.default_limits) def is_ip_blocked(self, ip): """Check if IP is blocked""" # Check permanent blocks if ip in self.blocked_ips: return True # Check temporary blocks if ip in self.temp_blocked_ips: if time.time() < self.temp_blocked_ips[ip]: return True else: # Unblock if time expired del self.temp_blocked_ips[ip] return False def block_ip_temporarily(self, ip, duration=3600): """Block IP temporarily (default 1 hour)""" self.temp_blocked_ips[ip] = time.time() + duration logger.warning(f"IP {ip} temporarily blocked for {duration} seconds") def check_global_limits(self): """Check global rate limits""" now = time.time() # Clean old requests minute_ago = now - 60 hour_ago = now - 3600 self.global_requests = deque( (t for t in self.global_requests if t > hour_ago), maxlen=10000 ) # Count requests requests_last_minute = sum(1 for t in self.global_requests if t > minute_ago) requests_last_hour = len(self.global_requests) # Check limits if requests_last_minute >= self.global_limits['total_requests_per_minute']: return False, "Global rate limit exceeded (per minute)" if requests_last_hour >= self.global_limits['total_requests_per_hour']: return False, "Global rate limit exceeded (per hour)" if self.concurrent_requests >= self.global_limits['concurrent_requests']: return False, "Too many concurrent requests" return True, None def check_rate_limit(self, client_id, endpoint, request_size=0): """Check if request should be allowed""" with self.lock: # Check global limits first global_ok, global_msg = self.check_global_limits() if not global_ok: return False, global_msg, None # Get limits for endpoint limits = self.get_limits(endpoint) # Check request size if applicable if request_size > 0 and 'max_request_size' in limits: if request_size > limits['max_request_size']: return False, "Request too large", None # Get or create bucket bucket = self.buckets[client_id] now = time.time() # Update tokens based on time passed time_passed = now - bucket['last_update'] new_tokens = time_passed * limits['token_refresh_rate'] bucket['tokens'] = min( limits['burst_size'], bucket['tokens'] + new_tokens ) bucket['last_update'] = now # Clean old requests from sliding window minute_ago = now - 60 hour_ago = now - 3600 bucket['requests'] = deque( (t for t in bucket['requests'] if t > hour_ago), maxlen=1000 ) # Count requests in windows requests_last_minute = sum(1 for t in bucket['requests'] if t > minute_ago) requests_last_hour = len(bucket['requests']) # Check sliding window limits if requests_last_minute >= limits['requests_per_minute']: return False, "Rate limit exceeded (per minute)", { 'retry_after': 60, 'limit': limits['requests_per_minute'], 'remaining': 0, 'reset': int(minute_ago + 60) } if requests_last_hour >= limits['requests_per_hour']: return False, "Rate limit exceeded (per hour)", { 'retry_after': 3600, 'limit': limits['requests_per_hour'], 'remaining': 0, 'reset': int(hour_ago + 3600) } # Check token bucket if bucket['tokens'] < 1: retry_after = int(1 / limits['token_refresh_rate']) return False, "Rate limit exceeded (burst)", { 'retry_after': retry_after, 'limit': limits['burst_size'], 'remaining': 0, 'reset': int(now + retry_after) } # Request allowed - consume token and record bucket['tokens'] -= 1 bucket['requests'].append(now) self.global_requests.append(now) # Calculate remaining remaining_minute = limits['requests_per_minute'] - requests_last_minute - 1 remaining_hour = limits['requests_per_hour'] - requests_last_hour - 1 return True, None, { 'limit': limits['requests_per_minute'], 'remaining': remaining_minute, 'reset': int(minute_ago + 60) } def increment_concurrent(self): """Increment concurrent request counter""" with self.lock: self.concurrent_requests += 1 def decrement_concurrent(self): """Decrement concurrent request counter""" with self.lock: self.concurrent_requests = max(0, self.concurrent_requests - 1) def get_client_stats(self, client_id): """Get statistics for a client""" with self.lock: if client_id not in self.buckets: return None bucket = self.buckets[client_id] now = time.time() minute_ago = now - 60 hour_ago = now - 3600 requests_last_minute = sum(1 for t in bucket['requests'] if t > minute_ago) requests_last_hour = len([t for t in bucket['requests'] if t > hour_ago]) return { 'requests_last_minute': requests_last_minute, 'requests_last_hour': requests_last_hour, 'tokens_available': bucket['tokens'], 'last_request': bucket['last_update'] } def cleanup_old_buckets(self, max_age=86400): """Clean up old unused buckets (default 24 hours)""" with self.lock: now = time.time() to_remove = [] for client_id, bucket in self.buckets.items(): if now - bucket['last_update'] > max_age: to_remove.append(client_id) for client_id in to_remove: del self.buckets[client_id] if to_remove: logger.info(f"Cleaned up {len(to_remove)} old rate limit buckets") # Global rate limiter instance rate_limiter = RateLimiter() def rate_limit(endpoint=None, requests_per_minute=None, requests_per_hour=None, burst_size=None, check_size=False): """ Rate limiting decorator for Flask routes Usage: @app.route('/api/endpoint') @rate_limit(requests_per_minute=10, check_size=True) def endpoint(): return jsonify({'status': 'ok'}) """ def decorator(f): @wraps(f) def decorated_function(*args, **kwargs): # Get client ID client_id = rate_limiter.get_client_id(request) ip = request.remote_addr # Check if IP is blocked if rate_limiter.is_ip_blocked(ip): return jsonify({ 'error': 'IP temporarily blocked due to excessive requests' }), 429 # Get endpoint endpoint_path = endpoint or request.endpoint # Override default limits if specified if any([requests_per_minute, requests_per_hour, burst_size]): limits = rate_limiter.get_limits(endpoint_path).copy() if requests_per_minute: limits['requests_per_minute'] = requests_per_minute if requests_per_hour: limits['requests_per_hour'] = requests_per_hour if burst_size: limits['burst_size'] = burst_size rate_limiter.endpoint_limits[endpoint_path] = limits # Check request size if needed request_size = 0 if check_size: request_size = request.content_length or 0 # Check rate limit allowed, message, headers = rate_limiter.check_rate_limit( client_id, endpoint_path, request_size ) if not allowed: # Log excessive requests logger.warning(f"Rate limit exceeded for {client_id} on {endpoint_path}: {message}") # Check if we should temporarily block this IP stats = rate_limiter.get_client_stats(client_id) if stats and stats['requests_last_minute'] > 100: rate_limiter.block_ip_temporarily(ip, 3600) # 1 hour block response = jsonify({ 'error': message, 'retry_after': headers.get('retry_after') if headers else 60 }) response.status_code = 429 # Add rate limit headers if headers: response.headers['X-RateLimit-Limit'] = str(headers['limit']) response.headers['X-RateLimit-Remaining'] = str(headers['remaining']) response.headers['X-RateLimit-Reset'] = str(headers['reset']) response.headers['Retry-After'] = str(headers['retry_after']) return response # Track concurrent requests rate_limiter.increment_concurrent() try: # Add rate limit info to response g.rate_limit_headers = headers response = f(*args, **kwargs) # Add headers to successful response if headers and hasattr(response, 'headers'): response.headers['X-RateLimit-Limit'] = str(headers['limit']) response.headers['X-RateLimit-Remaining'] = str(headers['remaining']) response.headers['X-RateLimit-Reset'] = str(headers['reset']) return response finally: rate_limiter.decrement_concurrent() return decorated_function return decorator def cleanup_rate_limiter(): """Cleanup function to be called periodically""" rate_limiter.cleanup_old_buckets() # IP whitelist/blacklist management class IPFilter: def __init__(self): self.whitelist = set() self.blacklist = set() def add_to_whitelist(self, ip): self.whitelist.add(ip) self.blacklist.discard(ip) def add_to_blacklist(self, ip): self.blacklist.add(ip) self.whitelist.discard(ip) def is_allowed(self, ip): if ip in self.blacklist: return False if self.whitelist and ip not in self.whitelist: return False return True ip_filter = IPFilter() def ip_filter_check(): """Middleware to check IP filtering""" ip = request.remote_addr if not ip_filter.is_allowed(ip): return jsonify({'error': 'Access denied'}), 403