"""User-specific rate limiting that integrates with authentication""" import time import logging from functools import wraps from flask import request, jsonify, g from collections import defaultdict, deque from threading import Lock from datetime import datetime, timedelta from auth import get_current_user from auth_models import User from database import db logger = logging.getLogger(__name__) class UserRateLimiter: """Enhanced rate limiter with user-specific limits""" def __init__(self, default_limiter): self.default_limiter = default_limiter self.user_buckets = defaultdict(lambda: { 'tokens': 0, 'last_update': time.time(), 'requests': deque(maxlen=1000) }) self.lock = Lock() def get_user_limits(self, user: User, endpoint: str): """Get rate limits for a specific user""" # Start with endpoint-specific or default limits base_limits = self.default_limiter.get_limits(endpoint) if not user: return base_limits # Override with user-specific limits user_limits = { 'requests_per_minute': user.rate_limit_per_minute, 'requests_per_hour': user.rate_limit_per_hour, 'requests_per_day': user.rate_limit_per_day, 'burst_size': base_limits.get('burst_size', 10), 'token_refresh_rate': user.rate_limit_per_minute / 60.0 # Convert to per-second } # Admin users get higher limits if user.is_admin: user_limits['requests_per_minute'] *= 10 user_limits['requests_per_hour'] *= 10 user_limits['requests_per_day'] *= 10 user_limits['burst_size'] *= 5 return user_limits def check_user_rate_limit(self, user: User, endpoint: str, request_size: int = 0): """Check rate limit for authenticated user""" if not user: # Fall back to IP-based limiting client_id = self.default_limiter.get_client_id(request) return self.default_limiter.check_rate_limit(client_id, endpoint, request_size) with self.lock: user_id = str(user.id) limits = self.get_user_limits(user, endpoint) # Get or create bucket for user bucket = self.user_buckets[user_id] now = time.time() # Update tokens based on time passed time_passed = now - bucket['last_update'] new_tokens = time_passed * limits['token_refresh_rate'] bucket['tokens'] = min( limits['burst_size'], bucket['tokens'] + new_tokens ) bucket['last_update'] = now # Clean old requests from sliding windows minute_ago = now - 60 hour_ago = now - 3600 day_ago = now - 86400 bucket['requests'] = deque( (t for t in bucket['requests'] if t > day_ago), maxlen=1000 ) # Count requests in windows requests_last_minute = sum(1 for t in bucket['requests'] if t > minute_ago) requests_last_hour = sum(1 for t in bucket['requests'] if t > hour_ago) requests_last_day = len(bucket['requests']) # Check limits if requests_last_minute >= limits['requests_per_minute']: return False, "Rate limit exceeded (per minute)", { 'retry_after': 60, 'limit': limits['requests_per_minute'], 'remaining': 0, 'reset': int(minute_ago + 60), 'scope': 'user' } if requests_last_hour >= limits['requests_per_hour']: return False, "Rate limit exceeded (per hour)", { 'retry_after': 3600, 'limit': limits['requests_per_hour'], 'remaining': 0, 'reset': int(hour_ago + 3600), 'scope': 'user' } if requests_last_day >= limits['requests_per_day']: return False, "Rate limit exceeded (per day)", { 'retry_after': 86400, 'limit': limits['requests_per_day'], 'remaining': 0, 'reset': int(day_ago + 86400), 'scope': 'user' } # Check token bucket if bucket['tokens'] < 1: retry_after = int(1 / limits['token_refresh_rate']) return False, "Rate limit exceeded (burst)", { 'retry_after': retry_after, 'limit': limits['burst_size'], 'remaining': 0, 'reset': int(now + retry_after), 'scope': 'user' } # Request allowed bucket['tokens'] -= 1 bucket['requests'].append(now) # Calculate remaining remaining_minute = limits['requests_per_minute'] - requests_last_minute - 1 remaining_hour = limits['requests_per_hour'] - requests_last_hour - 1 remaining_day = limits['requests_per_day'] - requests_last_day - 1 return True, None, { 'limit_minute': limits['requests_per_minute'], 'limit_hour': limits['requests_per_hour'], 'limit_day': limits['requests_per_day'], 'remaining_minute': remaining_minute, 'remaining_hour': remaining_hour, 'remaining_day': remaining_day, 'reset': int(minute_ago + 60), 'scope': 'user' } def get_user_usage_stats(self, user: User): """Get usage statistics for a user""" if not user: return None with self.lock: user_id = str(user.id) if user_id not in self.user_buckets: return { 'requests_last_minute': 0, 'requests_last_hour': 0, 'requests_last_day': 0, 'tokens_available': 0 } bucket = self.user_buckets[user_id] now = time.time() minute_ago = now - 60 hour_ago = now - 3600 day_ago = now - 86400 requests_last_minute = sum(1 for t in bucket['requests'] if t > minute_ago) requests_last_hour = sum(1 for t in bucket['requests'] if t > hour_ago) requests_last_day = sum(1 for t in bucket['requests'] if t > day_ago) return { 'requests_last_minute': requests_last_minute, 'requests_last_hour': requests_last_hour, 'requests_last_day': requests_last_day, 'tokens_available': bucket['tokens'], 'last_request': bucket['last_update'] } def reset_user_limits(self, user: User): """Reset rate limits for a user (admin action)""" if not user: return False with self.lock: user_id = str(user.id) if user_id in self.user_buckets: del self.user_buckets[user_id] return True return False # Global user rate limiter instance from rate_limiter import rate_limiter as default_rate_limiter user_rate_limiter = UserRateLimiter(default_rate_limiter) def user_aware_rate_limit(endpoint=None, requests_per_minute=None, requests_per_hour=None, requests_per_day=None, burst_size=None, check_size=False, require_auth=False): """ Enhanced rate limiting decorator that considers user authentication Usage: @app.route('/api/endpoint') @user_aware_rate_limit(requests_per_minute=10) def endpoint(): return jsonify({'status': 'ok'}) """ def decorator(f): @wraps(f) def decorated_function(*args, **kwargs): # Get current user (if authenticated) user = get_current_user() # If auth is required but no user, return 401 if require_auth and not user: return jsonify({ 'success': False, 'error': 'Authentication required', 'code': 'auth_required' }), 401 # Get endpoint endpoint_path = endpoint or request.endpoint # Check request size if needed request_size = 0 if check_size: request_size = request.content_length or 0 # Check rate limit if user: # User-specific rate limiting allowed, message, headers = user_rate_limiter.check_user_rate_limit( user, endpoint_path, request_size ) else: # Fall back to IP-based rate limiting client_id = default_rate_limiter.get_client_id(request) allowed, message, headers = default_rate_limiter.check_rate_limit( client_id, endpoint_path, request_size ) if not allowed: # Log excessive requests identifier = f"user:{user.username}" if user else f"ip:{request.remote_addr}" logger.warning(f"Rate limit exceeded for {identifier} on {endpoint_path}: {message}") # Update user stats if authenticated if user: user.last_active_at = datetime.utcnow() db.session.commit() response = jsonify({ 'success': False, 'error': message, 'retry_after': headers.get('retry_after') if headers else 60 }) response.status_code = 429 # Add rate limit headers if headers: if headers.get('scope') == 'user': response.headers['X-RateLimit-Limit'] = str(headers.get('limit_minute', 60)) response.headers['X-RateLimit-Remaining'] = str(headers.get('remaining_minute', 0)) response.headers['X-RateLimit-Limit-Hour'] = str(headers.get('limit_hour', 1000)) response.headers['X-RateLimit-Remaining-Hour'] = str(headers.get('remaining_hour', 0)) response.headers['X-RateLimit-Limit-Day'] = str(headers.get('limit_day', 10000)) response.headers['X-RateLimit-Remaining-Day'] = str(headers.get('remaining_day', 0)) else: response.headers['X-RateLimit-Limit'] = str(headers['limit']) response.headers['X-RateLimit-Remaining'] = str(headers['remaining']) response.headers['X-RateLimit-Reset'] = str(headers['reset']) response.headers['Retry-After'] = str(headers['retry_after']) return response # Track concurrent requests default_rate_limiter.increment_concurrent() try: # Store user in g if authenticated if user: g.current_user = user # Add rate limit info to response g.rate_limit_headers = headers response = f(*args, **kwargs) # Add headers to successful response if headers and hasattr(response, 'headers'): if headers.get('scope') == 'user': response.headers['X-RateLimit-Limit'] = str(headers.get('limit_minute', 60)) response.headers['X-RateLimit-Remaining'] = str(headers.get('remaining_minute', 0)) response.headers['X-RateLimit-Limit-Hour'] = str(headers.get('limit_hour', 1000)) response.headers['X-RateLimit-Remaining-Hour'] = str(headers.get('remaining_hour', 0)) response.headers['X-RateLimit-Limit-Day'] = str(headers.get('limit_day', 10000)) response.headers['X-RateLimit-Remaining-Day'] = str(headers.get('remaining_day', 0)) else: response.headers['X-RateLimit-Limit'] = str(headers.get('limit', 60)) response.headers['X-RateLimit-Remaining'] = str(headers.get('remaining', 0)) response.headers['X-RateLimit-Reset'] = str(headers['reset']) return response finally: default_rate_limiter.decrement_concurrent() return decorated_function return decorator def get_user_rate_limit_status(user: User = None): """Get current rate limit status for a user or IP""" if not user: user = get_current_user() if user: stats = user_rate_limiter.get_user_usage_stats(user) limits = user_rate_limiter.get_user_limits(user, request.endpoint or '/') return { 'type': 'user', 'identifier': user.username, 'limits': { 'per_minute': limits['requests_per_minute'], 'per_hour': limits['requests_per_hour'], 'per_day': limits['requests_per_day'] }, 'usage': stats } else: # IP-based stats client_id = default_rate_limiter.get_client_id(request) stats = default_rate_limiter.get_client_stats(client_id) return { 'type': 'ip', 'identifier': request.remote_addr, 'limits': default_rate_limiter.default_limits, 'usage': stats }