This commit introduces major enhancements to Talk2Me: ## Database Integration - PostgreSQL support with SQLAlchemy ORM - Redis integration for caching and real-time analytics - Automated database initialization scripts - Migration support infrastructure ## User Authentication System - JWT-based API authentication - Session-based web authentication - API key authentication for programmatic access - User roles and permissions (admin/user) - Login history and session tracking - Rate limiting per user with customizable limits ## Admin Dashboard - Real-time analytics and monitoring - User management interface (create, edit, delete users) - System health monitoring - Request/error tracking - Language pair usage statistics - Performance metrics visualization ## Key Features - Dual authentication support (token + user accounts) - Graceful fallback for missing services - Non-blocking analytics middleware - Comprehensive error handling - Session management with security features ## Bug Fixes - Fixed rate limiting bypass for admin routes - Added missing email validation method - Improved error handling for missing database tables - Fixed session-based authentication for API endpoints 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
365 lines
13 KiB
Python
365 lines
13 KiB
Python
# Redis-based rate limiting implementation
|
|
import time
|
|
import logging
|
|
from functools import wraps
|
|
from flask import request, jsonify, g
|
|
import hashlib
|
|
from typing import Optional, Dict, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class RedisRateLimiter:
|
|
"""Token bucket rate limiter using Redis for distributed rate limiting"""
|
|
|
|
def __init__(self, redis_manager):
|
|
self.redis = redis_manager
|
|
|
|
# Default limits (can be overridden per endpoint)
|
|
self.default_limits = {
|
|
'requests_per_minute': 30,
|
|
'requests_per_hour': 500,
|
|
'burst_size': 10,
|
|
'token_refresh_rate': 0.5 # tokens per second
|
|
}
|
|
|
|
# Endpoint-specific limits
|
|
self.endpoint_limits = {
|
|
'/transcribe': {
|
|
'requests_per_minute': 10,
|
|
'requests_per_hour': 100,
|
|
'burst_size': 3,
|
|
'token_refresh_rate': 0.167,
|
|
'max_request_size': 10 * 1024 * 1024 # 10MB
|
|
},
|
|
'/translate': {
|
|
'requests_per_minute': 20,
|
|
'requests_per_hour': 300,
|
|
'burst_size': 5,
|
|
'token_refresh_rate': 0.333,
|
|
'max_request_size': 100 * 1024 # 100KB
|
|
},
|
|
'/translate/stream': {
|
|
'requests_per_minute': 10,
|
|
'requests_per_hour': 150,
|
|
'burst_size': 3,
|
|
'token_refresh_rate': 0.167,
|
|
'max_request_size': 100 * 1024 # 100KB
|
|
},
|
|
'/speak': {
|
|
'requests_per_minute': 15,
|
|
'requests_per_hour': 200,
|
|
'burst_size': 3,
|
|
'token_refresh_rate': 0.25,
|
|
'max_request_size': 50 * 1024 # 50KB
|
|
}
|
|
}
|
|
|
|
# Global limits
|
|
self.global_limits = {
|
|
'total_requests_per_minute': 1000,
|
|
'total_requests_per_hour': 10000,
|
|
'concurrent_requests': 50
|
|
}
|
|
|
|
def get_client_id(self, req) -> str:
|
|
"""Get unique client identifier"""
|
|
ip = req.remote_addr or 'unknown'
|
|
user_agent = req.headers.get('User-Agent', '')
|
|
|
|
# Handle proxied requests
|
|
forwarded_for = req.headers.get('X-Forwarded-For')
|
|
if forwarded_for:
|
|
ip = forwarded_for.split(',')[0].strip()
|
|
|
|
# Create unique identifier
|
|
identifier = f"{ip}:{user_agent}"
|
|
return hashlib.md5(identifier.encode()).hexdigest()
|
|
|
|
def get_limits(self, endpoint: str) -> Dict:
|
|
"""Get rate limits for endpoint"""
|
|
return self.endpoint_limits.get(endpoint, self.default_limits)
|
|
|
|
def is_ip_blocked(self, ip: str) -> bool:
|
|
"""Check if IP is blocked"""
|
|
# Check permanent blocks
|
|
if self.redis.sismember('blocked_ips:permanent', ip):
|
|
return True
|
|
|
|
# Check temporary blocks
|
|
block_key = f'blocked_ip:{ip}'
|
|
if self.redis.exists(block_key):
|
|
return True
|
|
|
|
return False
|
|
|
|
def block_ip_temporarily(self, ip: str, duration: int = 3600):
|
|
"""Block IP temporarily"""
|
|
block_key = f'blocked_ip:{ip}'
|
|
self.redis.set(block_key, 1, expire=duration)
|
|
logger.warning(f"IP {ip} temporarily blocked for {duration} seconds")
|
|
|
|
def check_global_limits(self) -> Tuple[bool, Optional[str]]:
|
|
"""Check global rate limits"""
|
|
now = time.time()
|
|
|
|
# Check requests per minute
|
|
minute_key = 'global:requests:minute'
|
|
allowed, remaining = self.redis.check_rate_limit(
|
|
minute_key,
|
|
self.global_limits['total_requests_per_minute'],
|
|
60
|
|
)
|
|
if not allowed:
|
|
return False, "Global rate limit exceeded (per minute)"
|
|
|
|
# Check requests per hour
|
|
hour_key = 'global:requests:hour'
|
|
allowed, remaining = self.redis.check_rate_limit(
|
|
hour_key,
|
|
self.global_limits['total_requests_per_hour'],
|
|
3600
|
|
)
|
|
if not allowed:
|
|
return False, "Global rate limit exceeded (per hour)"
|
|
|
|
# Check concurrent requests
|
|
concurrent_key = 'global:concurrent'
|
|
current_concurrent = self.redis.get(concurrent_key, 0)
|
|
if current_concurrent >= self.global_limits['concurrent_requests']:
|
|
return False, "Too many concurrent requests"
|
|
|
|
return True, None
|
|
|
|
def check_rate_limit(self, client_id: str, endpoint: str,
|
|
request_size: int = 0) -> Tuple[bool, Optional[str], Optional[Dict]]:
|
|
"""Check if request should be allowed"""
|
|
# Check global limits first
|
|
global_ok, global_msg = self.check_global_limits()
|
|
if not global_ok:
|
|
return False, global_msg, None
|
|
|
|
# Get limits for endpoint
|
|
limits = self.get_limits(endpoint)
|
|
|
|
# Check request size if applicable
|
|
if request_size > 0 and 'max_request_size' in limits:
|
|
if request_size > limits['max_request_size']:
|
|
return False, "Request too large", None
|
|
|
|
# Token bucket implementation using Redis
|
|
bucket_key = f'bucket:{client_id}:{endpoint}'
|
|
now = time.time()
|
|
|
|
# Get current bucket state
|
|
bucket_data = self.redis.hgetall(bucket_key)
|
|
|
|
# Initialize bucket if empty
|
|
if not bucket_data:
|
|
bucket_data = {
|
|
'tokens': limits['burst_size'],
|
|
'last_update': now
|
|
}
|
|
else:
|
|
# Update tokens based on time passed
|
|
last_update = float(bucket_data.get('last_update', now))
|
|
time_passed = now - last_update
|
|
new_tokens = time_passed * limits['token_refresh_rate']
|
|
|
|
current_tokens = float(bucket_data.get('tokens', 0))
|
|
bucket_data['tokens'] = min(
|
|
limits['burst_size'],
|
|
current_tokens + new_tokens
|
|
)
|
|
bucket_data['last_update'] = now
|
|
|
|
# Check sliding window limits
|
|
minute_allowed, minute_remaining = self.redis.check_rate_limit(
|
|
f'window:{client_id}:{endpoint}:minute',
|
|
limits['requests_per_minute'],
|
|
60
|
|
)
|
|
|
|
if not minute_allowed:
|
|
return False, "Rate limit exceeded (per minute)", {
|
|
'retry_after': 60,
|
|
'limit': limits['requests_per_minute'],
|
|
'remaining': 0,
|
|
'reset': int(now + 60)
|
|
}
|
|
|
|
hour_allowed, hour_remaining = self.redis.check_rate_limit(
|
|
f'window:{client_id}:{endpoint}:hour',
|
|
limits['requests_per_hour'],
|
|
3600
|
|
)
|
|
|
|
if not hour_allowed:
|
|
return False, "Rate limit exceeded (per hour)", {
|
|
'retry_after': 3600,
|
|
'limit': limits['requests_per_hour'],
|
|
'remaining': 0,
|
|
'reset': int(now + 3600)
|
|
}
|
|
|
|
# Check token bucket
|
|
if float(bucket_data['tokens']) < 1:
|
|
retry_after = int(1 / limits['token_refresh_rate'])
|
|
return False, "Rate limit exceeded (burst)", {
|
|
'retry_after': retry_after,
|
|
'limit': limits['burst_size'],
|
|
'remaining': 0,
|
|
'reset': int(now + retry_after)
|
|
}
|
|
|
|
# Request allowed - update bucket
|
|
bucket_data['tokens'] = float(bucket_data['tokens']) - 1
|
|
|
|
# Save bucket state
|
|
self.redis.hset(bucket_key, 'tokens', bucket_data['tokens'])
|
|
self.redis.hset(bucket_key, 'last_update', bucket_data['last_update'])
|
|
self.redis.expire(bucket_key, 86400) # Expire after 24 hours
|
|
|
|
return True, None, {
|
|
'limit': limits['requests_per_minute'],
|
|
'remaining': minute_remaining,
|
|
'reset': int(now + 60)
|
|
}
|
|
|
|
def increment_concurrent(self):
|
|
"""Increment concurrent request counter"""
|
|
self.redis.incr('global:concurrent')
|
|
|
|
def decrement_concurrent(self):
|
|
"""Decrement concurrent request counter"""
|
|
self.redis.decr('global:concurrent')
|
|
|
|
def get_client_stats(self, client_id: str) -> Optional[Dict]:
|
|
"""Get statistics for a client"""
|
|
stats = {
|
|
'requests_last_minute': 0,
|
|
'requests_last_hour': 0,
|
|
'buckets': {}
|
|
}
|
|
|
|
# Get request counts from all endpoints
|
|
for endpoint in self.endpoint_limits.keys():
|
|
minute_key = f'window:{client_id}:{endpoint}:minute'
|
|
hour_key = f'window:{client_id}:{endpoint}:hour'
|
|
|
|
# This is approximate since we're using sliding windows
|
|
minute_count = self.redis.scard(minute_key)
|
|
hour_count = self.redis.scard(hour_key)
|
|
|
|
stats['requests_last_minute'] += minute_count
|
|
stats['requests_last_hour'] += hour_count
|
|
|
|
# Get bucket info
|
|
bucket_key = f'bucket:{client_id}:{endpoint}'
|
|
bucket_data = self.redis.hgetall(bucket_key)
|
|
if bucket_data:
|
|
stats['buckets'][endpoint] = {
|
|
'tokens': float(bucket_data.get('tokens', 0)),
|
|
'last_update': float(bucket_data.get('last_update', 0))
|
|
}
|
|
|
|
return stats
|
|
|
|
|
|
def rate_limit(endpoint=None,
|
|
requests_per_minute=None,
|
|
requests_per_hour=None,
|
|
burst_size=None,
|
|
check_size=False):
|
|
"""
|
|
Rate limiting decorator for Flask routes using Redis
|
|
"""
|
|
def decorator(f):
|
|
@wraps(f)
|
|
def decorated_function(*args, **kwargs):
|
|
# Get Redis rate limiter from app
|
|
from flask import current_app
|
|
|
|
if not hasattr(current_app, 'redis_rate_limiter'):
|
|
# No Redis rate limiter, execute function normally
|
|
return f(*args, **kwargs)
|
|
|
|
rate_limiter = current_app.redis_rate_limiter
|
|
|
|
# Get client ID
|
|
client_id = rate_limiter.get_client_id(request)
|
|
ip = request.remote_addr
|
|
|
|
# Check if IP is blocked
|
|
if rate_limiter.is_ip_blocked(ip):
|
|
return jsonify({
|
|
'error': 'IP temporarily blocked due to excessive requests'
|
|
}), 429
|
|
|
|
# Get endpoint
|
|
endpoint_path = endpoint or request.endpoint
|
|
|
|
# Override default limits if specified
|
|
if any([requests_per_minute, requests_per_hour, burst_size]):
|
|
limits = rate_limiter.get_limits(endpoint_path).copy()
|
|
if requests_per_minute:
|
|
limits['requests_per_minute'] = requests_per_minute
|
|
if requests_per_hour:
|
|
limits['requests_per_hour'] = requests_per_hour
|
|
if burst_size:
|
|
limits['burst_size'] = burst_size
|
|
rate_limiter.endpoint_limits[endpoint_path] = limits
|
|
|
|
# Check request size if needed
|
|
request_size = 0
|
|
if check_size:
|
|
request_size = request.content_length or 0
|
|
|
|
# Check rate limit
|
|
allowed, message, headers = rate_limiter.check_rate_limit(
|
|
client_id, endpoint_path, request_size
|
|
)
|
|
|
|
if not allowed:
|
|
# Log excessive requests
|
|
logger.warning(f"Rate limit exceeded for {client_id} on {endpoint_path}: {message}")
|
|
|
|
# Check if we should temporarily block this IP
|
|
stats = rate_limiter.get_client_stats(client_id)
|
|
if stats and stats['requests_last_minute'] > 100:
|
|
rate_limiter.block_ip_temporarily(ip, 3600) # 1 hour block
|
|
|
|
response = jsonify({
|
|
'error': message,
|
|
'retry_after': headers.get('retry_after') if headers else 60
|
|
})
|
|
response.status_code = 429
|
|
|
|
# Add rate limit headers
|
|
if headers:
|
|
response.headers['X-RateLimit-Limit'] = str(headers['limit'])
|
|
response.headers['X-RateLimit-Remaining'] = str(headers['remaining'])
|
|
response.headers['X-RateLimit-Reset'] = str(headers['reset'])
|
|
response.headers['Retry-After'] = str(headers['retry_after'])
|
|
|
|
return response
|
|
|
|
# Track concurrent requests
|
|
rate_limiter.increment_concurrent()
|
|
|
|
try:
|
|
# Add rate limit info to response
|
|
g.rate_limit_headers = headers
|
|
response = f(*args, **kwargs)
|
|
|
|
# Add headers to successful response
|
|
if headers and hasattr(response, 'headers'):
|
|
response.headers['X-RateLimit-Limit'] = str(headers['limit'])
|
|
response.headers['X-RateLimit-Remaining'] = str(headers['remaining'])
|
|
response.headers['X-RateLimit-Reset'] = str(headers['reset'])
|
|
|
|
return response
|
|
finally:
|
|
rate_limiter.decrement_concurrent()
|
|
|
|
return decorated_function
|
|
return decorator |