""" Unit tests for invitation HMAC token utilities and invitation API endpoints. Tests: - test_token_roundtrip: generate then validate returns same invitation_id - test_token_tamper_rejected: modified signature raises ValueError - test_token_expired_rejected: artificially old token raises ValueError - test_invite_create: POST /api/portal/invitations creates invitation, returns 201 + token - test_invite_accept_creates_user: accepting invite creates PortalUser + UserTenantRole - test_invite_accept_rejects_expired: expired invitation (status != pending) raises error - test_invite_resend_updates_token: resend generates new token_hash and extends expires_at """ from __future__ import annotations import uuid from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from shared.invite_token import generate_invite_token, token_to_hash, validate_invite_token # --------------------------------------------------------------------------- # HMAC token tests (unit, no DB) # --------------------------------------------------------------------------- def test_token_roundtrip() -> None: """generate_invite_token then validate_invite_token returns the same invitation_id.""" inv_id = str(uuid.uuid4()) token = generate_invite_token(inv_id) result = validate_invite_token(token) assert result == inv_id def test_token_tamper_rejected() -> None: """A token with a modified signature raises ValueError.""" inv_id = str(uuid.uuid4()) token = generate_invite_token(inv_id) # Truncate and append garbage to corrupt the signature portion corrupted = token[:-8] + "XXXXXXXX" with pytest.raises(ValueError): validate_invite_token(corrupted) def test_token_expired_rejected() -> None: """A token older than 48h raises ValueError on validation.""" inv_id = str(uuid.uuid4()) # Patch time.time to simulate a token created 49 hours ago past_time = 1000000 # some fixed old epoch value current_time = past_time + (49 * 3600) # 49h later with patch("shared.invite_token.time.time", return_value=past_time): token = generate_invite_token(inv_id) with patch("shared.invite_token.time.time", return_value=current_time): with pytest.raises(ValueError, match="expired"): validate_invite_token(token) def test_token_hash_is_hex_sha256() -> None: """token_to_hash returns a 64-char hex string (SHA-256).""" token = generate_invite_token(str(uuid.uuid4())) h = token_to_hash(token) assert len(h) == 64 # Valid hex int(h, 16) def test_two_tokens_different() -> None: """Two tokens for the same invitation_id at different times have different hashes.""" inv_id = str(uuid.uuid4()) with patch("shared.invite_token.time.time", return_value=1000000): token1 = generate_invite_token(inv_id) with patch("shared.invite_token.time.time", return_value=1000001): token2 = generate_invite_token(inv_id) assert token1 != token2 assert token_to_hash(token1) != token_to_hash(token2) # --------------------------------------------------------------------------- # Invitation API tests (mock DB session) # --------------------------------------------------------------------------- def _make_mock_session() -> AsyncMock: """Create a mock AsyncSession.""" session = AsyncMock() session.execute.return_value = MagicMock( scalar_one_or_none=MagicMock(return_value=None), scalars=MagicMock(return_value=MagicMock(all=MagicMock(return_value=[]))), ) return session def _make_mock_tenant(name: str = "Acme Corp") -> MagicMock: t = MagicMock() t.id = uuid.uuid4() t.name = name return t def _make_mock_invitation( tenant_id: uuid.UUID | None = None, status: str = "pending", expires_in_hours: float = 24, ) -> MagicMock: inv = MagicMock() inv.id = uuid.uuid4() inv.email = "new@example.com" inv.name = "New User" inv.tenant_id = tenant_id or uuid.uuid4() inv.role = "customer_admin" inv.invited_by = uuid.uuid4() inv.token_hash = "abc123" inv.status = status inv.expires_at = datetime.now(tz=timezone.utc) + timedelta(hours=expires_in_hours) inv.created_at = datetime.now(tz=timezone.utc) return inv @pytest.mark.asyncio async def test_invite_create_returns_201_with_token() -> None: """ POST /api/portal/invitations creates an invitation and returns 201 with raw token. """ from shared.api.invitations import InvitationCreate, create_invitation tenant_id = uuid.uuid4() caller_id = uuid.uuid4() from shared.api.rbac import PortalCaller caller = PortalCaller(user_id=caller_id, role="customer_admin", tenant_id=tenant_id) session = _make_mock_session() mock_tenant = _make_mock_tenant() # Mock: require_tenant_admin (membership check) returns membership mock_membership = MagicMock() # Setup execute to return tenant on first call, membership on second call call_count = 0 def execute_side_effect(stmt): nonlocal call_count call_count += 1 result = MagicMock() if call_count == 1: # require_tenant_admin — membership check result.scalar_one_or_none = MagicMock(return_value=mock_membership) elif call_count == 2: # _get_tenant_or_404 result.scalar_one_or_none = MagicMock(return_value=mock_tenant) else: result.scalar_one_or_none = MagicMock(return_value=None) return result session.execute = AsyncMock(side_effect=execute_side_effect) session.flush = AsyncMock() # Mock refresh to set created_at on the invitation object async def mock_refresh(obj): if not hasattr(obj, 'created_at') or obj.created_at is None: obj.created_at = datetime.now(tz=timezone.utc) session.refresh = AsyncMock(side_effect=mock_refresh) body = InvitationCreate( email="new@example.com", name="New User", role="customer_admin", tenant_id=tenant_id, ) with patch("shared.api.invitations._dispatch_invite_email") as mock_dispatch: result = await create_invitation(body=body, caller=caller, session=session) assert result.email == "new@example.com" assert result.name == "New User" assert result.token is not None assert len(result.token) > 0 mock_dispatch.assert_called_once() @pytest.mark.asyncio async def test_invite_accept_creates_user() -> None: """ POST /api/portal/invitations/accept with valid token creates PortalUser + UserTenantRole. """ from shared.api.invitations import InvitationAccept, accept_invitation inv_id = uuid.uuid4() tenant_id = uuid.uuid4() # Create a valid token token = generate_invite_token(str(inv_id)) mock_inv = _make_mock_invitation(tenant_id=tenant_id, status="pending") mock_inv.id = inv_id mock_inv.expires_at = datetime.now(tz=timezone.utc) + timedelta(hours=24) session = _make_mock_session() call_count = 0 added_objects = [] def execute_side_effect(stmt): nonlocal call_count call_count += 1 result = MagicMock() if call_count == 1: # Load invitation by ID result.scalar_one_or_none = MagicMock(return_value=mock_inv) elif call_count == 2: # Check existing user by email result.scalar_one_or_none = MagicMock(return_value=None) else: result.scalar_one_or_none = MagicMock(return_value=None) return result session.execute = AsyncMock(side_effect=execute_side_effect) def capture_add(obj): added_objects.append(obj) session.add = MagicMock(side_effect=capture_add) async def mock_refresh(obj): # Ensure the user has an id and role pass session.refresh = AsyncMock(side_effect=mock_refresh) body = InvitationAccept(token=token, password="securepassword123") result = await accept_invitation(body=body, session=session) assert result.email == mock_inv.email assert result.name == mock_inv.name assert result.role == mock_inv.role # Verify user and membership were added from shared.models.auth import PortalUser, UserTenantRole portal_users = [o for o in added_objects if isinstance(o, PortalUser)] memberships = [o for o in added_objects if isinstance(o, UserTenantRole)] assert len(portal_users) == 1, f"Expected 1 PortalUser, got {len(portal_users)}" assert len(memberships) == 1, f"Expected 1 UserTenantRole, got {len(memberships)}" # Verify invitation was marked accepted assert mock_inv.status == "accepted" @pytest.mark.asyncio async def test_invite_accept_rejects_non_pending() -> None: """ POST /api/portal/invitations/accept with already-accepted invitation returns 409. """ from fastapi import HTTPException from shared.api.invitations import InvitationAccept, accept_invitation inv_id = uuid.uuid4() token = generate_invite_token(str(inv_id)) mock_inv = _make_mock_invitation(status="accepted") mock_inv.id = inv_id session = _make_mock_session() session.execute = AsyncMock( return_value=MagicMock(scalar_one_or_none=MagicMock(return_value=mock_inv)) ) body = InvitationAccept(token=token, password="password123") with pytest.raises(HTTPException) as exc_info: await accept_invitation(body=body, session=session) assert exc_info.value.status_code == 409 @pytest.mark.asyncio async def test_invite_accept_rejects_expired_invitation() -> None: """ POST /api/portal/invitations/accept with past expires_at returns 400. """ from fastapi import HTTPException from shared.api.invitations import InvitationAccept, accept_invitation inv_id = uuid.uuid4() token = generate_invite_token(str(inv_id)) mock_inv = _make_mock_invitation(status="pending") mock_inv.id = inv_id mock_inv.expires_at = datetime.now(tz=timezone.utc) - timedelta(hours=1) # Expired session = _make_mock_session() session.execute = AsyncMock( return_value=MagicMock(scalar_one_or_none=MagicMock(return_value=mock_inv)) ) body = InvitationAccept(token=token, password="password123") with pytest.raises(HTTPException) as exc_info: await accept_invitation(body=body, session=session) assert exc_info.value.status_code == 400 assert "expired" in exc_info.value.detail.lower() @pytest.mark.asyncio async def test_invite_resend_updates_token() -> None: """ POST /api/portal/invitations/{id}/resend generates new token_hash and extends expires_at. """ from shared.api.invitations import resend_invitation from shared.api.rbac import PortalCaller tenant_id = uuid.uuid4() caller = PortalCaller(user_id=uuid.uuid4(), role="platform_admin", tenant_id=tenant_id) mock_inv = _make_mock_invitation(tenant_id=tenant_id, status="pending") old_token_hash = mock_inv.token_hash old_expires = mock_inv.expires_at mock_tenant = _make_mock_tenant() session = _make_mock_session() call_count = 0 def execute_side_effect(stmt): nonlocal call_count call_count += 1 result = MagicMock() if call_count == 1: # Load invitation result.scalar_one_or_none = MagicMock(return_value=mock_inv) elif call_count == 2: # Load tenant result.scalar_one_or_none = MagicMock(return_value=mock_tenant) else: result.scalar_one_or_none = MagicMock(return_value=None) return result session.execute = AsyncMock(side_effect=execute_side_effect) async def mock_refresh(obj): pass session.refresh = AsyncMock(side_effect=mock_refresh) with patch("shared.api.invitations._dispatch_invite_email"): result = await resend_invitation( invitation_id=mock_inv.id, caller=caller, session=session, ) # token_hash should have been updated assert mock_inv.token_hash != old_token_hash # expires_at should have been extended assert mock_inv.expires_at > old_expires # Raw token returned assert result.token is not None