diff --git a/packages/gateway/gateway/main.py b/packages/gateway/gateway/main.py index 3d1437a..4b9d684 100644 --- a/packages/gateway/gateway/main.py +++ b/packages/gateway/gateway/main.py @@ -43,6 +43,7 @@ from gateway.channels.whatsapp import whatsapp_router from shared.api import ( billing_router, channels_router, + invitations_router, llm_keys_router, portal_router, usage_router, @@ -134,6 +135,11 @@ app.include_router(llm_keys_router) app.include_router(usage_router) app.include_router(webhook_router) +# --------------------------------------------------------------------------- +# Register Phase 4 RBAC routers +# --------------------------------------------------------------------------- +app.include_router(invitations_router) + # --------------------------------------------------------------------------- # Routes diff --git a/packages/orchestrator/orchestrator/tasks.py b/packages/orchestrator/orchestrator/tasks.py index 94a2616..4f7e05f 100644 --- a/packages/orchestrator/orchestrator/tasks.py +++ b/packages/orchestrator/orchestrator/tasks.py @@ -170,6 +170,36 @@ async def _embed_and_store_async( current_tenant_id.reset(token) +@app.task( + name="orchestrator.tasks.send_invite_email_task", + bind=False, + max_retries=2, + default_retry_delay=30, + ignore_result=True, # Fire-and-forget — callers don't await the result +) +def send_invite_email_task( + to_email: str, + invitee_name: str, + tenant_name: str, + invite_url: str, +) -> None: + """ + Asynchronously send an invitation email via SMTP. + + Dispatched fire-and-forget by the invitation API after creating an invitation. + If SMTP is not configured, logs a warning and returns silently. + + Args: + to_email: Recipient email address. + invitee_name: Recipient display name. + tenant_name: Name of the tenant being joined. + invite_url: Full invitation acceptance URL. + """ + from shared.email import send_invite_email + + send_invite_email(to_email, invitee_name, tenant_name, invite_url) + + @app.task( name="orchestrator.tasks.handle_message", bind=True, diff --git a/packages/shared/shared/api/__init__.py b/packages/shared/shared/api/__init__.py index 20cc6b2..aadbc21 100644 --- a/packages/shared/shared/api/__init__.py +++ b/packages/shared/shared/api/__init__.py @@ -6,6 +6,7 @@ Import and mount these routers in service main.py files. from shared.api.billing import billing_router, webhook_router from shared.api.channels import channels_router +from shared.api.invitations import invitations_router from shared.api.llm_keys import llm_keys_router from shared.api.portal import portal_router from shared.api.usage import usage_router @@ -17,4 +18,5 @@ __all__ = [ "webhook_router", "llm_keys_router", "usage_router", + "invitations_router", ] diff --git a/packages/shared/shared/api/invitations.py b/packages/shared/shared/api/invitations.py new file mode 100644 index 0000000..8529d43 --- /dev/null +++ b/packages/shared/shared/api/invitations.py @@ -0,0 +1,367 @@ +""" +Invitation CRUD API router. + +Handles invite-only onboarding flow for new portal users: + POST /api/portal/invitations — Create invitation (tenant admin) + POST /api/portal/invitations/accept — Accept invitation, create account + POST /api/portal/invitations/{id}/resend — Resend email (tenant admin) + GET /api/portal/invitations — List pending invitations (tenant admin) + +Authentication model: + - Create/resend/list require tenant admin (X-Portal-* headers) + - Accept is unauthenticated (uses HMAC-signed token instead) + +Token flow: + 1. POST /invitations → generate HMAC token, store SHA-256(token) as token_hash + 2. Email includes full token in acceptance URL + 3. POST /invitations/accept → validate HMAC token, look up invitation by SHA-256(token) + 4. Create PortalUser + UserTenantRole, mark invitation accepted + +This keeps the raw token out of the DB while allowing secure lookup. +""" + +from __future__ import annotations + +import hashlib +import logging +import uuid +from datetime import datetime, timedelta, timezone +from typing import Any + +import bcrypt +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from shared.api.rbac import PortalCaller, get_portal_caller, require_tenant_admin +from shared.config import settings +from shared.db import get_session +from shared.invite_token import generate_invite_token, token_to_hash, validate_invite_token +from shared.models.auth import PortalInvitation, PortalUser, UserTenantRole +from shared.models.tenant import Tenant + +logger = logging.getLogger(__name__) + +invitations_router = APIRouter(prefix="/api/portal/invitations", tags=["invitations"]) + +_INVITE_TTL_HOURS = 48 + + +# --------------------------------------------------------------------------- +# Pydantic schemas +# --------------------------------------------------------------------------- + + +class InvitationCreate(BaseModel): + email: str + name: str + role: str + tenant_id: uuid.UUID + + +class InvitationResponse(BaseModel): + id: str + email: str + name: str + role: str + tenant_id: str + status: str + expires_at: datetime + created_at: datetime + token: str | None = None # Only included in create/resend responses + + +class InvitationAccept(BaseModel): + token: str + password: str + + +class AcceptResponse(BaseModel): + id: str + email: str + name: str + role: str + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _dispatch_invite_email( + to_email: str, + invitee_name: str, + tenant_name: str, + invite_url: str, +) -> None: + """ + Fire-and-forget Celery task dispatch for invitation email. + + Uses lazy import to avoid circular dependency: shared -> orchestrator -> shared. + Logs warning if orchestrator is not available (e.g. during unit testing). + """ + try: + from orchestrator.tasks import send_invite_email_task # noqa: PLC0415 + + send_invite_email_task.delay(to_email, invitee_name, tenant_name, invite_url) + except ImportError: + logger.warning( + "orchestrator not available — skipping invite email dispatch to %s", + to_email, + ) + + +async def _get_tenant_or_404(tenant_id: uuid.UUID, session: AsyncSession) -> Tenant: + result = await session.execute(select(Tenant).where(Tenant.id == tenant_id)) + tenant = result.scalar_one_or_none() + if tenant is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tenant not found") + return tenant + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + + +@invitations_router.post("", status_code=status.HTTP_201_CREATED, response_model=InvitationResponse) +async def create_invitation( + body: InvitationCreate, + caller: PortalCaller = Depends(get_portal_caller), + session: AsyncSession = Depends(get_session), +) -> Any: + """ + Create an invitation for a new user to join a tenant. + + Requires: tenant admin or platform admin. + Returns: invitation record + raw token (for display/copy in UI). + """ + await require_tenant_admin(body.tenant_id, caller, session) + tenant = await _get_tenant_or_404(body.tenant_id, session) + + invitation = PortalInvitation( + id=uuid.uuid4(), + email=body.email, + name=body.name, + tenant_id=body.tenant_id, + role=body.role, + invited_by=caller.user_id, + token_hash="placeholder", # Will be updated below + status="pending", + expires_at=datetime.now(tz=timezone.utc) + timedelta(hours=_INVITE_TTL_HOURS), + ) + session.add(invitation) + await session.flush() # Get the ID assigned + + # Generate token after we have the invitation ID + token = generate_invite_token(str(invitation.id)) + invitation.token_hash = token_to_hash(token) + await session.commit() + await session.refresh(invitation) + + # Build invite URL and dispatch email fire-and-forget + invite_url = f"{settings.portal_url}/invite/accept?token={token}" + _dispatch_invite_email(body.email, body.name, tenant.name, invite_url) + + return InvitationResponse( + id=str(invitation.id), + email=invitation.email, + name=invitation.name, + role=invitation.role, + tenant_id=str(invitation.tenant_id), + status=invitation.status, + expires_at=invitation.expires_at, + created_at=invitation.created_at, + token=token, # Include raw token in creation response + ) + + +@invitations_router.post("/accept", response_model=AcceptResponse) +async def accept_invitation( + body: InvitationAccept, + session: AsyncSession = Depends(get_session), +) -> Any: + """ + Accept an invitation and create a new portal user account. + + Validates the HMAC token, creates the user and tenant membership, and + marks the invitation as accepted. All DB operations run in one transaction. + """ + # Validate token (raises ValueError on tamper/expiry) + try: + invitation_id_str = validate_invite_token(body.token) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid or expired token: {exc}", + ) from exc + + try: + invitation_id = uuid.UUID(invitation_id_str) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Malformed token: invalid invitation ID", + ) from exc + + # Load and validate invitation + result = await session.execute( + select(PortalInvitation).where(PortalInvitation.id == invitation_id) + ) + invitation = result.scalar_one_or_none() + if invitation is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Invitation not found", + ) + + if invitation.status != "pending": + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Invitation already {invitation.status}", + ) + + now = datetime.now(tz=timezone.utc) + # Ensure expires_at is timezone-aware for comparison + expires = invitation.expires_at + if expires.tzinfo is None: + expires = expires.replace(tzinfo=timezone.utc) + if now > expires: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invitation has expired", + ) + + # Check email not already registered + existing = await session.execute( + select(PortalUser).where(PortalUser.email == invitation.email) + ) + if existing.scalar_one_or_none() is not None: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Email already registered", + ) + + # Create user + hashed = bcrypt.hashpw(body.password.encode(), bcrypt.gensalt()).decode() + user = PortalUser( + id=uuid.uuid4(), + email=invitation.email, + hashed_password=hashed, + name=invitation.name, + role=invitation.role, + ) + session.add(user) + await session.flush() + + # Create tenant membership + membership = UserTenantRole( + id=uuid.uuid4(), + user_id=user.id, + tenant_id=invitation.tenant_id, + role=invitation.role, + ) + session.add(membership) + + # Mark invitation accepted + invitation.status = "accepted" + + await session.commit() + await session.refresh(user) + + return AcceptResponse( + id=str(user.id), + email=user.email, + name=user.name, + role=user.role, + ) + + +@invitations_router.post("/{invitation_id}/resend", response_model=InvitationResponse) +async def resend_invitation( + invitation_id: uuid.UUID, + caller: PortalCaller = Depends(get_portal_caller), + session: AsyncSession = Depends(get_session), +) -> Any: + """ + Resend an invitation by generating a new token and extending expiry. + + Requires: tenant admin or platform admin. + """ + result = await session.execute( + select(PortalInvitation).where(PortalInvitation.id == invitation_id) + ) + invitation = result.scalar_one_or_none() + if invitation is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Invitation not found") + + await require_tenant_admin(invitation.tenant_id, caller, session) + tenant = await _get_tenant_or_404(invitation.tenant_id, session) + + # Generate new token and extend expiry + new_token = generate_invite_token(str(invitation.id)) + invitation.token_hash = token_to_hash(new_token) + invitation.expires_at = datetime.now(tz=timezone.utc) + timedelta(hours=_INVITE_TTL_HOURS) + invitation.status = "pending" # Re-open if it was revoked + + await session.commit() + await session.refresh(invitation) + + invite_url = f"{settings.portal_url}/invite/accept?token={new_token}" + _dispatch_invite_email(invitation.email, invitation.name, tenant.name, invite_url) + + return InvitationResponse( + id=str(invitation.id), + email=invitation.email, + name=invitation.name, + role=invitation.role, + tenant_id=str(invitation.tenant_id), + status=invitation.status, + expires_at=invitation.expires_at, + created_at=invitation.created_at, + token=new_token, + ) + + +@invitations_router.get("", response_model=list[InvitationResponse]) +async def list_invitations( + caller: PortalCaller = Depends(get_portal_caller), + session: AsyncSession = Depends(get_session), +) -> Any: + """ + List pending invitations for the caller's active tenant. + + Requires: tenant admin or platform admin. + The tenant is resolved from X-Portal-Tenant-Id header. + """ + if caller.tenant_id is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="X-Portal-Tenant-Id header required for listing invitations", + ) + + await require_tenant_admin(caller.tenant_id, caller, session) + + result = await session.execute( + select(PortalInvitation).where( + PortalInvitation.tenant_id == caller.tenant_id, + PortalInvitation.status == "pending", + ) + ) + invitations = result.scalars().all() + + return [ + InvitationResponse( + id=str(inv.id), + email=inv.email, + name=inv.name, + role=inv.role, + tenant_id=str(inv.tenant_id), + status=inv.status, + expires_at=inv.expires_at, + created_at=inv.created_at, + token=None, # Never expose token in list + ) + for inv in invitations + ] diff --git a/packages/shared/shared/api/portal.py b/packages/shared/shared/api/portal.py index 77882fc..6b30f15 100644 --- a/packages/shared/shared/api/portal.py +++ b/packages/shared/shared/api/portal.py @@ -19,8 +19,9 @@ from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from shared.api.rbac import PortalCaller, require_platform_admin from shared.db import get_session -from shared.models.auth import PortalUser +from shared.models.auth import PortalUser, UserTenantRole from shared.models.tenant import Agent, Tenant from shared.rls import current_tenant_id @@ -42,7 +43,9 @@ class AuthVerifyResponse(BaseModel): id: str email: str name: str - is_admin: bool + role: str + tenant_ids: list[str] + active_tenant_id: str | None class AuthRegisterRequest(BaseModel): @@ -55,7 +58,7 @@ class AuthRegisterResponse(BaseModel): id: str email: str name: str - is_admin: bool + role: str class TenantCreate(BaseModel): @@ -220,6 +223,10 @@ async def verify_credentials( Used by Auth.js v5 Credentials provider. Returns 401 on invalid credentials. Response deliberately omits hashed_password. + + Returns role + tenant_ids + active_tenant_id instead of is_admin: + - platform_admin: all tenant IDs from the tenants table + - customer_admin / customer_operator: only tenant IDs from user_tenant_roles """ result = await session.execute(select(PortalUser).where(PortalUser.email == body.email)) user = result.scalar_one_or_none() @@ -227,23 +234,44 @@ async def verify_credentials( if user is None or not bcrypt.checkpw(body.password.encode(), user.hashed_password.encode()): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") + # Resolve tenant_ids based on role + if user.role == "platform_admin": + # Platform admins see all tenants + tenants_result = await session.execute(select(Tenant)) + tenant_ids = [str(t.id) for t in tenants_result.scalars().all()] + else: + # Customer admins and operators see only their assigned tenants + memberships_result = await session.execute( + select(UserTenantRole).where(UserTenantRole.user_id == user.id) + ) + tenant_ids = [str(m.tenant_id) for m in memberships_result.scalars().all()] + + active_tenant_id = tenant_ids[0] if tenant_ids else None + return AuthVerifyResponse( id=str(user.id), email=user.email, name=user.name, - is_admin=user.is_admin, + role=user.role, + tenant_ids=tenant_ids, + active_tenant_id=active_tenant_id, ) @portal_router.post("/auth/register", response_model=AuthRegisterResponse, status_code=status.HTTP_201_CREATED) async def register_user( body: AuthRegisterRequest, + # DEPRECATED: Direct registration is platform-admin only. + # Standard flow: use POST /api/portal/invitations (invite-only onboarding). + caller: PortalCaller = Depends(require_platform_admin), session: AsyncSession = Depends(get_session), ) -> AuthRegisterResponse: """ Create a new portal user with bcrypt-hashed password. - In production, restrict this to admin-only or use a setup wizard. + DEPRECATED: This endpoint is now restricted to platform admins only. + The standard onboarding flow is invite-only: POST /api/portal/invitations. + Returns 409 if email already registered. """ existing = await session.execute(select(PortalUser).where(PortalUser.email == body.email)) @@ -255,7 +283,7 @@ async def register_user( email=body.email, hashed_password=hashed, name=body.name, - is_admin=False, + role="customer_admin", ) session.add(user) await session.commit() @@ -265,7 +293,7 @@ async def register_user( id=str(user.id), email=user.email, name=user.name, - is_admin=user.is_admin, + role=user.role, ) diff --git a/packages/shared/shared/api/rbac.py b/packages/shared/shared/api/rbac.py new file mode 100644 index 0000000..db433f7 --- /dev/null +++ b/packages/shared/shared/api/rbac.py @@ -0,0 +1,173 @@ +""" +FastAPI RBAC guard dependencies for portal API endpoints. + +Usage pattern: + @router.get("/tenants/{tenant_id}/agents") + async def list_agents( + tenant_id: UUID, + caller: PortalCaller = Depends(get_portal_caller), + session: AsyncSession = Depends(get_session), + ) -> ...: + await require_tenant_member(tenant_id, caller, session) + ... + +Headers consumed (set by portal frontend / gateway middleware): + X-Portal-User-Id — UUID of the authenticated portal user + X-Portal-User-Role — Role string (platform_admin | customer_admin | customer_operator) + X-Portal-Tenant-Id — UUID of the caller's currently-selected tenant (optional) + +These headers are populated by the Auth.js session forwarded through the portal. +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass + +from fastapi import Depends, Header, HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from shared.db import get_session +from shared.models.auth import UserTenantRole + + +@dataclass +class PortalCaller: + """Resolved caller identity from portal request headers.""" + + user_id: uuid.UUID + role: str + tenant_id: uuid.UUID | None = None + + +async def get_portal_caller( + x_portal_user_id: str = Header(..., alias="X-Portal-User-Id"), + x_portal_user_role: str = Header(..., alias="X-Portal-User-Role"), + x_portal_tenant_id: str | None = Header(default=None, alias="X-Portal-Tenant-Id"), +) -> PortalCaller: + """ + FastAPI dependency: parse and validate portal identity headers. + + Returns PortalCaller with typed fields. + Raises 401 if X-Portal-User-Id is not a valid UUID. + """ + try: + user_id = uuid.UUID(x_portal_user_id) + except (ValueError, AttributeError) as exc: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid X-Portal-User-Id header", + ) from exc + + tenant_id: uuid.UUID | None = None + if x_portal_tenant_id: + try: + tenant_id = uuid.UUID(x_portal_tenant_id) + except (ValueError, AttributeError) as exc: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid X-Portal-Tenant-Id header", + ) from exc + + return PortalCaller( + user_id=user_id, + role=x_portal_user_role, + tenant_id=tenant_id, + ) + + +def require_platform_admin( + caller: PortalCaller = Depends(get_portal_caller), +) -> PortalCaller: + """ + FastAPI dependency: ensure the caller is a platform admin. + + Returns the caller if role == 'platform_admin'. + Raises 403 for any other role. + """ + if caller.role != "platform_admin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Platform admin access required", + ) + return caller + + +async def require_tenant_admin( + tenant_id: uuid.UUID, + caller: PortalCaller = Depends(get_portal_caller), + session: AsyncSession = Depends(get_session), +) -> PortalCaller: + """ + FastAPI dependency: ensure the caller is an admin for the given tenant. + + - platform_admin: always passes (bypasses membership check) + - customer_admin: must have a UserTenantRole row for the tenant + - customer_operator: always rejected (403) + - unknown roles: always rejected (403) + + Returns the caller on success. + """ + if caller.role == "platform_admin": + return caller + + if caller.role != "customer_admin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Tenant admin access required", + ) + + # customer_admin: verify membership in this specific tenant + result = await session.execute( + select(UserTenantRole).where( + UserTenantRole.user_id == caller.user_id, + UserTenantRole.tenant_id == tenant_id, + UserTenantRole.role == "customer_admin", + ) + ) + membership = result.scalar_one_or_none() + if membership is None: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have admin access to this tenant", + ) + return caller + + +async def require_tenant_member( + tenant_id: uuid.UUID, + caller: PortalCaller = Depends(get_portal_caller), + session: AsyncSession = Depends(get_session), +) -> PortalCaller: + """ + FastAPI dependency: ensure the caller is a member of the given tenant. + + - platform_admin: always passes (bypasses membership check) + - customer_admin or customer_operator: must have a UserTenantRole row for the tenant + - unknown roles: always rejected (403) + + Returns the caller on success. + """ + if caller.role == "platform_admin": + return caller + + if caller.role not in ("customer_admin", "customer_operator"): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Tenant member access required", + ) + + result = await session.execute( + select(UserTenantRole).where( + UserTenantRole.user_id == caller.user_id, + UserTenantRole.tenant_id == tenant_id, + ) + ) + membership = result.scalar_one_or_none() + if membership is None: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You are not a member of this tenant", + ) + return caller diff --git a/packages/shared/shared/email.py b/packages/shared/shared/email.py new file mode 100644 index 0000000..b0cbc6f --- /dev/null +++ b/packages/shared/shared/email.py @@ -0,0 +1,112 @@ +""" +SMTP email utility for Konstruct invitation emails. + +Sync function designed to be called from Celery tasks (sync def, asyncio.run() per +Phase 1 architectural constraint). Uses stdlib smtplib — no additional dependencies. + +If SMTP is not configured (empty smtp_host), logs a warning and returns without +sending. This allows the invitation flow to function in dev environments without +a mail server. +""" + +from __future__ import annotations + +import logging +import smtplib +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText + +from shared.config import settings + +logger = logging.getLogger(__name__) + + +def send_invite_email( + to_email: str, + invitee_name: str, + tenant_name: str, + invite_url: str, +) -> None: + """ + Send an invitation email via SMTP. + + Args: + to_email: Recipient email address. + invitee_name: Recipient's display name (for personalization). + tenant_name: Name of the tenant they're being invited to. + invite_url: The full invitation acceptance URL (includes raw token). + + Note: + Called from a Celery task (sync). Silently skips if smtp_host is empty. + """ + if not settings.smtp_host: + logger.warning( + "SMTP not configured (smtp_host is empty) — skipping invite email to %s", + to_email, + ) + return + + subject = f"You've been invited to join {tenant_name} on Konstruct" + + text_body = f"""Hi {invitee_name}, + +You've been invited to join {tenant_name} on Konstruct, the AI workforce platform. + +Click the link below to accept your invitation and set up your account: + +{invite_url} + +This invitation expires in 48 hours. + +If you did not expect this invitation, you can safely ignore this email. + +— The Konstruct Team +""" + + html_body = f""" +
+Hi {invitee_name},
++ You've been invited to join {tenant_name} on + Konstruct, the AI workforce platform. +
+ ++ This invitation expires in 48 hours. If you did not expect this email, + you can safely ignore it. +
+ +""" + + msg = MIMEMultipart("alternative") + msg["Subject"] = subject + msg["From"] = settings.smtp_from_email + msg["To"] = to_email + + msg.attach(MIMEText(text_body, "plain")) + msg.attach(MIMEText(html_body, "html")) + + try: + with smtplib.SMTP(settings.smtp_host, settings.smtp_port) as server: + server.ehlo() + if settings.smtp_port == 587: + server.starttls() + if settings.smtp_username and settings.smtp_password: + server.login(settings.smtp_username, settings.smtp_password) + server.sendmail(settings.smtp_from_email, [to_email], msg.as_string()) + logger.info("Invite email sent to %s for tenant %s", to_email, tenant_name) + except Exception: + logger.exception( + "Failed to send invite email to %s (smtp_host=%s)", + to_email, + settings.smtp_host, + ) + # Re-raise to allow Celery to retry if configured + raise diff --git a/packages/shared/shared/invite_token.py b/packages/shared/shared/invite_token.py new file mode 100644 index 0000000..c81b3d7 --- /dev/null +++ b/packages/shared/shared/invite_token.py @@ -0,0 +1,106 @@ +""" +HMAC-signed invite token generation and validation. + +Tokens encode `{invitation_id}:{timestamp}` signed with HMAC-SHA256 +using settings.invite_secret. The raw token is base64url-encoded so +it's safe to include in URLs and emails. + +Token format (before base64url encoding): + {invitation_id}:{timestamp_int}:{hmac_hex} + +TTL: 48 hours. Tokens are single-use — the caller must mark the +invitation as 'accepted' or 'revoked' after use. +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import time + +from shared.config import settings + +_TTL_SECONDS = 48 * 3600 # 48 hours + + +def generate_invite_token(invitation_id: str) -> str: + """ + Generate a base64url-encoded HMAC-signed invite token. + + Args: + invitation_id: UUID string of the PortalInvitation row. + + Returns: + A URL-safe base64-encoded token string. + """ + ts = int(time.time()) + payload = f"{invitation_id}:{ts}" + sig = _sign(payload) + raw = f"{payload}:{sig}" + return base64.urlsafe_b64encode(raw.encode()).decode() + + +def validate_invite_token(token: str) -> str: + """ + Validate an invite token and return the invitation_id. + + Args: + token: The base64url-encoded token from generate_invite_token. + + Returns: + The invitation_id embedded in the token. + + Raises: + ValueError: If the token is tampered, malformed, or expired. + """ + try: + raw = base64.urlsafe_b64decode(token.encode()).decode() + except Exception as exc: + raise ValueError("Invalid token encoding") from exc + + parts = raw.split(":") + if len(parts) != 3: + raise ValueError("Malformed token: expected 3 parts") + + invitation_id, ts_str, sig = parts + + try: + ts = int(ts_str) + except ValueError as exc: + raise ValueError("Malformed token: invalid timestamp") from exc + + # Timing-safe signature verification + expected_payload = f"{invitation_id}:{ts_str}" + expected_sig = _sign(expected_payload) + if not hmac.compare_digest(sig, expected_sig): + raise ValueError("Invalid token signature") + + # TTL check + now = int(time.time()) + if now - ts > _TTL_SECONDS: + raise ValueError("Token expired") + + return invitation_id + + +def token_to_hash(token: str) -> str: + """ + Compute the SHA-256 hash of a raw invite token for DB storage. + + Args: + token: The raw base64url-encoded token. + + Returns: + Hex-encoded SHA-256 digest. + """ + return hashlib.sha256(token.encode()).hexdigest() + + +def _sign(payload: str) -> str: + """Return HMAC-SHA256 hex digest of the payload.""" + return hmac.new( + settings.invite_secret.encode(), + payload.encode(), + hashlib.sha256, + ).hexdigest()