This commit introduces major enhancements to Talk2Me: ## Database Integration - PostgreSQL support with SQLAlchemy ORM - Redis integration for caching and real-time analytics - Automated database initialization scripts - Migration support infrastructure ## User Authentication System - JWT-based API authentication - Session-based web authentication - API key authentication for programmatic access - User roles and permissions (admin/user) - Login history and session tracking - Rate limiting per user with customizable limits ## Admin Dashboard - Real-time analytics and monitoring - User management interface (create, edit, delete users) - System health monitoring - Request/error tracking - Language pair usage statistics - Performance metrics visualization ## Key Features - Dual authentication support (token + user accounts) - Graceful fallback for missing services - Non-blocking analytics middleware - Comprehensive error handling - Session management with security features ## Bug Fixes - Fixed rate limiting bypass for admin routes - Added missing email validation method - Improved error handling for missing database tables - Fixed session-based authentication for API endpoints 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
352 lines
14 KiB
Python
352 lines
14 KiB
Python
"""User-specific rate limiting that integrates with authentication"""
|
|
import time
|
|
import logging
|
|
from functools import wraps
|
|
from flask import request, jsonify, g
|
|
from collections import defaultdict, deque
|
|
from threading import Lock
|
|
from datetime import datetime, timedelta
|
|
|
|
from auth import get_current_user
|
|
from auth_models import User
|
|
from database import db
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class UserRateLimiter:
|
|
"""Enhanced rate limiter with user-specific limits"""
|
|
|
|
def __init__(self, default_limiter):
|
|
self.default_limiter = default_limiter
|
|
self.user_buckets = defaultdict(lambda: {
|
|
'tokens': 0,
|
|
'last_update': time.time(),
|
|
'requests': deque(maxlen=1000)
|
|
})
|
|
self.lock = Lock()
|
|
|
|
def get_user_limits(self, user: User, endpoint: str):
|
|
"""Get rate limits for a specific user"""
|
|
# Start with endpoint-specific or default limits
|
|
base_limits = self.default_limiter.get_limits(endpoint)
|
|
|
|
if not user:
|
|
return base_limits
|
|
|
|
# Override with user-specific limits
|
|
user_limits = {
|
|
'requests_per_minute': user.rate_limit_per_minute,
|
|
'requests_per_hour': user.rate_limit_per_hour,
|
|
'requests_per_day': user.rate_limit_per_day,
|
|
'burst_size': base_limits.get('burst_size', 10),
|
|
'token_refresh_rate': user.rate_limit_per_minute / 60.0 # Convert to per-second
|
|
}
|
|
|
|
# Admin users get higher limits
|
|
if user.is_admin:
|
|
user_limits['requests_per_minute'] *= 10
|
|
user_limits['requests_per_hour'] *= 10
|
|
user_limits['requests_per_day'] *= 10
|
|
user_limits['burst_size'] *= 5
|
|
|
|
return user_limits
|
|
|
|
def check_user_rate_limit(self, user: User, endpoint: str, request_size: int = 0):
|
|
"""Check rate limit for authenticated user"""
|
|
if not user:
|
|
# Fall back to IP-based limiting
|
|
client_id = self.default_limiter.get_client_id(request)
|
|
return self.default_limiter.check_rate_limit(client_id, endpoint, request_size)
|
|
|
|
with self.lock:
|
|
user_id = str(user.id)
|
|
limits = self.get_user_limits(user, endpoint)
|
|
|
|
# Get or create bucket for user
|
|
bucket = self.user_buckets[user_id]
|
|
now = time.time()
|
|
|
|
# Update tokens based on time passed
|
|
time_passed = now - bucket['last_update']
|
|
new_tokens = time_passed * limits['token_refresh_rate']
|
|
bucket['tokens'] = min(
|
|
limits['burst_size'],
|
|
bucket['tokens'] + new_tokens
|
|
)
|
|
bucket['last_update'] = now
|
|
|
|
# Clean old requests from sliding windows
|
|
minute_ago = now - 60
|
|
hour_ago = now - 3600
|
|
day_ago = now - 86400
|
|
|
|
bucket['requests'] = deque(
|
|
(t for t in bucket['requests'] if t > day_ago),
|
|
maxlen=1000
|
|
)
|
|
|
|
# Count requests in windows
|
|
requests_last_minute = sum(1 for t in bucket['requests'] if t > minute_ago)
|
|
requests_last_hour = sum(1 for t in bucket['requests'] if t > hour_ago)
|
|
requests_last_day = len(bucket['requests'])
|
|
|
|
# Check limits
|
|
if requests_last_minute >= limits['requests_per_minute']:
|
|
return False, "Rate limit exceeded (per minute)", {
|
|
'retry_after': 60,
|
|
'limit': limits['requests_per_minute'],
|
|
'remaining': 0,
|
|
'reset': int(minute_ago + 60),
|
|
'scope': 'user'
|
|
}
|
|
|
|
if requests_last_hour >= limits['requests_per_hour']:
|
|
return False, "Rate limit exceeded (per hour)", {
|
|
'retry_after': 3600,
|
|
'limit': limits['requests_per_hour'],
|
|
'remaining': 0,
|
|
'reset': int(hour_ago + 3600),
|
|
'scope': 'user'
|
|
}
|
|
|
|
if requests_last_day >= limits['requests_per_day']:
|
|
return False, "Rate limit exceeded (per day)", {
|
|
'retry_after': 86400,
|
|
'limit': limits['requests_per_day'],
|
|
'remaining': 0,
|
|
'reset': int(day_ago + 86400),
|
|
'scope': 'user'
|
|
}
|
|
|
|
# Check token bucket
|
|
if bucket['tokens'] < 1:
|
|
retry_after = int(1 / limits['token_refresh_rate'])
|
|
return False, "Rate limit exceeded (burst)", {
|
|
'retry_after': retry_after,
|
|
'limit': limits['burst_size'],
|
|
'remaining': 0,
|
|
'reset': int(now + retry_after),
|
|
'scope': 'user'
|
|
}
|
|
|
|
# Request allowed
|
|
bucket['tokens'] -= 1
|
|
bucket['requests'].append(now)
|
|
|
|
# Calculate remaining
|
|
remaining_minute = limits['requests_per_minute'] - requests_last_minute - 1
|
|
remaining_hour = limits['requests_per_hour'] - requests_last_hour - 1
|
|
remaining_day = limits['requests_per_day'] - requests_last_day - 1
|
|
|
|
return True, None, {
|
|
'limit_minute': limits['requests_per_minute'],
|
|
'limit_hour': limits['requests_per_hour'],
|
|
'limit_day': limits['requests_per_day'],
|
|
'remaining_minute': remaining_minute,
|
|
'remaining_hour': remaining_hour,
|
|
'remaining_day': remaining_day,
|
|
'reset': int(minute_ago + 60),
|
|
'scope': 'user'
|
|
}
|
|
|
|
def get_user_usage_stats(self, user: User):
|
|
"""Get usage statistics for a user"""
|
|
if not user:
|
|
return None
|
|
|
|
with self.lock:
|
|
user_id = str(user.id)
|
|
if user_id not in self.user_buckets:
|
|
return {
|
|
'requests_last_minute': 0,
|
|
'requests_last_hour': 0,
|
|
'requests_last_day': 0,
|
|
'tokens_available': 0
|
|
}
|
|
|
|
bucket = self.user_buckets[user_id]
|
|
now = time.time()
|
|
minute_ago = now - 60
|
|
hour_ago = now - 3600
|
|
day_ago = now - 86400
|
|
|
|
requests_last_minute = sum(1 for t in bucket['requests'] if t > minute_ago)
|
|
requests_last_hour = sum(1 for t in bucket['requests'] if t > hour_ago)
|
|
requests_last_day = sum(1 for t in bucket['requests'] if t > day_ago)
|
|
|
|
return {
|
|
'requests_last_minute': requests_last_minute,
|
|
'requests_last_hour': requests_last_hour,
|
|
'requests_last_day': requests_last_day,
|
|
'tokens_available': bucket['tokens'],
|
|
'last_request': bucket['last_update']
|
|
}
|
|
|
|
def reset_user_limits(self, user: User):
|
|
"""Reset rate limits for a user (admin action)"""
|
|
if not user:
|
|
return False
|
|
|
|
with self.lock:
|
|
user_id = str(user.id)
|
|
if user_id in self.user_buckets:
|
|
del self.user_buckets[user_id]
|
|
return True
|
|
return False
|
|
|
|
|
|
# Global user rate limiter instance
|
|
from rate_limiter import rate_limiter as default_rate_limiter
|
|
user_rate_limiter = UserRateLimiter(default_rate_limiter)
|
|
|
|
|
|
def user_aware_rate_limit(endpoint=None,
|
|
requests_per_minute=None,
|
|
requests_per_hour=None,
|
|
requests_per_day=None,
|
|
burst_size=None,
|
|
check_size=False,
|
|
require_auth=False):
|
|
"""
|
|
Enhanced rate limiting decorator that considers user authentication
|
|
|
|
Usage:
|
|
@app.route('/api/endpoint')
|
|
@user_aware_rate_limit(requests_per_minute=10)
|
|
def endpoint():
|
|
return jsonify({'status': 'ok'})
|
|
"""
|
|
def decorator(f):
|
|
@wraps(f)
|
|
def decorated_function(*args, **kwargs):
|
|
# Get current user (if authenticated)
|
|
user = get_current_user()
|
|
|
|
# If auth is required but no user, return 401
|
|
if require_auth and not user:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': 'Authentication required',
|
|
'code': 'auth_required'
|
|
}), 401
|
|
|
|
# Get endpoint
|
|
endpoint_path = endpoint or request.endpoint
|
|
|
|
# Check request size if needed
|
|
request_size = 0
|
|
if check_size:
|
|
request_size = request.content_length or 0
|
|
|
|
# Check rate limit
|
|
if user:
|
|
# User-specific rate limiting
|
|
allowed, message, headers = user_rate_limiter.check_user_rate_limit(
|
|
user, endpoint_path, request_size
|
|
)
|
|
else:
|
|
# Fall back to IP-based rate limiting
|
|
client_id = default_rate_limiter.get_client_id(request)
|
|
allowed, message, headers = default_rate_limiter.check_rate_limit(
|
|
client_id, endpoint_path, request_size
|
|
)
|
|
|
|
if not allowed:
|
|
# Log excessive requests
|
|
identifier = f"user:{user.username}" if user else f"ip:{request.remote_addr}"
|
|
logger.warning(f"Rate limit exceeded for {identifier} on {endpoint_path}: {message}")
|
|
|
|
# Update user stats if authenticated
|
|
if user:
|
|
user.last_active_at = datetime.utcnow()
|
|
db.session.commit()
|
|
|
|
response = jsonify({
|
|
'success': False,
|
|
'error': message,
|
|
'retry_after': headers.get('retry_after') if headers else 60
|
|
})
|
|
response.status_code = 429
|
|
|
|
# Add rate limit headers
|
|
if headers:
|
|
if headers.get('scope') == 'user':
|
|
response.headers['X-RateLimit-Limit'] = str(headers.get('limit_minute', 60))
|
|
response.headers['X-RateLimit-Remaining'] = str(headers.get('remaining_minute', 0))
|
|
response.headers['X-RateLimit-Limit-Hour'] = str(headers.get('limit_hour', 1000))
|
|
response.headers['X-RateLimit-Remaining-Hour'] = str(headers.get('remaining_hour', 0))
|
|
response.headers['X-RateLimit-Limit-Day'] = str(headers.get('limit_day', 10000))
|
|
response.headers['X-RateLimit-Remaining-Day'] = str(headers.get('remaining_day', 0))
|
|
else:
|
|
response.headers['X-RateLimit-Limit'] = str(headers['limit'])
|
|
response.headers['X-RateLimit-Remaining'] = str(headers['remaining'])
|
|
response.headers['X-RateLimit-Reset'] = str(headers['reset'])
|
|
response.headers['Retry-After'] = str(headers['retry_after'])
|
|
|
|
return response
|
|
|
|
# Track concurrent requests
|
|
default_rate_limiter.increment_concurrent()
|
|
|
|
try:
|
|
# Store user in g if authenticated
|
|
if user:
|
|
g.current_user = user
|
|
|
|
# Add rate limit info to response
|
|
g.rate_limit_headers = headers
|
|
response = f(*args, **kwargs)
|
|
|
|
# Add headers to successful response
|
|
if headers and hasattr(response, 'headers'):
|
|
if headers.get('scope') == 'user':
|
|
response.headers['X-RateLimit-Limit'] = str(headers.get('limit_minute', 60))
|
|
response.headers['X-RateLimit-Remaining'] = str(headers.get('remaining_minute', 0))
|
|
response.headers['X-RateLimit-Limit-Hour'] = str(headers.get('limit_hour', 1000))
|
|
response.headers['X-RateLimit-Remaining-Hour'] = str(headers.get('remaining_hour', 0))
|
|
response.headers['X-RateLimit-Limit-Day'] = str(headers.get('limit_day', 10000))
|
|
response.headers['X-RateLimit-Remaining-Day'] = str(headers.get('remaining_day', 0))
|
|
else:
|
|
response.headers['X-RateLimit-Limit'] = str(headers.get('limit', 60))
|
|
response.headers['X-RateLimit-Remaining'] = str(headers.get('remaining', 0))
|
|
response.headers['X-RateLimit-Reset'] = str(headers['reset'])
|
|
|
|
return response
|
|
finally:
|
|
default_rate_limiter.decrement_concurrent()
|
|
|
|
return decorated_function
|
|
return decorator
|
|
|
|
|
|
def get_user_rate_limit_status(user: User = None):
|
|
"""Get current rate limit status for a user or IP"""
|
|
if not user:
|
|
user = get_current_user()
|
|
|
|
if user:
|
|
stats = user_rate_limiter.get_user_usage_stats(user)
|
|
limits = user_rate_limiter.get_user_limits(user, request.endpoint or '/')
|
|
|
|
return {
|
|
'type': 'user',
|
|
'identifier': user.username,
|
|
'limits': {
|
|
'per_minute': limits['requests_per_minute'],
|
|
'per_hour': limits['requests_per_hour'],
|
|
'per_day': limits['requests_per_day']
|
|
},
|
|
'usage': stats
|
|
}
|
|
else:
|
|
# IP-based stats
|
|
client_id = default_rate_limiter.get_client_id(request)
|
|
stats = default_rate_limiter.get_client_stats(client_id)
|
|
|
|
return {
|
|
'type': 'ip',
|
|
'identifier': request.remote_addr,
|
|
'limits': default_rate_limiter.default_limits,
|
|
'usage': stats
|
|
} |