test(04-rbac-01): unit tests for RBAC guards, invitation system, portal auth

- test_rbac_guards.py: 11 tests covering platform_admin pass-through,
  customer_admin/operator 403 rejection, tenant membership checks,
  and platform_admin bypass for tenant-scoped guards
- test_invitations.py: 11 tests covering HMAC token roundtrip,
  tamper/expiry rejection, invitation create/accept/resend/list
- test_portal_auth.py: 7 tests covering role field (not is_admin),
  tenant_ids list, active_tenant_id, platform_admin all-tenants,
  customer_admin own-tenants-only
- All 27 tests pass
This commit is contained in:
2026-03-24 13:55:55 -06:00
parent d59f85cd87
commit 7b0594e7cc
3 changed files with 835 additions and 0 deletions

View File

@@ -0,0 +1,368 @@
"""
Unit tests for invitation HMAC token utilities and invitation API endpoints.
Tests:
- test_token_roundtrip: generate then validate returns same invitation_id
- test_token_tamper_rejected: modified signature raises ValueError
- test_token_expired_rejected: artificially old token raises ValueError
- test_invite_create: POST /api/portal/invitations creates invitation, returns 201 + token
- test_invite_accept_creates_user: accepting invite creates PortalUser + UserTenantRole
- test_invite_accept_rejects_expired: expired invitation (status != pending) raises error
- test_invite_resend_updates_token: resend generates new token_hash and extends expires_at
"""
from __future__ import annotations
import uuid
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from shared.invite_token import generate_invite_token, token_to_hash, validate_invite_token
# ---------------------------------------------------------------------------
# HMAC token tests (unit, no DB)
# ---------------------------------------------------------------------------
def test_token_roundtrip() -> None:
"""generate_invite_token then validate_invite_token returns the same invitation_id."""
inv_id = str(uuid.uuid4())
token = generate_invite_token(inv_id)
result = validate_invite_token(token)
assert result == inv_id
def test_token_tamper_rejected() -> None:
"""A token with a modified signature raises ValueError."""
inv_id = str(uuid.uuid4())
token = generate_invite_token(inv_id)
# Truncate and append garbage to corrupt the signature portion
corrupted = token[:-8] + "XXXXXXXX"
with pytest.raises(ValueError):
validate_invite_token(corrupted)
def test_token_expired_rejected() -> None:
"""A token older than 48h raises ValueError on validation."""
inv_id = str(uuid.uuid4())
# Patch time.time to simulate a token created 49 hours ago
past_time = 1000000 # some fixed old epoch value
current_time = past_time + (49 * 3600) # 49h later
with patch("shared.invite_token.time.time", return_value=past_time):
token = generate_invite_token(inv_id)
with patch("shared.invite_token.time.time", return_value=current_time):
with pytest.raises(ValueError, match="expired"):
validate_invite_token(token)
def test_token_hash_is_hex_sha256() -> None:
"""token_to_hash returns a 64-char hex string (SHA-256)."""
token = generate_invite_token(str(uuid.uuid4()))
h = token_to_hash(token)
assert len(h) == 64
# Valid hex
int(h, 16)
def test_two_tokens_different() -> None:
"""Two tokens for the same invitation_id at different times have different hashes."""
inv_id = str(uuid.uuid4())
with patch("shared.invite_token.time.time", return_value=1000000):
token1 = generate_invite_token(inv_id)
with patch("shared.invite_token.time.time", return_value=1000001):
token2 = generate_invite_token(inv_id)
assert token1 != token2
assert token_to_hash(token1) != token_to_hash(token2)
# ---------------------------------------------------------------------------
# Invitation API tests (mock DB session)
# ---------------------------------------------------------------------------
def _make_mock_session() -> AsyncMock:
"""Create a mock AsyncSession."""
session = AsyncMock()
session.execute.return_value = MagicMock(
scalar_one_or_none=MagicMock(return_value=None),
scalars=MagicMock(return_value=MagicMock(all=MagicMock(return_value=[]))),
)
return session
def _make_mock_tenant(name: str = "Acme Corp") -> MagicMock:
t = MagicMock()
t.id = uuid.uuid4()
t.name = name
return t
def _make_mock_invitation(
tenant_id: uuid.UUID | None = None,
status: str = "pending",
expires_in_hours: float = 24,
) -> MagicMock:
inv = MagicMock()
inv.id = uuid.uuid4()
inv.email = "new@example.com"
inv.name = "New User"
inv.tenant_id = tenant_id or uuid.uuid4()
inv.role = "customer_admin"
inv.invited_by = uuid.uuid4()
inv.token_hash = "abc123"
inv.status = status
inv.expires_at = datetime.now(tz=timezone.utc) + timedelta(hours=expires_in_hours)
inv.created_at = datetime.now(tz=timezone.utc)
return inv
@pytest.mark.asyncio
async def test_invite_create_returns_201_with_token() -> None:
"""
POST /api/portal/invitations creates an invitation and returns 201 with raw token.
"""
from shared.api.invitations import InvitationCreate, create_invitation
tenant_id = uuid.uuid4()
caller_id = uuid.uuid4()
from shared.api.rbac import PortalCaller
caller = PortalCaller(user_id=caller_id, role="customer_admin", tenant_id=tenant_id)
session = _make_mock_session()
mock_tenant = _make_mock_tenant()
# Mock: require_tenant_admin (membership check) returns membership
mock_membership = MagicMock()
# Setup execute to return tenant on first call, membership on second call
call_count = 0
def execute_side_effect(stmt):
nonlocal call_count
call_count += 1
result = MagicMock()
if call_count == 1:
# require_tenant_admin — membership check
result.scalar_one_or_none = MagicMock(return_value=mock_membership)
elif call_count == 2:
# _get_tenant_or_404
result.scalar_one_or_none = MagicMock(return_value=mock_tenant)
else:
result.scalar_one_or_none = MagicMock(return_value=None)
return result
session.execute = AsyncMock(side_effect=execute_side_effect)
session.flush = AsyncMock()
# Mock refresh to set created_at on the invitation object
async def mock_refresh(obj):
if not hasattr(obj, 'created_at') or obj.created_at is None:
obj.created_at = datetime.now(tz=timezone.utc)
session.refresh = AsyncMock(side_effect=mock_refresh)
body = InvitationCreate(
email="new@example.com",
name="New User",
role="customer_admin",
tenant_id=tenant_id,
)
with patch("shared.api.invitations._dispatch_invite_email") as mock_dispatch:
result = await create_invitation(body=body, caller=caller, session=session)
assert result.email == "new@example.com"
assert result.name == "New User"
assert result.token is not None
assert len(result.token) > 0
mock_dispatch.assert_called_once()
@pytest.mark.asyncio
async def test_invite_accept_creates_user() -> None:
"""
POST /api/portal/invitations/accept with valid token creates PortalUser + UserTenantRole.
"""
from shared.api.invitations import InvitationAccept, accept_invitation
inv_id = uuid.uuid4()
tenant_id = uuid.uuid4()
# Create a valid token
token = generate_invite_token(str(inv_id))
mock_inv = _make_mock_invitation(tenant_id=tenant_id, status="pending")
mock_inv.id = inv_id
mock_inv.expires_at = datetime.now(tz=timezone.utc) + timedelta(hours=24)
session = _make_mock_session()
call_count = 0
added_objects = []
def execute_side_effect(stmt):
nonlocal call_count
call_count += 1
result = MagicMock()
if call_count == 1:
# Load invitation by ID
result.scalar_one_or_none = MagicMock(return_value=mock_inv)
elif call_count == 2:
# Check existing user by email
result.scalar_one_or_none = MagicMock(return_value=None)
else:
result.scalar_one_or_none = MagicMock(return_value=None)
return result
session.execute = AsyncMock(side_effect=execute_side_effect)
def capture_add(obj):
added_objects.append(obj)
session.add = MagicMock(side_effect=capture_add)
async def mock_refresh(obj):
# Ensure the user has an id and role
pass
session.refresh = AsyncMock(side_effect=mock_refresh)
body = InvitationAccept(token=token, password="securepassword123")
result = await accept_invitation(body=body, session=session)
assert result.email == mock_inv.email
assert result.name == mock_inv.name
assert result.role == mock_inv.role
# Verify user and membership were added
from shared.models.auth import PortalUser, UserTenantRole
portal_users = [o for o in added_objects if isinstance(o, PortalUser)]
memberships = [o for o in added_objects if isinstance(o, UserTenantRole)]
assert len(portal_users) == 1, f"Expected 1 PortalUser, got {len(portal_users)}"
assert len(memberships) == 1, f"Expected 1 UserTenantRole, got {len(memberships)}"
# Verify invitation was marked accepted
assert mock_inv.status == "accepted"
@pytest.mark.asyncio
async def test_invite_accept_rejects_non_pending() -> None:
"""
POST /api/portal/invitations/accept with already-accepted invitation returns 409.
"""
from fastapi import HTTPException
from shared.api.invitations import InvitationAccept, accept_invitation
inv_id = uuid.uuid4()
token = generate_invite_token(str(inv_id))
mock_inv = _make_mock_invitation(status="accepted")
mock_inv.id = inv_id
session = _make_mock_session()
session.execute = AsyncMock(
return_value=MagicMock(scalar_one_or_none=MagicMock(return_value=mock_inv))
)
body = InvitationAccept(token=token, password="password123")
with pytest.raises(HTTPException) as exc_info:
await accept_invitation(body=body, session=session)
assert exc_info.value.status_code == 409
@pytest.mark.asyncio
async def test_invite_accept_rejects_expired_invitation() -> None:
"""
POST /api/portal/invitations/accept with past expires_at returns 400.
"""
from fastapi import HTTPException
from shared.api.invitations import InvitationAccept, accept_invitation
inv_id = uuid.uuid4()
token = generate_invite_token(str(inv_id))
mock_inv = _make_mock_invitation(status="pending")
mock_inv.id = inv_id
mock_inv.expires_at = datetime.now(tz=timezone.utc) - timedelta(hours=1) # Expired
session = _make_mock_session()
session.execute = AsyncMock(
return_value=MagicMock(scalar_one_or_none=MagicMock(return_value=mock_inv))
)
body = InvitationAccept(token=token, password="password123")
with pytest.raises(HTTPException) as exc_info:
await accept_invitation(body=body, session=session)
assert exc_info.value.status_code == 400
assert "expired" in exc_info.value.detail.lower()
@pytest.mark.asyncio
async def test_invite_resend_updates_token() -> None:
"""
POST /api/portal/invitations/{id}/resend generates new token_hash and extends expires_at.
"""
from shared.api.invitations import resend_invitation
from shared.api.rbac import PortalCaller
tenant_id = uuid.uuid4()
caller = PortalCaller(user_id=uuid.uuid4(), role="platform_admin", tenant_id=tenant_id)
mock_inv = _make_mock_invitation(tenant_id=tenant_id, status="pending")
old_token_hash = mock_inv.token_hash
old_expires = mock_inv.expires_at
mock_tenant = _make_mock_tenant()
session = _make_mock_session()
call_count = 0
def execute_side_effect(stmt):
nonlocal call_count
call_count += 1
result = MagicMock()
if call_count == 1:
# Load invitation
result.scalar_one_or_none = MagicMock(return_value=mock_inv)
elif call_count == 2:
# Load tenant
result.scalar_one_or_none = MagicMock(return_value=mock_tenant)
else:
result.scalar_one_or_none = MagicMock(return_value=None)
return result
session.execute = AsyncMock(side_effect=execute_side_effect)
async def mock_refresh(obj):
pass
session.refresh = AsyncMock(side_effect=mock_refresh)
with patch("shared.api.invitations._dispatch_invite_email"):
result = await resend_invitation(
invitation_id=mock_inv.id,
caller=caller,
session=session,
)
# token_hash should have been updated
assert mock_inv.token_hash != old_token_hash
# expires_at should have been extended
assert mock_inv.expires_at > old_expires
# Raw token returned
assert result.token is not None