This commit introduces major enhancements to Talk2Me: ## Database Integration - PostgreSQL support with SQLAlchemy ORM - Redis integration for caching and real-time analytics - Automated database initialization scripts - Migration support infrastructure ## User Authentication System - JWT-based API authentication - Session-based web authentication - API key authentication for programmatic access - User roles and permissions (admin/user) - Login history and session tracking - Rate limiting per user with customizable limits ## Admin Dashboard - Real-time analytics and monitoring - User management interface (create, edit, delete users) - System health monitoring - Request/error tracking - Language pair usage statistics - Performance metrics visualization ## Key Features - Dual authentication support (token + user accounts) - Graceful fallback for missing services - Non-blocking analytics middleware - Comprehensive error handling - Session management with security features ## Bug Fixes - Fixed rate limiting bypass for admin routes - Added missing email validation method - Improved error handling for missing database tables - Fixed session-based authentication for API endpoints 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
437 lines
16 KiB
Python
437 lines
16 KiB
Python
# Rate limiting implementation for Flask
|
|
import time
|
|
import logging
|
|
from functools import wraps
|
|
from collections import defaultdict, deque
|
|
from threading import Lock
|
|
from flask import request, jsonify, g
|
|
from datetime import datetime, timedelta
|
|
import hashlib
|
|
import json
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class RateLimiter:
|
|
"""
|
|
Token bucket rate limiter with sliding window and multiple strategies
|
|
"""
|
|
def __init__(self):
|
|
self.buckets = defaultdict(lambda: {
|
|
'tokens': 5, # Start with some tokens to avoid immediate burst errors
|
|
'last_update': time.time(),
|
|
'requests': deque(maxlen=1000) # Track last 1000 requests
|
|
})
|
|
self.lock = Lock()
|
|
|
|
# Default limits (can be overridden per endpoint)
|
|
self.default_limits = {
|
|
'requests_per_minute': 30,
|
|
'requests_per_hour': 500,
|
|
'burst_size': 10,
|
|
'token_refresh_rate': 0.5 # tokens per second
|
|
}
|
|
|
|
# Endpoint-specific limits
|
|
self.endpoint_limits = {
|
|
'/transcribe': {
|
|
'requests_per_minute': 10,
|
|
'requests_per_hour': 100,
|
|
'burst_size': 3,
|
|
'token_refresh_rate': 0.167, # 1 token per 6 seconds
|
|
'max_request_size': 10 * 1024 * 1024 # 10MB
|
|
},
|
|
'/translate': {
|
|
'requests_per_minute': 20,
|
|
'requests_per_hour': 300,
|
|
'burst_size': 5,
|
|
'token_refresh_rate': 0.333, # 1 token per 3 seconds
|
|
'max_request_size': 100 * 1024 # 100KB
|
|
},
|
|
'/translate/stream': {
|
|
'requests_per_minute': 10,
|
|
'requests_per_hour': 150,
|
|
'burst_size': 3,
|
|
'token_refresh_rate': 0.167,
|
|
'max_request_size': 100 * 1024 # 100KB
|
|
},
|
|
'/speak': {
|
|
'requests_per_minute': 15,
|
|
'requests_per_hour': 200,
|
|
'burst_size': 3,
|
|
'token_refresh_rate': 0.25, # 1 token per 4 seconds
|
|
'max_request_size': 50 * 1024 # 50KB
|
|
}
|
|
}
|
|
|
|
# IP-based blocking
|
|
self.blocked_ips = set()
|
|
self.temp_blocked_ips = {} # IP -> unblock_time
|
|
|
|
# Global limits
|
|
self.global_limits = {
|
|
'total_requests_per_minute': 1000,
|
|
'total_requests_per_hour': 10000,
|
|
'concurrent_requests': 50
|
|
}
|
|
self.global_requests = deque(maxlen=10000)
|
|
self.concurrent_requests = 0
|
|
|
|
def get_client_id(self, request):
|
|
"""Get unique client identifier"""
|
|
# Use IP address + user agent for better identification
|
|
ip = request.remote_addr or 'unknown'
|
|
user_agent = request.headers.get('User-Agent', '')
|
|
|
|
# Handle proxied requests
|
|
forwarded_for = request.headers.get('X-Forwarded-For')
|
|
if forwarded_for:
|
|
ip = forwarded_for.split(',')[0].strip()
|
|
|
|
# Create unique identifier
|
|
identifier = f"{ip}:{user_agent}"
|
|
return hashlib.md5(identifier.encode()).hexdigest()
|
|
|
|
def get_limits(self, endpoint):
|
|
"""Get rate limits for endpoint"""
|
|
return self.endpoint_limits.get(endpoint, self.default_limits)
|
|
|
|
def is_ip_blocked(self, ip):
|
|
"""Check if IP is blocked"""
|
|
# Check permanent blocks
|
|
if ip in self.blocked_ips:
|
|
return True
|
|
|
|
# Check temporary blocks
|
|
if ip in self.temp_blocked_ips:
|
|
if time.time() < self.temp_blocked_ips[ip]:
|
|
return True
|
|
else:
|
|
# Unblock if time expired
|
|
del self.temp_blocked_ips[ip]
|
|
|
|
return False
|
|
|
|
def block_ip_temporarily(self, ip, duration=3600):
|
|
"""Block IP temporarily (default 1 hour)"""
|
|
self.temp_blocked_ips[ip] = time.time() + duration
|
|
logger.warning(f"IP {ip} temporarily blocked for {duration} seconds")
|
|
|
|
def check_global_limits(self):
|
|
"""Check global rate limits"""
|
|
now = time.time()
|
|
|
|
# Clean old requests
|
|
minute_ago = now - 60
|
|
hour_ago = now - 3600
|
|
|
|
self.global_requests = deque(
|
|
(t for t in self.global_requests if t > hour_ago),
|
|
maxlen=10000
|
|
)
|
|
|
|
# Count requests
|
|
requests_last_minute = sum(1 for t in self.global_requests if t > minute_ago)
|
|
requests_last_hour = len(self.global_requests)
|
|
|
|
# Check limits
|
|
if requests_last_minute >= self.global_limits['total_requests_per_minute']:
|
|
return False, "Global rate limit exceeded (per minute)"
|
|
|
|
if requests_last_hour >= self.global_limits['total_requests_per_hour']:
|
|
return False, "Global rate limit exceeded (per hour)"
|
|
|
|
if self.concurrent_requests >= self.global_limits['concurrent_requests']:
|
|
return False, "Too many concurrent requests"
|
|
|
|
return True, None
|
|
|
|
def is_exempt_path(self, path):
|
|
"""Check if path is exempt from rate limiting"""
|
|
# Handle both path strings and endpoint names
|
|
if path.startswith('admin.'):
|
|
return True
|
|
|
|
exempt_paths = ['/admin', '/health', '/static']
|
|
return any(path.startswith(p) for p in exempt_paths)
|
|
|
|
def check_rate_limit(self, client_id, endpoint, request_size=0):
|
|
"""Check if request should be allowed"""
|
|
# Log what endpoint we're checking
|
|
logger.debug(f"Checking rate limit for endpoint: {endpoint}")
|
|
|
|
# Skip rate limiting for exempt paths before any processing
|
|
if self.is_exempt_path(endpoint):
|
|
logger.debug(f"Endpoint {endpoint} is exempt from rate limiting")
|
|
return True, None, None
|
|
|
|
with self.lock:
|
|
# Check global limits first
|
|
global_ok, global_msg = self.check_global_limits()
|
|
if not global_ok:
|
|
return False, global_msg, None
|
|
|
|
# Get limits for endpoint
|
|
limits = self.get_limits(endpoint)
|
|
|
|
# Check request size if applicable
|
|
if request_size > 0 and 'max_request_size' in limits:
|
|
if request_size > limits['max_request_size']:
|
|
return False, "Request too large", None
|
|
|
|
# Get or create bucket
|
|
bucket = self.buckets[client_id]
|
|
now = time.time()
|
|
|
|
# Update tokens based on time passed
|
|
time_passed = now - bucket['last_update']
|
|
new_tokens = time_passed * limits['token_refresh_rate']
|
|
bucket['tokens'] = min(
|
|
limits['burst_size'],
|
|
bucket['tokens'] + new_tokens
|
|
)
|
|
bucket['last_update'] = now
|
|
|
|
# Clean old requests from sliding window
|
|
minute_ago = now - 60
|
|
hour_ago = now - 3600
|
|
bucket['requests'] = deque(
|
|
(t for t in bucket['requests'] if t > hour_ago),
|
|
maxlen=1000
|
|
)
|
|
|
|
# Count requests in windows
|
|
requests_last_minute = sum(1 for t in bucket['requests'] if t > minute_ago)
|
|
requests_last_hour = len(bucket['requests'])
|
|
|
|
# Check sliding window limits
|
|
if requests_last_minute >= limits['requests_per_minute']:
|
|
return False, "Rate limit exceeded (per minute)", {
|
|
'retry_after': 60,
|
|
'limit': limits['requests_per_minute'],
|
|
'remaining': 0,
|
|
'reset': int(minute_ago + 60)
|
|
}
|
|
|
|
if requests_last_hour >= limits['requests_per_hour']:
|
|
return False, "Rate limit exceeded (per hour)", {
|
|
'retry_after': 3600,
|
|
'limit': limits['requests_per_hour'],
|
|
'remaining': 0,
|
|
'reset': int(hour_ago + 3600)
|
|
}
|
|
|
|
# Check token bucket
|
|
if bucket['tokens'] < 1:
|
|
retry_after = int(1 / limits['token_refresh_rate'])
|
|
return False, "Rate limit exceeded (burst)", {
|
|
'retry_after': retry_after,
|
|
'limit': limits['burst_size'],
|
|
'remaining': 0,
|
|
'reset': int(now + retry_after)
|
|
}
|
|
|
|
# Request allowed - consume token and record
|
|
bucket['tokens'] -= 1
|
|
bucket['requests'].append(now)
|
|
self.global_requests.append(now)
|
|
|
|
# Calculate remaining
|
|
remaining_minute = limits['requests_per_minute'] - requests_last_minute - 1
|
|
remaining_hour = limits['requests_per_hour'] - requests_last_hour - 1
|
|
|
|
return True, None, {
|
|
'limit': limits['requests_per_minute'],
|
|
'remaining': remaining_minute,
|
|
'reset': int(minute_ago + 60)
|
|
}
|
|
|
|
def increment_concurrent(self):
|
|
"""Increment concurrent request counter"""
|
|
with self.lock:
|
|
self.concurrent_requests += 1
|
|
|
|
def decrement_concurrent(self):
|
|
"""Decrement concurrent request counter"""
|
|
with self.lock:
|
|
self.concurrent_requests = max(0, self.concurrent_requests - 1)
|
|
|
|
def get_client_stats(self, client_id):
|
|
"""Get statistics for a client"""
|
|
with self.lock:
|
|
if client_id not in self.buckets:
|
|
return None
|
|
|
|
bucket = self.buckets[client_id]
|
|
now = time.time()
|
|
minute_ago = now - 60
|
|
hour_ago = now - 3600
|
|
|
|
requests_last_minute = sum(1 for t in bucket['requests'] if t > minute_ago)
|
|
requests_last_hour = len([t for t in bucket['requests'] if t > hour_ago])
|
|
|
|
return {
|
|
'requests_last_minute': requests_last_minute,
|
|
'requests_last_hour': requests_last_hour,
|
|
'tokens_available': bucket['tokens'],
|
|
'last_request': bucket['last_update']
|
|
}
|
|
|
|
def cleanup_old_buckets(self, max_age=86400):
|
|
"""Clean up old unused buckets (default 24 hours)"""
|
|
with self.lock:
|
|
now = time.time()
|
|
to_remove = []
|
|
|
|
for client_id, bucket in self.buckets.items():
|
|
if now - bucket['last_update'] > max_age:
|
|
to_remove.append(client_id)
|
|
|
|
for client_id in to_remove:
|
|
del self.buckets[client_id]
|
|
|
|
if to_remove:
|
|
logger.info(f"Cleaned up {len(to_remove)} old rate limit buckets")
|
|
|
|
# Global rate limiter instance
|
|
rate_limiter = RateLimiter()
|
|
|
|
def rate_limit(endpoint=None,
|
|
requests_per_minute=None,
|
|
requests_per_hour=None,
|
|
burst_size=None,
|
|
check_size=False):
|
|
"""
|
|
Rate limiting decorator for Flask routes
|
|
|
|
Usage:
|
|
@app.route('/api/endpoint')
|
|
@rate_limit(requests_per_minute=10, check_size=True)
|
|
def endpoint():
|
|
return jsonify({'status': 'ok'})
|
|
"""
|
|
def decorator(f):
|
|
@wraps(f)
|
|
def decorated_function(*args, **kwargs):
|
|
# Skip rate limiting for admin routes - check both path and endpoint
|
|
if request.path.startswith('/admin'):
|
|
return f(*args, **kwargs)
|
|
|
|
# Also check endpoint name
|
|
if request.endpoint and request.endpoint.startswith('admin.'):
|
|
return f(*args, **kwargs)
|
|
|
|
# Get client ID
|
|
client_id = rate_limiter.get_client_id(request)
|
|
ip = request.remote_addr
|
|
|
|
# Check if IP is blocked
|
|
if rate_limiter.is_ip_blocked(ip):
|
|
return jsonify({
|
|
'error': 'IP temporarily blocked due to excessive requests'
|
|
}), 429
|
|
|
|
# Get endpoint
|
|
endpoint_path = endpoint or request.endpoint
|
|
|
|
# Override default limits if specified
|
|
if any([requests_per_minute, requests_per_hour, burst_size]):
|
|
limits = rate_limiter.get_limits(endpoint_path).copy()
|
|
if requests_per_minute:
|
|
limits['requests_per_minute'] = requests_per_minute
|
|
if requests_per_hour:
|
|
limits['requests_per_hour'] = requests_per_hour
|
|
if burst_size:
|
|
limits['burst_size'] = burst_size
|
|
rate_limiter.endpoint_limits[endpoint_path] = limits
|
|
|
|
# Check request size if needed
|
|
request_size = 0
|
|
if check_size:
|
|
request_size = request.content_length or 0
|
|
|
|
# Check rate limit
|
|
allowed, message, headers = rate_limiter.check_rate_limit(
|
|
client_id, endpoint_path, request_size
|
|
)
|
|
|
|
if not allowed:
|
|
# Log excessive requests
|
|
logger.warning(f"Rate limit exceeded for {client_id} on {endpoint_path}: {message}")
|
|
|
|
# Check if we should temporarily block this IP
|
|
stats = rate_limiter.get_client_stats(client_id)
|
|
if stats and stats['requests_last_minute'] > 100:
|
|
rate_limiter.block_ip_temporarily(ip, 3600) # 1 hour block
|
|
|
|
response = jsonify({
|
|
'error': message,
|
|
'retry_after': headers.get('retry_after') if headers else 60
|
|
})
|
|
response.status_code = 429
|
|
|
|
# Add rate limit headers
|
|
if headers:
|
|
response.headers['X-RateLimit-Limit'] = str(headers['limit'])
|
|
response.headers['X-RateLimit-Remaining'] = str(headers['remaining'])
|
|
response.headers['X-RateLimit-Reset'] = str(headers['reset'])
|
|
response.headers['Retry-After'] = str(headers['retry_after'])
|
|
|
|
return response
|
|
|
|
# Track concurrent requests
|
|
rate_limiter.increment_concurrent()
|
|
|
|
try:
|
|
# Add rate limit info to response
|
|
g.rate_limit_headers = headers
|
|
response = f(*args, **kwargs)
|
|
|
|
# Add headers to successful response
|
|
if headers and hasattr(response, 'headers'):
|
|
response.headers['X-RateLimit-Limit'] = str(headers['limit'])
|
|
response.headers['X-RateLimit-Remaining'] = str(headers['remaining'])
|
|
response.headers['X-RateLimit-Reset'] = str(headers['reset'])
|
|
|
|
return response
|
|
finally:
|
|
rate_limiter.decrement_concurrent()
|
|
|
|
return decorated_function
|
|
return decorator
|
|
|
|
def cleanup_rate_limiter():
|
|
"""Cleanup function to be called periodically"""
|
|
rate_limiter.cleanup_old_buckets()
|
|
|
|
# IP whitelist/blacklist management
|
|
class IPFilter:
|
|
def __init__(self):
|
|
self.whitelist = set()
|
|
self.blacklist = set()
|
|
|
|
def add_to_whitelist(self, ip):
|
|
self.whitelist.add(ip)
|
|
self.blacklist.discard(ip)
|
|
|
|
def add_to_blacklist(self, ip):
|
|
self.blacklist.add(ip)
|
|
self.whitelist.discard(ip)
|
|
|
|
def is_allowed(self, ip):
|
|
if ip in self.blacklist:
|
|
return False
|
|
if self.whitelist and ip not in self.whitelist:
|
|
return False
|
|
return True
|
|
|
|
ip_filter = IPFilter()
|
|
|
|
def ip_filter_check():
|
|
"""Middleware to check IP filtering"""
|
|
# Skip IP filtering for admin routes
|
|
if request.path.startswith('/admin'):
|
|
return None
|
|
|
|
ip = request.remote_addr
|
|
if not ip_filter.is_allowed(ip):
|
|
return jsonify({'error': 'Access denied'}), 403 |