"""Authentication and authorization utilities for Talk2Me""" import os import uuid import functools from datetime import datetime, timedelta, timezone from typing import Optional, Dict, Any, Callable, Union, List from flask import request, jsonify, g, current_app from flask_jwt_extended import ( JWTManager, create_access_token, create_refresh_token, get_jwt_identity, jwt_required, get_jwt, verify_jwt_in_request ) from werkzeug.exceptions import Unauthorized from sqlalchemy.exc import IntegrityError from database import db from auth_models import User, LoginHistory, UserSession, RevokedToken, bcrypt from error_logger import log_exception # Initialize JWT Manager jwt = JWTManager() def init_auth(app): """Initialize authentication system with app""" # Configure JWT app.config['JWT_SECRET_KEY'] = app.config.get('JWT_SECRET_KEY', os.environ.get('JWT_SECRET_KEY', 'your-secret-key-change-in-production')) app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(hours=1) app.config['JWT_REFRESH_TOKEN_EXPIRES'] = timedelta(days=30) app.config['JWT_ALGORITHM'] = 'HS256' app.config['JWT_BLACKLIST_ENABLED'] = True app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = ['access', 'refresh'] # Initialize JWT manager jwt.init_app(app) # Initialize bcrypt bcrypt.init_app(app) # Register JWT callbacks @jwt.token_in_blocklist_loader def check_if_token_revoked(jwt_header, jwt_payload): jti = jwt_payload["jti"] return RevokedToken.is_token_revoked(jti) @jwt.expired_token_loader def expired_token_callback(jwt_header, jwt_payload): return jsonify({ 'success': False, 'error': 'Token has expired', 'code': 'token_expired' }), 401 @jwt.invalid_token_loader def invalid_token_callback(error): return jsonify({ 'success': False, 'error': 'Invalid token', 'code': 'invalid_token' }), 401 @jwt.unauthorized_loader def missing_token_callback(error): return jsonify({ 'success': False, 'error': 'Authorization required', 'code': 'authorization_required' }), 401 @jwt.revoked_token_loader def revoked_token_callback(jwt_header, jwt_payload): return jsonify({ 'success': False, 'error': 'Token has been revoked', 'code': 'token_revoked' }), 401 def create_user(email: str, username: str, password: str, full_name: Optional[str] = None, role: str = 'user', is_verified: bool = False) -> tuple[Optional[User], Optional[str]]: """Create a new user account""" try: # Check if user already exists if User.query.filter((User.email == email) | (User.username == username)).first(): return None, "User with this email or username already exists" # Create user user = User( email=email, username=username, full_name=full_name, role=role, is_verified=is_verified ) user.set_password(password) db.session.add(user) db.session.commit() return user, None except IntegrityError: db.session.rollback() return None, "User with this email or username already exists" except Exception as e: db.session.rollback() log_exception(e, "Failed to create user") return None, "Failed to create user account" def authenticate_user(username_or_email: str, password: str) -> tuple[Optional[User], Optional[str]]: """Authenticate user with username/email and password""" # Find user by username or email user = User.query.filter( (User.username == username_or_email) | (User.email == username_or_email) ).first() if not user: return None, "Invalid credentials" # Check if user can login can_login, reason = user.can_login() if not can_login: user.record_login_attempt(False) db.session.commit() return None, reason # Verify password if not user.check_password(password): user.record_login_attempt(False) db.session.commit() return None, "Invalid credentials" # Success user.record_login_attempt(True) db.session.commit() return user, None def authenticate_api_key(api_key: str) -> tuple[Optional[User], Optional[str]]: """Authenticate user with API key""" user = User.query.filter_by(api_key=api_key).first() if not user: return None, "Invalid API key" # Check if user can login can_login, reason = user.can_login() if not can_login: return None, reason # Update last active user.last_active_at = datetime.utcnow() db.session.commit() return user, None def create_tokens(user: User, session_id: Optional[str] = None) -> Dict[str, Any]: """Create JWT tokens for user""" # Generate JTIs access_jti = str(uuid.uuid4()) refresh_jti = str(uuid.uuid4()) # Create tokens with custom claims identity = str(user.id) additional_claims = { 'username': user.username, 'role': user.role, 'permissions': user.permissions or [], 'session_id': session_id } access_token = create_access_token( identity=identity, additional_claims=additional_claims, fresh=True ) refresh_token = create_refresh_token( identity=identity, additional_claims={'session_id': session_id} ) return { 'access_token': access_token, 'refresh_token': refresh_token, 'token_type': 'Bearer', 'expires_in': current_app.config['JWT_ACCESS_TOKEN_EXPIRES'].total_seconds() } def create_user_session(user: User, request_info: Dict[str, Any]) -> UserSession: """Create a new user session""" session = UserSession( session_id=str(uuid.uuid4()), user_id=user.id, ip_address=request_info.get('ip_address'), user_agent=request_info.get('user_agent'), expires_at=datetime.utcnow() + timedelta(days=30) ) db.session.add(session) db.session.commit() return session def log_login_attempt(user_id: Optional[uuid.UUID], success: bool, method: str, failure_reason: Optional[str] = None, session_id: Optional[str] = None, jwt_jti: Optional[str] = None) -> LoginHistory: """Log a login attempt""" login_record = LoginHistory( user_id=user_id, login_method=method, success=success, failure_reason=failure_reason, session_id=session_id, jwt_jti=jwt_jti, ip_address=request.remote_addr, user_agent=request.headers.get('User-Agent') ) db.session.add(login_record) db.session.commit() return login_record def revoke_token(jti: str, token_type: str, user_id: Optional[uuid.UUID] = None, reason: Optional[str] = None, expires_at: Optional[datetime] = None): """Revoke a JWT token""" if not expires_at: # Default expiration based on token type if token_type == 'access': expires_at = datetime.utcnow() + current_app.config['JWT_ACCESS_TOKEN_EXPIRES'] else: expires_at = datetime.utcnow() + current_app.config['JWT_REFRESH_TOKEN_EXPIRES'] revoked = RevokedToken( jti=jti, token_type=token_type, user_id=user_id, reason=reason, expires_at=expires_at ) db.session.add(revoked) db.session.commit() def get_current_user() -> Optional[User]: """Get current authenticated user from JWT, API key, or session""" # Try JWT first try: verify_jwt_in_request(optional=True) user_id = get_jwt_identity() if user_id: user = User.query.get(user_id) if user and user.is_active and not user.is_suspended_now: # Update last active user.last_active_at = datetime.utcnow() db.session.commit() return user except: pass # Try API key from header api_key = request.headers.get('X-API-Key') if api_key: user, _ = authenticate_api_key(api_key) if user: return user # Try API key from query parameter api_key = request.args.get('api_key') if api_key: user, _ = authenticate_api_key(api_key) if user: return user # Try session-based authentication (for admin panel) from flask import session if session.get('logged_in') and session.get('user_id'): # Check if it's the admin token user if session.get('user_id') == 'admin-token-user' and session.get('user_role') == 'admin': # Create a pseudo-admin user for session-based admin access admin_user = User.query.filter_by(role='admin').first() if admin_user: return admin_user else: # Create a temporary admin user object (not saved to DB) admin_user = User( id=uuid.uuid4(), username='admin', email='admin@talk2me.local', role='admin', is_active=True, is_verified=True, is_suspended=False, total_requests=0, total_translations=0, total_transcriptions=0, total_tts_requests=0 ) # Don't add to session, just return for authorization return admin_user else: # Regular user session user = User.query.get(session.get('user_id')) if user and user.is_active and not user.is_suspended_now: # Update last active user.last_active_at = datetime.utcnow() db.session.commit() return user return None def require_auth(f: Callable) -> Callable: """Decorator to require authentication (JWT, API key, or session)""" @functools.wraps(f) def decorated_function(*args, **kwargs): user = get_current_user() if not user: return jsonify({ 'success': False, 'error': 'Authentication required', 'code': 'auth_required' }), 401 # Store user in g for access in route g.current_user = user # Track usage only for database-backed users try: if hasattr(user, 'id') and db.session.query(User).filter_by(id=user.id).first(): user.total_requests += 1 db.session.commit() except Exception as e: # Ignore tracking errors for temporary users pass return f(*args, **kwargs) return decorated_function def require_admin(f: Callable) -> Callable: """Decorator to require admin role""" @functools.wraps(f) @require_auth def decorated_function(*args, **kwargs): if not g.current_user.is_admin: return jsonify({ 'success': False, 'error': 'Admin access required', 'code': 'admin_required' }), 403 return f(*args, **kwargs) return decorated_function def require_permission(permission: str) -> Callable: """Decorator to require specific permission""" def decorator(f: Callable) -> Callable: @functools.wraps(f) @require_auth def decorated_function(*args, **kwargs): if not g.current_user.has_permission(permission): return jsonify({ 'success': False, 'error': f'Permission required: {permission}', 'code': 'permission_denied' }), 403 return f(*args, **kwargs) return decorated_function return decorator def require_verified(f: Callable) -> Callable: """Decorator to require verified email""" @functools.wraps(f) @require_auth def decorated_function(*args, **kwargs): if not g.current_user.is_verified: return jsonify({ 'success': False, 'error': 'Email verification required', 'code': 'verification_required' }), 403 return f(*args, **kwargs) return decorated_function def get_user_rate_limits(user: User) -> Dict[str, int]: """Get user-specific rate limits""" return { 'per_minute': user.rate_limit_per_minute, 'per_hour': user.rate_limit_per_hour, 'per_day': user.rate_limit_per_day } def check_user_rate_limit(user: User, endpoint: str) -> tuple[bool, Optional[str]]: """Check if user has exceeded rate limits""" # This would integrate with the existing rate limiter # For now, return True to allow requests return True, None def update_user_usage_stats(user: User, operation: str) -> None: """Update user usage statistics""" user.total_requests += 1 if operation == 'translation': user.total_translations += 1 elif operation == 'transcription': user.total_transcriptions += 1 elif operation == 'tts': user.total_tts_requests += 1 user.last_active_at = datetime.utcnow() db.session.commit() def cleanup_expired_sessions() -> int: """Clean up expired user sessions""" deleted = UserSession.query.filter( UserSession.expires_at < datetime.utcnow() ).delete() db.session.commit() return deleted def cleanup_expired_tokens() -> int: """Clean up expired revoked tokens""" return RevokedToken.cleanup_expired() def get_user_sessions(user_id: Union[str, uuid.UUID]) -> List[UserSession]: """Get all active sessions for a user""" return UserSession.query.filter_by( user_id=user_id ).filter( UserSession.expires_at > datetime.utcnow() ).order_by(UserSession.last_active_at.desc()).all() def revoke_user_sessions(user_id: Union[str, uuid.UUID], except_session: Optional[str] = None) -> int: """Revoke all sessions for a user""" sessions = UserSession.query.filter_by(user_id=user_id) if except_session: sessions = sessions.filter(UserSession.session_id != except_session) count = 0 for session in sessions: # Revoke associated tokens if session.access_token_jti: revoke_token(session.access_token_jti, 'access', user_id, 'Session revoked') if session.refresh_token_jti: revoke_token(session.refresh_token_jti, 'refresh', user_id, 'Session revoked') count += 1 # Delete sessions sessions.delete() db.session.commit() return count