"""Authentication models for Talk2Me application""" import uuid import secrets from datetime import datetime, timedelta from typing import Optional, Dict, Any, List from flask_sqlalchemy import SQLAlchemy from flask_bcrypt import Bcrypt from sqlalchemy import Index, text, func from sqlalchemy.dialects.postgresql import UUID, JSONB, ENUM from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import relationship from database import db bcrypt = Bcrypt() class User(db.Model): """User account model with authentication and authorization""" __tablename__ = 'users' id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) email = db.Column(db.String(255), unique=True, nullable=False, index=True) username = db.Column(db.String(100), unique=True, nullable=False, index=True) password_hash = db.Column(db.String(255), nullable=False) # User profile full_name = db.Column(db.String(255), nullable=True) avatar_url = db.Column(db.String(500), nullable=True) # API Key - unique per user api_key = db.Column(db.String(64), unique=True, nullable=False, index=True) api_key_created_at = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) # Account status is_active = db.Column(db.Boolean, default=True, nullable=False) is_verified = db.Column(db.Boolean, default=False, nullable=False) is_suspended = db.Column(db.Boolean, default=False, nullable=False) suspension_reason = db.Column(db.Text, nullable=True) suspended_at = db.Column(db.DateTime, nullable=True) suspended_until = db.Column(db.DateTime, nullable=True) # Role and permissions role = db.Column(db.String(20), nullable=False, default='user') # admin, user permissions = db.Column(JSONB, default=[], nullable=False) # Additional granular permissions # Usage limits (per user) rate_limit_per_minute = db.Column(db.Integer, default=30, nullable=False) rate_limit_per_hour = db.Column(db.Integer, default=500, nullable=False) rate_limit_per_day = db.Column(db.Integer, default=5000, nullable=False) # Usage tracking total_requests = db.Column(db.Integer, default=0, nullable=False) total_translations = db.Column(db.Integer, default=0, nullable=False) total_transcriptions = db.Column(db.Integer, default=0, nullable=False) total_tts_requests = db.Column(db.Integer, default=0, nullable=False) # Timestamps created_at = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) updated_at = db.Column(db.DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow) last_login_at = db.Column(db.DateTime, nullable=True) last_active_at = db.Column(db.DateTime, nullable=True) # Security password_changed_at = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) failed_login_attempts = db.Column(db.Integer, default=0, nullable=False) locked_until = db.Column(db.DateTime, nullable=True) # Settings settings = db.Column(JSONB, default={}, nullable=False) # Relationships login_history = relationship('LoginHistory', back_populates='user', cascade='all, delete-orphan') sessions = relationship('UserSession', back_populates='user', cascade='all, delete-orphan') __table_args__ = ( Index('idx_users_email_active', 'email', 'is_active'), Index('idx_users_role_active', 'role', 'is_active'), Index('idx_users_created_at', 'created_at'), ) def __init__(self, **kwargs): super(User, self).__init__(**kwargs) if not self.api_key: self.api_key = self.generate_api_key() @staticmethod def generate_api_key() -> str: """Generate a secure API key""" return f"tk_{secrets.token_urlsafe(32)}" def regenerate_api_key(self) -> str: """Regenerate user's API key""" self.api_key = self.generate_api_key() self.api_key_created_at = datetime.utcnow() return self.api_key def set_password(self, password: str) -> None: """Hash and set user password""" self.password_hash = bcrypt.generate_password_hash(password).decode('utf-8') self.password_changed_at = datetime.utcnow() def check_password(self, password: str) -> bool: """Check if provided password matches hash""" return bcrypt.check_password_hash(self.password_hash, password) @hybrid_property def is_admin(self) -> bool: """Check if user has admin role""" return self.role == 'admin' @hybrid_property def is_locked(self) -> bool: """Check if account is locked due to failed login attempts""" if self.locked_until is None: return False return datetime.utcnow() < self.locked_until @hybrid_property def is_suspended_now(self) -> bool: """Check if account is currently suspended""" if not self.is_suspended: return False if self.suspended_until is None: return True # Indefinite suspension return datetime.utcnow() < self.suspended_until def can_login(self) -> tuple[bool, Optional[str]]: """Check if user can login""" if not self.is_active: return False, "Account is deactivated" if self.is_locked: return False, "Account is locked due to failed login attempts" if self.is_suspended_now: return False, f"Account is suspended: {self.suspension_reason or 'Policy violation'}" return True, None def record_login_attempt(self, success: bool) -> None: """Record login attempt and handle lockout""" if success: self.failed_login_attempts = 0 self.locked_until = None self.last_login_at = datetime.utcnow() else: self.failed_login_attempts += 1 # Lock account after 5 failed attempts if self.failed_login_attempts >= 5: self.locked_until = datetime.utcnow() + timedelta(minutes=30) def has_permission(self, permission: str) -> bool: """Check if user has specific permission""" if self.is_admin: return True # Admins have all permissions return permission in (self.permissions or []) def add_permission(self, permission: str) -> None: """Add permission to user""" if self.permissions is None: self.permissions = [] if permission not in self.permissions: self.permissions = self.permissions + [permission] def remove_permission(self, permission: str) -> None: """Remove permission from user""" if self.permissions and permission in self.permissions: self.permissions = [p for p in self.permissions if p != permission] def suspend(self, reason: str, until: Optional[datetime] = None) -> None: """Suspend user account""" self.is_suspended = True self.suspension_reason = reason self.suspended_at = datetime.utcnow() self.suspended_until = until def unsuspend(self) -> None: """Unsuspend user account""" self.is_suspended = False self.suspension_reason = None self.suspended_at = None self.suspended_until = None def to_dict(self, include_sensitive: bool = False) -> Dict[str, Any]: """Convert user to dictionary""" data = { 'id': str(self.id), 'email': self.email, 'username': self.username, 'full_name': self.full_name, 'avatar_url': self.avatar_url, 'role': self.role, 'is_active': self.is_active, 'is_verified': self.is_verified, 'is_suspended': self.is_suspended_now, 'created_at': self.created_at.isoformat(), 'last_login_at': self.last_login_at.isoformat() if self.last_login_at else None, 'last_active_at': self.last_active_at.isoformat() if self.last_active_at else None, 'total_requests': self.total_requests, 'total_translations': self.total_translations, 'total_transcriptions': self.total_transcriptions, 'total_tts_requests': self.total_tts_requests, 'settings': self.settings or {} } if include_sensitive: data.update({ 'api_key': self.api_key, 'api_key_created_at': self.api_key_created_at.isoformat(), 'permissions': self.permissions or [], 'rate_limit_per_minute': self.rate_limit_per_minute, 'rate_limit_per_hour': self.rate_limit_per_hour, 'rate_limit_per_day': self.rate_limit_per_day, 'suspension_reason': self.suspension_reason, 'suspended_until': self.suspended_until.isoformat() if self.suspended_until else None, 'failed_login_attempts': self.failed_login_attempts, 'is_locked': self.is_locked }) return data class LoginHistory(db.Model): """Track user login history for security auditing""" __tablename__ = 'login_history' id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) user_id = db.Column(UUID(as_uuid=True), db.ForeignKey('users.id'), nullable=False, index=True) # Login details login_at = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) logout_at = db.Column(db.DateTime, nullable=True) login_method = db.Column(db.String(20), nullable=False) # password, api_key, jwt success = db.Column(db.Boolean, nullable=False) failure_reason = db.Column(db.String(255), nullable=True) # Session info session_id = db.Column(db.String(255), nullable=True, index=True) jwt_jti = db.Column(db.String(255), nullable=True, index=True) # JWT ID for revocation # Client info ip_address = db.Column(db.String(45), nullable=False) user_agent = db.Column(db.String(500), nullable=True) device_info = db.Column(JSONB, nullable=True) # Parsed user agent info # Location info (if available) country = db.Column(db.String(2), nullable=True) city = db.Column(db.String(100), nullable=True) # Security flags is_suspicious = db.Column(db.Boolean, default=False, nullable=False) security_notes = db.Column(db.Text, nullable=True) # Relationship user = relationship('User', back_populates='login_history') __table_args__ = ( Index('idx_login_history_user_time', 'user_id', 'login_at'), Index('idx_login_history_session', 'session_id'), Index('idx_login_history_ip', 'ip_address'), ) def to_dict(self) -> Dict[str, Any]: """Convert login history to dictionary""" return { 'id': str(self.id), 'user_id': str(self.user_id), 'login_at': self.login_at.isoformat(), 'logout_at': self.logout_at.isoformat() if self.logout_at else None, 'login_method': self.login_method, 'success': self.success, 'failure_reason': self.failure_reason, 'session_id': self.session_id, 'ip_address': self.ip_address, 'user_agent': self.user_agent, 'device_info': self.device_info, 'country': self.country, 'city': self.city, 'is_suspicious': self.is_suspicious, 'security_notes': self.security_notes } class UserSession(db.Model): """Active user sessions for session management""" __tablename__ = 'user_sessions' id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) session_id = db.Column(db.String(255), unique=True, nullable=False, index=True) user_id = db.Column(UUID(as_uuid=True), db.ForeignKey('users.id'), nullable=False, index=True) # JWT tokens access_token_jti = db.Column(db.String(255), nullable=True, index=True) refresh_token_jti = db.Column(db.String(255), nullable=True, index=True) # Session info created_at = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) last_active_at = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) expires_at = db.Column(db.DateTime, nullable=False) # Client info ip_address = db.Column(db.String(45), nullable=False) user_agent = db.Column(db.String(500), nullable=True) # Session data data = db.Column(JSONB, default={}, nullable=False) # Relationship user = relationship('User', back_populates='sessions') __table_args__ = ( Index('idx_user_sessions_user_active', 'user_id', 'expires_at'), Index('idx_user_sessions_token', 'access_token_jti'), ) @hybrid_property def is_expired(self) -> bool: """Check if session is expired""" return datetime.utcnow() > self.expires_at def refresh(self, duration_hours: int = 24) -> None: """Refresh session expiration""" self.last_active_at = datetime.utcnow() self.expires_at = datetime.utcnow() + timedelta(hours=duration_hours) def to_dict(self) -> Dict[str, Any]: """Convert session to dictionary""" return { 'id': str(self.id), 'session_id': self.session_id, 'user_id': str(self.user_id), 'created_at': self.created_at.isoformat(), 'last_active_at': self.last_active_at.isoformat(), 'expires_at': self.expires_at.isoformat(), 'is_expired': self.is_expired, 'ip_address': self.ip_address, 'user_agent': self.user_agent, 'data': self.data or {} } class RevokedToken(db.Model): """Store revoked JWT tokens for blacklisting""" __tablename__ = 'revoked_tokens' id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) jti = db.Column(db.String(255), unique=True, nullable=False, index=True) token_type = db.Column(db.String(20), nullable=False) # access, refresh user_id = db.Column(UUID(as_uuid=True), nullable=True, index=True) revoked_at = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) expires_at = db.Column(db.DateTime, nullable=False) # When token would have expired reason = db.Column(db.String(255), nullable=True) __table_args__ = ( Index('idx_revoked_tokens_expires', 'expires_at'), ) @classmethod def is_token_revoked(cls, jti: str) -> bool: """Check if a token JTI is revoked""" return cls.query.filter_by(jti=jti).first() is not None @classmethod def cleanup_expired(cls) -> int: """Remove revoked tokens that have expired anyway""" deleted = cls.query.filter(cls.expires_at < datetime.utcnow()).delete() db.session.commit() return deleted