talk2me/redis_rate_limiter.py
Adolfo Delorenzo fa951c3141 Add comprehensive database integration, authentication, and admin dashboard
This commit introduces major enhancements to Talk2Me:

## Database Integration
- PostgreSQL support with SQLAlchemy ORM
- Redis integration for caching and real-time analytics
- Automated database initialization scripts
- Migration support infrastructure

## User Authentication System
- JWT-based API authentication
- Session-based web authentication
- API key authentication for programmatic access
- User roles and permissions (admin/user)
- Login history and session tracking
- Rate limiting per user with customizable limits

## Admin Dashboard
- Real-time analytics and monitoring
- User management interface (create, edit, delete users)
- System health monitoring
- Request/error tracking
- Language pair usage statistics
- Performance metrics visualization

## Key Features
- Dual authentication support (token + user accounts)
- Graceful fallback for missing services
- Non-blocking analytics middleware
- Comprehensive error handling
- Session management with security features

## Bug Fixes
- Fixed rate limiting bypass for admin routes
- Added missing email validation method
- Improved error handling for missing database tables
- Fixed session-based authentication for API endpoints

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-06-03 18:21:56 -06:00

365 lines
13 KiB
Python

# Redis-based rate limiting implementation
import time
import logging
from functools import wraps
from flask import request, jsonify, g
import hashlib
from typing import Optional, Dict, Tuple
logger = logging.getLogger(__name__)
class RedisRateLimiter:
"""Token bucket rate limiter using Redis for distributed rate limiting"""
def __init__(self, redis_manager):
self.redis = redis_manager
# Default limits (can be overridden per endpoint)
self.default_limits = {
'requests_per_minute': 30,
'requests_per_hour': 500,
'burst_size': 10,
'token_refresh_rate': 0.5 # tokens per second
}
# Endpoint-specific limits
self.endpoint_limits = {
'/transcribe': {
'requests_per_minute': 10,
'requests_per_hour': 100,
'burst_size': 3,
'token_refresh_rate': 0.167,
'max_request_size': 10 * 1024 * 1024 # 10MB
},
'/translate': {
'requests_per_minute': 20,
'requests_per_hour': 300,
'burst_size': 5,
'token_refresh_rate': 0.333,
'max_request_size': 100 * 1024 # 100KB
},
'/translate/stream': {
'requests_per_minute': 10,
'requests_per_hour': 150,
'burst_size': 3,
'token_refresh_rate': 0.167,
'max_request_size': 100 * 1024 # 100KB
},
'/speak': {
'requests_per_minute': 15,
'requests_per_hour': 200,
'burst_size': 3,
'token_refresh_rate': 0.25,
'max_request_size': 50 * 1024 # 50KB
}
}
# Global limits
self.global_limits = {
'total_requests_per_minute': 1000,
'total_requests_per_hour': 10000,
'concurrent_requests': 50
}
def get_client_id(self, req) -> str:
"""Get unique client identifier"""
ip = req.remote_addr or 'unknown'
user_agent = req.headers.get('User-Agent', '')
# Handle proxied requests
forwarded_for = req.headers.get('X-Forwarded-For')
if forwarded_for:
ip = forwarded_for.split(',')[0].strip()
# Create unique identifier
identifier = f"{ip}:{user_agent}"
return hashlib.md5(identifier.encode()).hexdigest()
def get_limits(self, endpoint: str) -> Dict:
"""Get rate limits for endpoint"""
return self.endpoint_limits.get(endpoint, self.default_limits)
def is_ip_blocked(self, ip: str) -> bool:
"""Check if IP is blocked"""
# Check permanent blocks
if self.redis.sismember('blocked_ips:permanent', ip):
return True
# Check temporary blocks
block_key = f'blocked_ip:{ip}'
if self.redis.exists(block_key):
return True
return False
def block_ip_temporarily(self, ip: str, duration: int = 3600):
"""Block IP temporarily"""
block_key = f'blocked_ip:{ip}'
self.redis.set(block_key, 1, expire=duration)
logger.warning(f"IP {ip} temporarily blocked for {duration} seconds")
def check_global_limits(self) -> Tuple[bool, Optional[str]]:
"""Check global rate limits"""
now = time.time()
# Check requests per minute
minute_key = 'global:requests:minute'
allowed, remaining = self.redis.check_rate_limit(
minute_key,
self.global_limits['total_requests_per_minute'],
60
)
if not allowed:
return False, "Global rate limit exceeded (per minute)"
# Check requests per hour
hour_key = 'global:requests:hour'
allowed, remaining = self.redis.check_rate_limit(
hour_key,
self.global_limits['total_requests_per_hour'],
3600
)
if not allowed:
return False, "Global rate limit exceeded (per hour)"
# Check concurrent requests
concurrent_key = 'global:concurrent'
current_concurrent = self.redis.get(concurrent_key, 0)
if current_concurrent >= self.global_limits['concurrent_requests']:
return False, "Too many concurrent requests"
return True, None
def check_rate_limit(self, client_id: str, endpoint: str,
request_size: int = 0) -> Tuple[bool, Optional[str], Optional[Dict]]:
"""Check if request should be allowed"""
# Check global limits first
global_ok, global_msg = self.check_global_limits()
if not global_ok:
return False, global_msg, None
# Get limits for endpoint
limits = self.get_limits(endpoint)
# Check request size if applicable
if request_size > 0 and 'max_request_size' in limits:
if request_size > limits['max_request_size']:
return False, "Request too large", None
# Token bucket implementation using Redis
bucket_key = f'bucket:{client_id}:{endpoint}'
now = time.time()
# Get current bucket state
bucket_data = self.redis.hgetall(bucket_key)
# Initialize bucket if empty
if not bucket_data:
bucket_data = {
'tokens': limits['burst_size'],
'last_update': now
}
else:
# Update tokens based on time passed
last_update = float(bucket_data.get('last_update', now))
time_passed = now - last_update
new_tokens = time_passed * limits['token_refresh_rate']
current_tokens = float(bucket_data.get('tokens', 0))
bucket_data['tokens'] = min(
limits['burst_size'],
current_tokens + new_tokens
)
bucket_data['last_update'] = now
# Check sliding window limits
minute_allowed, minute_remaining = self.redis.check_rate_limit(
f'window:{client_id}:{endpoint}:minute',
limits['requests_per_minute'],
60
)
if not minute_allowed:
return False, "Rate limit exceeded (per minute)", {
'retry_after': 60,
'limit': limits['requests_per_minute'],
'remaining': 0,
'reset': int(now + 60)
}
hour_allowed, hour_remaining = self.redis.check_rate_limit(
f'window:{client_id}:{endpoint}:hour',
limits['requests_per_hour'],
3600
)
if not hour_allowed:
return False, "Rate limit exceeded (per hour)", {
'retry_after': 3600,
'limit': limits['requests_per_hour'],
'remaining': 0,
'reset': int(now + 3600)
}
# Check token bucket
if float(bucket_data['tokens']) < 1:
retry_after = int(1 / limits['token_refresh_rate'])
return False, "Rate limit exceeded (burst)", {
'retry_after': retry_after,
'limit': limits['burst_size'],
'remaining': 0,
'reset': int(now + retry_after)
}
# Request allowed - update bucket
bucket_data['tokens'] = float(bucket_data['tokens']) - 1
# Save bucket state
self.redis.hset(bucket_key, 'tokens', bucket_data['tokens'])
self.redis.hset(bucket_key, 'last_update', bucket_data['last_update'])
self.redis.expire(bucket_key, 86400) # Expire after 24 hours
return True, None, {
'limit': limits['requests_per_minute'],
'remaining': minute_remaining,
'reset': int(now + 60)
}
def increment_concurrent(self):
"""Increment concurrent request counter"""
self.redis.incr('global:concurrent')
def decrement_concurrent(self):
"""Decrement concurrent request counter"""
self.redis.decr('global:concurrent')
def get_client_stats(self, client_id: str) -> Optional[Dict]:
"""Get statistics for a client"""
stats = {
'requests_last_minute': 0,
'requests_last_hour': 0,
'buckets': {}
}
# Get request counts from all endpoints
for endpoint in self.endpoint_limits.keys():
minute_key = f'window:{client_id}:{endpoint}:minute'
hour_key = f'window:{client_id}:{endpoint}:hour'
# This is approximate since we're using sliding windows
minute_count = self.redis.scard(minute_key)
hour_count = self.redis.scard(hour_key)
stats['requests_last_minute'] += minute_count
stats['requests_last_hour'] += hour_count
# Get bucket info
bucket_key = f'bucket:{client_id}:{endpoint}'
bucket_data = self.redis.hgetall(bucket_key)
if bucket_data:
stats['buckets'][endpoint] = {
'tokens': float(bucket_data.get('tokens', 0)),
'last_update': float(bucket_data.get('last_update', 0))
}
return stats
def rate_limit(endpoint=None,
requests_per_minute=None,
requests_per_hour=None,
burst_size=None,
check_size=False):
"""
Rate limiting decorator for Flask routes using Redis
"""
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
# Get Redis rate limiter from app
from flask import current_app
if not hasattr(current_app, 'redis_rate_limiter'):
# No Redis rate limiter, execute function normally
return f(*args, **kwargs)
rate_limiter = current_app.redis_rate_limiter
# Get client ID
client_id = rate_limiter.get_client_id(request)
ip = request.remote_addr
# Check if IP is blocked
if rate_limiter.is_ip_blocked(ip):
return jsonify({
'error': 'IP temporarily blocked due to excessive requests'
}), 429
# Get endpoint
endpoint_path = endpoint or request.endpoint
# Override default limits if specified
if any([requests_per_minute, requests_per_hour, burst_size]):
limits = rate_limiter.get_limits(endpoint_path).copy()
if requests_per_minute:
limits['requests_per_minute'] = requests_per_minute
if requests_per_hour:
limits['requests_per_hour'] = requests_per_hour
if burst_size:
limits['burst_size'] = burst_size
rate_limiter.endpoint_limits[endpoint_path] = limits
# Check request size if needed
request_size = 0
if check_size:
request_size = request.content_length or 0
# Check rate limit
allowed, message, headers = rate_limiter.check_rate_limit(
client_id, endpoint_path, request_size
)
if not allowed:
# Log excessive requests
logger.warning(f"Rate limit exceeded for {client_id} on {endpoint_path}: {message}")
# Check if we should temporarily block this IP
stats = rate_limiter.get_client_stats(client_id)
if stats and stats['requests_last_minute'] > 100:
rate_limiter.block_ip_temporarily(ip, 3600) # 1 hour block
response = jsonify({
'error': message,
'retry_after': headers.get('retry_after') if headers else 60
})
response.status_code = 429
# Add rate limit headers
if headers:
response.headers['X-RateLimit-Limit'] = str(headers['limit'])
response.headers['X-RateLimit-Remaining'] = str(headers['remaining'])
response.headers['X-RateLimit-Reset'] = str(headers['reset'])
response.headers['Retry-After'] = str(headers['retry_after'])
return response
# Track concurrent requests
rate_limiter.increment_concurrent()
try:
# Add rate limit info to response
g.rate_limit_headers = headers
response = f(*args, **kwargs)
# Add headers to successful response
if headers and hasattr(response, 'headers'):
response.headers['X-RateLimit-Limit'] = str(headers['limit'])
response.headers['X-RateLimit-Remaining'] = str(headers['remaining'])
response.headers['X-RateLimit-Reset'] = str(headers['reset'])
return response
finally:
rate_limiter.decrement_concurrent()
return decorated_function
return decorator