talk2me/rate_limiter.py
Adolfo Delorenzo a4ef775731 Implement comprehensive rate limiting to protect against DoS attacks
- Add token bucket rate limiter with sliding window algorithm
- Implement per-endpoint configurable rate limits
- Add automatic IP blocking for excessive requests
- Implement global request limits and concurrent request throttling
- Add request size validation for all endpoints
- Create admin endpoints for rate limit management
- Add rate limit headers to responses
- Implement cleanup thread for old rate limit buckets
- Create detailed rate limiting documentation

Rate limits:
- Transcription: 10/min, 100/hour, max 10MB
- Translation: 20/min, 300/hour, max 100KB
- Streaming: 10/min, 150/hour, max 100KB
- TTS: 15/min, 200/hour, max 50KB
- Global: 1000/min, 10000/hour, 50 concurrent

Security features:
- Automatic temporary IP blocking (1 hour) for abuse
- Manual IP blocking via admin endpoint
- Request size validation to prevent large payload attacks
- Burst control to limit sudden traffic spikes

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-06-03 00:14:05 -06:00

408 lines
15 KiB
Python

# Rate limiting implementation for Flask
import time
import logging
from functools import wraps
from collections import defaultdict, deque
from threading import Lock
from flask import request, jsonify, g
from datetime import datetime, timedelta
import hashlib
import json
logger = logging.getLogger(__name__)
class RateLimiter:
"""
Token bucket rate limiter with sliding window and multiple strategies
"""
def __init__(self):
self.buckets = defaultdict(lambda: {
'tokens': 0,
'last_update': time.time(),
'requests': deque(maxlen=1000) # Track last 1000 requests
})
self.lock = Lock()
# Default limits (can be overridden per endpoint)
self.default_limits = {
'requests_per_minute': 30,
'requests_per_hour': 500,
'burst_size': 10,
'token_refresh_rate': 0.5 # tokens per second
}
# Endpoint-specific limits
self.endpoint_limits = {
'/transcribe': {
'requests_per_minute': 10,
'requests_per_hour': 100,
'burst_size': 3,
'token_refresh_rate': 0.167, # 1 token per 6 seconds
'max_request_size': 10 * 1024 * 1024 # 10MB
},
'/translate': {
'requests_per_minute': 20,
'requests_per_hour': 300,
'burst_size': 5,
'token_refresh_rate': 0.333, # 1 token per 3 seconds
'max_request_size': 100 * 1024 # 100KB
},
'/translate/stream': {
'requests_per_minute': 10,
'requests_per_hour': 150,
'burst_size': 3,
'token_refresh_rate': 0.167,
'max_request_size': 100 * 1024 # 100KB
},
'/speak': {
'requests_per_minute': 15,
'requests_per_hour': 200,
'burst_size': 3,
'token_refresh_rate': 0.25, # 1 token per 4 seconds
'max_request_size': 50 * 1024 # 50KB
}
}
# IP-based blocking
self.blocked_ips = set()
self.temp_blocked_ips = {} # IP -> unblock_time
# Global limits
self.global_limits = {
'total_requests_per_minute': 1000,
'total_requests_per_hour': 10000,
'concurrent_requests': 50
}
self.global_requests = deque(maxlen=10000)
self.concurrent_requests = 0
def get_client_id(self, request):
"""Get unique client identifier"""
# Use IP address + user agent for better identification
ip = request.remote_addr or 'unknown'
user_agent = request.headers.get('User-Agent', '')
# Handle proxied requests
forwarded_for = request.headers.get('X-Forwarded-For')
if forwarded_for:
ip = forwarded_for.split(',')[0].strip()
# Create unique identifier
identifier = f"{ip}:{user_agent}"
return hashlib.md5(identifier.encode()).hexdigest()
def get_limits(self, endpoint):
"""Get rate limits for endpoint"""
return self.endpoint_limits.get(endpoint, self.default_limits)
def is_ip_blocked(self, ip):
"""Check if IP is blocked"""
# Check permanent blocks
if ip in self.blocked_ips:
return True
# Check temporary blocks
if ip in self.temp_blocked_ips:
if time.time() < self.temp_blocked_ips[ip]:
return True
else:
# Unblock if time expired
del self.temp_blocked_ips[ip]
return False
def block_ip_temporarily(self, ip, duration=3600):
"""Block IP temporarily (default 1 hour)"""
self.temp_blocked_ips[ip] = time.time() + duration
logger.warning(f"IP {ip} temporarily blocked for {duration} seconds")
def check_global_limits(self):
"""Check global rate limits"""
now = time.time()
# Clean old requests
minute_ago = now - 60
hour_ago = now - 3600
self.global_requests = deque(
(t for t in self.global_requests if t > hour_ago),
maxlen=10000
)
# Count requests
requests_last_minute = sum(1 for t in self.global_requests if t > minute_ago)
requests_last_hour = len(self.global_requests)
# Check limits
if requests_last_minute >= self.global_limits['total_requests_per_minute']:
return False, "Global rate limit exceeded (per minute)"
if requests_last_hour >= self.global_limits['total_requests_per_hour']:
return False, "Global rate limit exceeded (per hour)"
if self.concurrent_requests >= self.global_limits['concurrent_requests']:
return False, "Too many concurrent requests"
return True, None
def check_rate_limit(self, client_id, endpoint, request_size=0):
"""Check if request should be allowed"""
with self.lock:
# Check global limits first
global_ok, global_msg = self.check_global_limits()
if not global_ok:
return False, global_msg, None
# Get limits for endpoint
limits = self.get_limits(endpoint)
# Check request size if applicable
if request_size > 0 and 'max_request_size' in limits:
if request_size > limits['max_request_size']:
return False, "Request too large", None
# Get or create bucket
bucket = self.buckets[client_id]
now = time.time()
# Update tokens based on time passed
time_passed = now - bucket['last_update']
new_tokens = time_passed * limits['token_refresh_rate']
bucket['tokens'] = min(
limits['burst_size'],
bucket['tokens'] + new_tokens
)
bucket['last_update'] = now
# Clean old requests from sliding window
minute_ago = now - 60
hour_ago = now - 3600
bucket['requests'] = deque(
(t for t in bucket['requests'] if t > hour_ago),
maxlen=1000
)
# Count requests in windows
requests_last_minute = sum(1 for t in bucket['requests'] if t > minute_ago)
requests_last_hour = len(bucket['requests'])
# Check sliding window limits
if requests_last_minute >= limits['requests_per_minute']:
return False, "Rate limit exceeded (per minute)", {
'retry_after': 60,
'limit': limits['requests_per_minute'],
'remaining': 0,
'reset': int(minute_ago + 60)
}
if requests_last_hour >= limits['requests_per_hour']:
return False, "Rate limit exceeded (per hour)", {
'retry_after': 3600,
'limit': limits['requests_per_hour'],
'remaining': 0,
'reset': int(hour_ago + 3600)
}
# Check token bucket
if bucket['tokens'] < 1:
retry_after = int(1 / limits['token_refresh_rate'])
return False, "Rate limit exceeded (burst)", {
'retry_after': retry_after,
'limit': limits['burst_size'],
'remaining': 0,
'reset': int(now + retry_after)
}
# Request allowed - consume token and record
bucket['tokens'] -= 1
bucket['requests'].append(now)
self.global_requests.append(now)
# Calculate remaining
remaining_minute = limits['requests_per_minute'] - requests_last_minute - 1
remaining_hour = limits['requests_per_hour'] - requests_last_hour - 1
return True, None, {
'limit': limits['requests_per_minute'],
'remaining': remaining_minute,
'reset': int(minute_ago + 60)
}
def increment_concurrent(self):
"""Increment concurrent request counter"""
with self.lock:
self.concurrent_requests += 1
def decrement_concurrent(self):
"""Decrement concurrent request counter"""
with self.lock:
self.concurrent_requests = max(0, self.concurrent_requests - 1)
def get_client_stats(self, client_id):
"""Get statistics for a client"""
with self.lock:
if client_id not in self.buckets:
return None
bucket = self.buckets[client_id]
now = time.time()
minute_ago = now - 60
hour_ago = now - 3600
requests_last_minute = sum(1 for t in bucket['requests'] if t > minute_ago)
requests_last_hour = len([t for t in bucket['requests'] if t > hour_ago])
return {
'requests_last_minute': requests_last_minute,
'requests_last_hour': requests_last_hour,
'tokens_available': bucket['tokens'],
'last_request': bucket['last_update']
}
def cleanup_old_buckets(self, max_age=86400):
"""Clean up old unused buckets (default 24 hours)"""
with self.lock:
now = time.time()
to_remove = []
for client_id, bucket in self.buckets.items():
if now - bucket['last_update'] > max_age:
to_remove.append(client_id)
for client_id in to_remove:
del self.buckets[client_id]
if to_remove:
logger.info(f"Cleaned up {len(to_remove)} old rate limit buckets")
# Global rate limiter instance
rate_limiter = RateLimiter()
def rate_limit(endpoint=None,
requests_per_minute=None,
requests_per_hour=None,
burst_size=None,
check_size=False):
"""
Rate limiting decorator for Flask routes
Usage:
@app.route('/api/endpoint')
@rate_limit(requests_per_minute=10, check_size=True)
def endpoint():
return jsonify({'status': 'ok'})
"""
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
# Get client ID
client_id = rate_limiter.get_client_id(request)
ip = request.remote_addr
# Check if IP is blocked
if rate_limiter.is_ip_blocked(ip):
return jsonify({
'error': 'IP temporarily blocked due to excessive requests'
}), 429
# Get endpoint
endpoint_path = endpoint or request.endpoint
# Override default limits if specified
if any([requests_per_minute, requests_per_hour, burst_size]):
limits = rate_limiter.get_limits(endpoint_path).copy()
if requests_per_minute:
limits['requests_per_minute'] = requests_per_minute
if requests_per_hour:
limits['requests_per_hour'] = requests_per_hour
if burst_size:
limits['burst_size'] = burst_size
rate_limiter.endpoint_limits[endpoint_path] = limits
# Check request size if needed
request_size = 0
if check_size:
request_size = request.content_length or 0
# Check rate limit
allowed, message, headers = rate_limiter.check_rate_limit(
client_id, endpoint_path, request_size
)
if not allowed:
# Log excessive requests
logger.warning(f"Rate limit exceeded for {client_id} on {endpoint_path}: {message}")
# Check if we should temporarily block this IP
stats = rate_limiter.get_client_stats(client_id)
if stats and stats['requests_last_minute'] > 100:
rate_limiter.block_ip_temporarily(ip, 3600) # 1 hour block
response = jsonify({
'error': message,
'retry_after': headers.get('retry_after') if headers else 60
})
response.status_code = 429
# Add rate limit headers
if headers:
response.headers['X-RateLimit-Limit'] = str(headers['limit'])
response.headers['X-RateLimit-Remaining'] = str(headers['remaining'])
response.headers['X-RateLimit-Reset'] = str(headers['reset'])
response.headers['Retry-After'] = str(headers['retry_after'])
return response
# Track concurrent requests
rate_limiter.increment_concurrent()
try:
# Add rate limit info to response
g.rate_limit_headers = headers
response = f(*args, **kwargs)
# Add headers to successful response
if headers and hasattr(response, 'headers'):
response.headers['X-RateLimit-Limit'] = str(headers['limit'])
response.headers['X-RateLimit-Remaining'] = str(headers['remaining'])
response.headers['X-RateLimit-Reset'] = str(headers['reset'])
return response
finally:
rate_limiter.decrement_concurrent()
return decorated_function
return decorator
def cleanup_rate_limiter():
"""Cleanup function to be called periodically"""
rate_limiter.cleanup_old_buckets()
# IP whitelist/blacklist management
class IPFilter:
def __init__(self):
self.whitelist = set()
self.blacklist = set()
def add_to_whitelist(self, ip):
self.whitelist.add(ip)
self.blacklist.discard(ip)
def add_to_blacklist(self, ip):
self.blacklist.add(ip)
self.whitelist.discard(ip)
def is_allowed(self, ip):
if ip in self.blacklist:
return False
if self.whitelist and ip not in self.whitelist:
return False
return True
ip_filter = IPFilter()
def ip_filter_check():
"""Middleware to check IP filtering"""
ip = request.remote_addr
if not ip_filter.is_allowed(ip):
return jsonify({'error': 'Access denied'}), 403