test(04-rbac-01): unit tests for RBAC guards, invitation system, portal auth
- test_rbac_guards.py: 11 tests covering platform_admin pass-through, customer_admin/operator 403 rejection, tenant membership checks, and platform_admin bypass for tenant-scoped guards - test_invitations.py: 11 tests covering HMAC token roundtrip, tamper/expiry rejection, invitation create/accept/resend/list - test_portal_auth.py: 7 tests covering role field (not is_admin), tenant_ids list, active_tenant_id, platform_admin all-tenants, customer_admin own-tenants-only - All 27 tests pass
This commit is contained in:
368
tests/unit/test_invitations.py
Normal file
368
tests/unit/test_invitations.py
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for invitation HMAC token utilities and invitation API endpoints.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- test_token_roundtrip: generate then validate returns same invitation_id
|
||||||
|
- test_token_tamper_rejected: modified signature raises ValueError
|
||||||
|
- test_token_expired_rejected: artificially old token raises ValueError
|
||||||
|
- test_invite_create: POST /api/portal/invitations creates invitation, returns 201 + token
|
||||||
|
- test_invite_accept_creates_user: accepting invite creates PortalUser + UserTenantRole
|
||||||
|
- test_invite_accept_rejects_expired: expired invitation (status != pending) raises error
|
||||||
|
- test_invite_resend_updates_token: resend generates new token_hash and extends expires_at
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from shared.invite_token import generate_invite_token, token_to_hash, validate_invite_token
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HMAC token tests (unit, no DB)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_roundtrip() -> None:
|
||||||
|
"""generate_invite_token then validate_invite_token returns the same invitation_id."""
|
||||||
|
inv_id = str(uuid.uuid4())
|
||||||
|
token = generate_invite_token(inv_id)
|
||||||
|
result = validate_invite_token(token)
|
||||||
|
assert result == inv_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_tamper_rejected() -> None:
|
||||||
|
"""A token with a modified signature raises ValueError."""
|
||||||
|
inv_id = str(uuid.uuid4())
|
||||||
|
token = generate_invite_token(inv_id)
|
||||||
|
|
||||||
|
# Truncate and append garbage to corrupt the signature portion
|
||||||
|
corrupted = token[:-8] + "XXXXXXXX"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
validate_invite_token(corrupted)
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_expired_rejected() -> None:
|
||||||
|
"""A token older than 48h raises ValueError on validation."""
|
||||||
|
inv_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Patch time.time to simulate a token created 49 hours ago
|
||||||
|
past_time = 1000000 # some fixed old epoch value
|
||||||
|
current_time = past_time + (49 * 3600) # 49h later
|
||||||
|
|
||||||
|
with patch("shared.invite_token.time.time", return_value=past_time):
|
||||||
|
token = generate_invite_token(inv_id)
|
||||||
|
|
||||||
|
with patch("shared.invite_token.time.time", return_value=current_time):
|
||||||
|
with pytest.raises(ValueError, match="expired"):
|
||||||
|
validate_invite_token(token)
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_hash_is_hex_sha256() -> None:
|
||||||
|
"""token_to_hash returns a 64-char hex string (SHA-256)."""
|
||||||
|
token = generate_invite_token(str(uuid.uuid4()))
|
||||||
|
h = token_to_hash(token)
|
||||||
|
assert len(h) == 64
|
||||||
|
# Valid hex
|
||||||
|
int(h, 16)
|
||||||
|
|
||||||
|
|
||||||
|
def test_two_tokens_different() -> None:
|
||||||
|
"""Two tokens for the same invitation_id at different times have different hashes."""
|
||||||
|
inv_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
with patch("shared.invite_token.time.time", return_value=1000000):
|
||||||
|
token1 = generate_invite_token(inv_id)
|
||||||
|
|
||||||
|
with patch("shared.invite_token.time.time", return_value=1000001):
|
||||||
|
token2 = generate_invite_token(inv_id)
|
||||||
|
|
||||||
|
assert token1 != token2
|
||||||
|
assert token_to_hash(token1) != token_to_hash(token2)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Invitation API tests (mock DB session)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_session() -> AsyncMock:
|
||||||
|
"""Create a mock AsyncSession."""
|
||||||
|
session = AsyncMock()
|
||||||
|
session.execute.return_value = MagicMock(
|
||||||
|
scalar_one_or_none=MagicMock(return_value=None),
|
||||||
|
scalars=MagicMock(return_value=MagicMock(all=MagicMock(return_value=[]))),
|
||||||
|
)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_tenant(name: str = "Acme Corp") -> MagicMock:
|
||||||
|
t = MagicMock()
|
||||||
|
t.id = uuid.uuid4()
|
||||||
|
t.name = name
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_invitation(
|
||||||
|
tenant_id: uuid.UUID | None = None,
|
||||||
|
status: str = "pending",
|
||||||
|
expires_in_hours: float = 24,
|
||||||
|
) -> MagicMock:
|
||||||
|
inv = MagicMock()
|
||||||
|
inv.id = uuid.uuid4()
|
||||||
|
inv.email = "new@example.com"
|
||||||
|
inv.name = "New User"
|
||||||
|
inv.tenant_id = tenant_id or uuid.uuid4()
|
||||||
|
inv.role = "customer_admin"
|
||||||
|
inv.invited_by = uuid.uuid4()
|
||||||
|
inv.token_hash = "abc123"
|
||||||
|
inv.status = status
|
||||||
|
inv.expires_at = datetime.now(tz=timezone.utc) + timedelta(hours=expires_in_hours)
|
||||||
|
inv.created_at = datetime.now(tz=timezone.utc)
|
||||||
|
return inv
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invite_create_returns_201_with_token() -> None:
|
||||||
|
"""
|
||||||
|
POST /api/portal/invitations creates an invitation and returns 201 with raw token.
|
||||||
|
"""
|
||||||
|
from shared.api.invitations import InvitationCreate, create_invitation
|
||||||
|
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
caller_id = uuid.uuid4()
|
||||||
|
|
||||||
|
from shared.api.rbac import PortalCaller
|
||||||
|
caller = PortalCaller(user_id=caller_id, role="customer_admin", tenant_id=tenant_id)
|
||||||
|
|
||||||
|
session = _make_mock_session()
|
||||||
|
mock_tenant = _make_mock_tenant()
|
||||||
|
|
||||||
|
# Mock: require_tenant_admin (membership check) returns membership
|
||||||
|
mock_membership = MagicMock()
|
||||||
|
|
||||||
|
# Setup execute to return tenant on first call, membership on second call
|
||||||
|
call_count = 0
|
||||||
|
def execute_side_effect(stmt):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
result = MagicMock()
|
||||||
|
if call_count == 1:
|
||||||
|
# require_tenant_admin — membership check
|
||||||
|
result.scalar_one_or_none = MagicMock(return_value=mock_membership)
|
||||||
|
elif call_count == 2:
|
||||||
|
# _get_tenant_or_404
|
||||||
|
result.scalar_one_or_none = MagicMock(return_value=mock_tenant)
|
||||||
|
else:
|
||||||
|
result.scalar_one_or_none = MagicMock(return_value=None)
|
||||||
|
return result
|
||||||
|
|
||||||
|
session.execute = AsyncMock(side_effect=execute_side_effect)
|
||||||
|
session.flush = AsyncMock()
|
||||||
|
|
||||||
|
# Mock refresh to set created_at on the invitation object
|
||||||
|
async def mock_refresh(obj):
|
||||||
|
if not hasattr(obj, 'created_at') or obj.created_at is None:
|
||||||
|
obj.created_at = datetime.now(tz=timezone.utc)
|
||||||
|
|
||||||
|
session.refresh = AsyncMock(side_effect=mock_refresh)
|
||||||
|
|
||||||
|
body = InvitationCreate(
|
||||||
|
email="new@example.com",
|
||||||
|
name="New User",
|
||||||
|
role="customer_admin",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("shared.api.invitations._dispatch_invite_email") as mock_dispatch:
|
||||||
|
result = await create_invitation(body=body, caller=caller, session=session)
|
||||||
|
|
||||||
|
assert result.email == "new@example.com"
|
||||||
|
assert result.name == "New User"
|
||||||
|
assert result.token is not None
|
||||||
|
assert len(result.token) > 0
|
||||||
|
mock_dispatch.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invite_accept_creates_user() -> None:
|
||||||
|
"""
|
||||||
|
POST /api/portal/invitations/accept with valid token creates PortalUser + UserTenantRole.
|
||||||
|
"""
|
||||||
|
from shared.api.invitations import InvitationAccept, accept_invitation
|
||||||
|
|
||||||
|
inv_id = uuid.uuid4()
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
|
||||||
|
# Create a valid token
|
||||||
|
token = generate_invite_token(str(inv_id))
|
||||||
|
|
||||||
|
mock_inv = _make_mock_invitation(tenant_id=tenant_id, status="pending")
|
||||||
|
mock_inv.id = inv_id
|
||||||
|
mock_inv.expires_at = datetime.now(tz=timezone.utc) + timedelta(hours=24)
|
||||||
|
|
||||||
|
session = _make_mock_session()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
added_objects = []
|
||||||
|
|
||||||
|
def execute_side_effect(stmt):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
result = MagicMock()
|
||||||
|
if call_count == 1:
|
||||||
|
# Load invitation by ID
|
||||||
|
result.scalar_one_or_none = MagicMock(return_value=mock_inv)
|
||||||
|
elif call_count == 2:
|
||||||
|
# Check existing user by email
|
||||||
|
result.scalar_one_or_none = MagicMock(return_value=None)
|
||||||
|
else:
|
||||||
|
result.scalar_one_or_none = MagicMock(return_value=None)
|
||||||
|
return result
|
||||||
|
|
||||||
|
session.execute = AsyncMock(side_effect=execute_side_effect)
|
||||||
|
|
||||||
|
def capture_add(obj):
|
||||||
|
added_objects.append(obj)
|
||||||
|
|
||||||
|
session.add = MagicMock(side_effect=capture_add)
|
||||||
|
|
||||||
|
async def mock_refresh(obj):
|
||||||
|
# Ensure the user has an id and role
|
||||||
|
pass
|
||||||
|
|
||||||
|
session.refresh = AsyncMock(side_effect=mock_refresh)
|
||||||
|
|
||||||
|
body = InvitationAccept(token=token, password="securepassword123")
|
||||||
|
result = await accept_invitation(body=body, session=session)
|
||||||
|
|
||||||
|
assert result.email == mock_inv.email
|
||||||
|
assert result.name == mock_inv.name
|
||||||
|
assert result.role == mock_inv.role
|
||||||
|
|
||||||
|
# Verify user and membership were added
|
||||||
|
from shared.models.auth import PortalUser, UserTenantRole
|
||||||
|
portal_users = [o for o in added_objects if isinstance(o, PortalUser)]
|
||||||
|
memberships = [o for o in added_objects if isinstance(o, UserTenantRole)]
|
||||||
|
assert len(portal_users) == 1, f"Expected 1 PortalUser, got {len(portal_users)}"
|
||||||
|
assert len(memberships) == 1, f"Expected 1 UserTenantRole, got {len(memberships)}"
|
||||||
|
|
||||||
|
# Verify invitation was marked accepted
|
||||||
|
assert mock_inv.status == "accepted"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invite_accept_rejects_non_pending() -> None:
|
||||||
|
"""
|
||||||
|
POST /api/portal/invitations/accept with already-accepted invitation returns 409.
|
||||||
|
"""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from shared.api.invitations import InvitationAccept, accept_invitation
|
||||||
|
|
||||||
|
inv_id = uuid.uuid4()
|
||||||
|
token = generate_invite_token(str(inv_id))
|
||||||
|
|
||||||
|
mock_inv = _make_mock_invitation(status="accepted")
|
||||||
|
mock_inv.id = inv_id
|
||||||
|
|
||||||
|
session = _make_mock_session()
|
||||||
|
session.execute = AsyncMock(
|
||||||
|
return_value=MagicMock(scalar_one_or_none=MagicMock(return_value=mock_inv))
|
||||||
|
)
|
||||||
|
|
||||||
|
body = InvitationAccept(token=token, password="password123")
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await accept_invitation(body=body, session=session)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invite_accept_rejects_expired_invitation() -> None:
|
||||||
|
"""
|
||||||
|
POST /api/portal/invitations/accept with past expires_at returns 400.
|
||||||
|
"""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from shared.api.invitations import InvitationAccept, accept_invitation
|
||||||
|
|
||||||
|
inv_id = uuid.uuid4()
|
||||||
|
token = generate_invite_token(str(inv_id))
|
||||||
|
|
||||||
|
mock_inv = _make_mock_invitation(status="pending")
|
||||||
|
mock_inv.id = inv_id
|
||||||
|
mock_inv.expires_at = datetime.now(tz=timezone.utc) - timedelta(hours=1) # Expired
|
||||||
|
|
||||||
|
session = _make_mock_session()
|
||||||
|
session.execute = AsyncMock(
|
||||||
|
return_value=MagicMock(scalar_one_or_none=MagicMock(return_value=mock_inv))
|
||||||
|
)
|
||||||
|
|
||||||
|
body = InvitationAccept(token=token, password="password123")
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await accept_invitation(body=body, session=session)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
assert "expired" in exc_info.value.detail.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invite_resend_updates_token() -> None:
|
||||||
|
"""
|
||||||
|
POST /api/portal/invitations/{id}/resend generates new token_hash and extends expires_at.
|
||||||
|
"""
|
||||||
|
from shared.api.invitations import resend_invitation
|
||||||
|
from shared.api.rbac import PortalCaller
|
||||||
|
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
caller = PortalCaller(user_id=uuid.uuid4(), role="platform_admin", tenant_id=tenant_id)
|
||||||
|
mock_inv = _make_mock_invitation(tenant_id=tenant_id, status="pending")
|
||||||
|
old_token_hash = mock_inv.token_hash
|
||||||
|
old_expires = mock_inv.expires_at
|
||||||
|
|
||||||
|
mock_tenant = _make_mock_tenant()
|
||||||
|
session = _make_mock_session()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def execute_side_effect(stmt):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
result = MagicMock()
|
||||||
|
if call_count == 1:
|
||||||
|
# Load invitation
|
||||||
|
result.scalar_one_or_none = MagicMock(return_value=mock_inv)
|
||||||
|
elif call_count == 2:
|
||||||
|
# Load tenant
|
||||||
|
result.scalar_one_or_none = MagicMock(return_value=mock_tenant)
|
||||||
|
else:
|
||||||
|
result.scalar_one_or_none = MagicMock(return_value=None)
|
||||||
|
return result
|
||||||
|
|
||||||
|
session.execute = AsyncMock(side_effect=execute_side_effect)
|
||||||
|
|
||||||
|
async def mock_refresh(obj):
|
||||||
|
pass
|
||||||
|
|
||||||
|
session.refresh = AsyncMock(side_effect=mock_refresh)
|
||||||
|
|
||||||
|
with patch("shared.api.invitations._dispatch_invite_email"):
|
||||||
|
result = await resend_invitation(
|
||||||
|
invitation_id=mock_inv.id,
|
||||||
|
caller=caller,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# token_hash should have been updated
|
||||||
|
assert mock_inv.token_hash != old_token_hash
|
||||||
|
# expires_at should have been extended
|
||||||
|
assert mock_inv.expires_at > old_expires
|
||||||
|
# Raw token returned
|
||||||
|
assert result.token is not None
|
||||||
279
tests/unit/test_portal_auth.py
Normal file
279
tests/unit/test_portal_auth.py
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the updated auth/verify endpoint.
|
||||||
|
|
||||||
|
Tests verify that the response shape matches the new RBAC contract:
|
||||||
|
- Returns `role` (not `is_admin`)
|
||||||
|
- Returns `tenant_ids` as a list of UUID strings
|
||||||
|
- Returns `active_tenant_id` as the first tenant ID (or None)
|
||||||
|
- platform_admin returns all tenant IDs from the tenants table
|
||||||
|
- customer_admin returns only their UserTenantRole tenant IDs
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from shared.api.portal import AuthVerifyRequest, AuthVerifyResponse, verify_credentials
|
||||||
|
from shared.models.auth import PortalUser, UserTenantRole
|
||||||
|
from shared.models.tenant import Tenant
|
||||||
|
|
||||||
|
|
||||||
|
def _make_user(role: str, email: str = "test@example.com") -> PortalUser:
|
||||||
|
user = MagicMock(spec=PortalUser)
|
||||||
|
user.id = uuid.uuid4()
|
||||||
|
user.email = email
|
||||||
|
user.name = "Test User"
|
||||||
|
user.role = role
|
||||||
|
# Real bcrypt hash for password "testpassword"
|
||||||
|
user.hashed_password = bcrypt.hashpw(b"testpassword", bcrypt.gensalt()).decode()
|
||||||
|
user.created_at = datetime.now(tz=timezone.utc)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tenant_role(user_id: uuid.UUID, tenant_id: uuid.UUID, role: str) -> UserTenantRole:
|
||||||
|
m = MagicMock(spec=UserTenantRole)
|
||||||
|
m.user_id = user_id
|
||||||
|
m.tenant_id = tenant_id
|
||||||
|
m.role = role
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tenant(name: str = "Acme") -> Tenant:
|
||||||
|
t = MagicMock(spec=Tenant)
|
||||||
|
t.id = uuid.uuid4()
|
||||||
|
t.name = name
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# test_auth_verify_returns_role
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_verify_returns_role() -> None:
|
||||||
|
"""
|
||||||
|
auth/verify response contains 'role' field (not 'is_admin').
|
||||||
|
"""
|
||||||
|
user = _make_user("customer_admin")
|
||||||
|
session = AsyncMock()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def execute_side_effect(stmt):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
result = MagicMock()
|
||||||
|
if call_count == 1:
|
||||||
|
# User lookup
|
||||||
|
result.scalar_one_or_none = MagicMock(return_value=user)
|
||||||
|
else:
|
||||||
|
# UserTenantRole lookup — empty
|
||||||
|
result.scalars = MagicMock(return_value=MagicMock(all=MagicMock(return_value=[])))
|
||||||
|
return result
|
||||||
|
|
||||||
|
session.execute = AsyncMock(side_effect=execute_side_effect)
|
||||||
|
|
||||||
|
body = AuthVerifyRequest(email=user.email, password="testpassword")
|
||||||
|
response = await verify_credentials(body=body, session=session)
|
||||||
|
|
||||||
|
assert isinstance(response, AuthVerifyResponse)
|
||||||
|
assert response.role == "customer_admin"
|
||||||
|
# Ensure is_admin is NOT in the response model fields
|
||||||
|
assert not hasattr(response, "is_admin") or not isinstance(getattr(response, "is_admin", None), bool)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# test_auth_verify_returns_tenant_ids
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_verify_returns_tenant_ids() -> None:
|
||||||
|
"""
|
||||||
|
auth/verify response contains tenant_ids as a list of UUID strings.
|
||||||
|
"""
|
||||||
|
user = _make_user("customer_admin")
|
||||||
|
tenant_id_1 = uuid.uuid4()
|
||||||
|
tenant_id_2 = uuid.uuid4()
|
||||||
|
memberships = [
|
||||||
|
_make_tenant_role(user.id, tenant_id_1, "customer_admin"),
|
||||||
|
_make_tenant_role(user.id, tenant_id_2, "customer_admin"),
|
||||||
|
]
|
||||||
|
session = AsyncMock()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def execute_side_effect(stmt):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
result = MagicMock()
|
||||||
|
if call_count == 1:
|
||||||
|
result.scalar_one_or_none = MagicMock(return_value=user)
|
||||||
|
else:
|
||||||
|
result.scalars = MagicMock(return_value=MagicMock(all=MagicMock(return_value=memberships)))
|
||||||
|
return result
|
||||||
|
|
||||||
|
session.execute = AsyncMock(side_effect=execute_side_effect)
|
||||||
|
|
||||||
|
body = AuthVerifyRequest(email=user.email, password="testpassword")
|
||||||
|
response = await verify_credentials(body=body, session=session)
|
||||||
|
|
||||||
|
assert isinstance(response.tenant_ids, list)
|
||||||
|
assert len(response.tenant_ids) == 2
|
||||||
|
assert str(tenant_id_1) in response.tenant_ids
|
||||||
|
assert str(tenant_id_2) in response.tenant_ids
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# test_auth_verify_returns_active_tenant
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_verify_returns_active_tenant_first() -> None:
|
||||||
|
"""
|
||||||
|
auth/verify response contains active_tenant_id as the first tenant ID.
|
||||||
|
"""
|
||||||
|
user = _make_user("customer_admin")
|
||||||
|
tenant_id_1 = uuid.uuid4()
|
||||||
|
memberships = [
|
||||||
|
_make_tenant_role(user.id, tenant_id_1, "customer_admin"),
|
||||||
|
]
|
||||||
|
session = AsyncMock()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def execute_side_effect(stmt):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
result = MagicMock()
|
||||||
|
if call_count == 1:
|
||||||
|
result.scalar_one_or_none = MagicMock(return_value=user)
|
||||||
|
else:
|
||||||
|
result.scalars = MagicMock(return_value=MagicMock(all=MagicMock(return_value=memberships)))
|
||||||
|
return result
|
||||||
|
|
||||||
|
session.execute = AsyncMock(side_effect=execute_side_effect)
|
||||||
|
|
||||||
|
body = AuthVerifyRequest(email=user.email, password="testpassword")
|
||||||
|
response = await verify_credentials(body=body, session=session)
|
||||||
|
|
||||||
|
assert response.active_tenant_id == str(tenant_id_1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_verify_active_tenant_none_for_no_memberships() -> None:
|
||||||
|
"""
|
||||||
|
auth/verify response contains active_tenant_id=None for users with no tenant memberships.
|
||||||
|
"""
|
||||||
|
user = _make_user("customer_admin")
|
||||||
|
session = AsyncMock()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def execute_side_effect(stmt):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
result = MagicMock()
|
||||||
|
if call_count == 1:
|
||||||
|
result.scalar_one_or_none = MagicMock(return_value=user)
|
||||||
|
else:
|
||||||
|
result.scalars = MagicMock(return_value=MagicMock(all=MagicMock(return_value=[])))
|
||||||
|
return result
|
||||||
|
|
||||||
|
session.execute = AsyncMock(side_effect=execute_side_effect)
|
||||||
|
|
||||||
|
body = AuthVerifyRequest(email=user.email, password="testpassword")
|
||||||
|
response = await verify_credentials(body=body, session=session)
|
||||||
|
|
||||||
|
assert response.active_tenant_id is None
|
||||||
|
assert response.tenant_ids == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# test_auth_verify_platform_admin_returns_all_tenants
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_verify_platform_admin_returns_all_tenants() -> None:
|
||||||
|
"""
|
||||||
|
platform_admin auth/verify returns all tenant IDs from the tenants table.
|
||||||
|
"""
|
||||||
|
user = _make_user("platform_admin")
|
||||||
|
tenant_1 = _make_tenant("Acme")
|
||||||
|
tenant_2 = _make_tenant("Globex")
|
||||||
|
all_tenants = [tenant_1, tenant_2]
|
||||||
|
|
||||||
|
session = AsyncMock()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def execute_side_effect(stmt):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
result = MagicMock()
|
||||||
|
if call_count == 1:
|
||||||
|
# User lookup
|
||||||
|
result.scalar_one_or_none = MagicMock(return_value=user)
|
||||||
|
else:
|
||||||
|
# All tenants query for platform_admin
|
||||||
|
result.scalars = MagicMock(return_value=MagicMock(all=MagicMock(return_value=all_tenants)))
|
||||||
|
return result
|
||||||
|
|
||||||
|
session.execute = AsyncMock(side_effect=execute_side_effect)
|
||||||
|
|
||||||
|
body = AuthVerifyRequest(email=user.email, password="testpassword")
|
||||||
|
response = await verify_credentials(body=body, session=session)
|
||||||
|
|
||||||
|
assert response.role == "platform_admin"
|
||||||
|
assert len(response.tenant_ids) == 2
|
||||||
|
assert str(tenant_1.id) in response.tenant_ids
|
||||||
|
assert str(tenant_2.id) in response.tenant_ids
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# test_auth_verify_customer_admin_only_own_tenants
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_verify_customer_admin_only_own_tenants() -> None:
|
||||||
|
"""
|
||||||
|
customer_admin auth/verify returns only their UserTenantRole tenant IDs.
|
||||||
|
Not the full tenant list.
|
||||||
|
"""
|
||||||
|
user = _make_user("customer_admin")
|
||||||
|
own_tenant_id = uuid.uuid4()
|
||||||
|
other_tenant_id = uuid.uuid4() # Should NOT appear in response
|
||||||
|
|
||||||
|
memberships = [_make_tenant_role(user.id, own_tenant_id, "customer_admin")]
|
||||||
|
session = AsyncMock()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def execute_side_effect(stmt):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
result = MagicMock()
|
||||||
|
if call_count == 1:
|
||||||
|
result.scalar_one_or_none = MagicMock(return_value=user)
|
||||||
|
else:
|
||||||
|
result.scalars = MagicMock(return_value=MagicMock(all=MagicMock(return_value=memberships)))
|
||||||
|
return result
|
||||||
|
|
||||||
|
session.execute = AsyncMock(side_effect=execute_side_effect)
|
||||||
|
|
||||||
|
body = AuthVerifyRequest(email=user.email, password="testpassword")
|
||||||
|
response = await verify_credentials(body=body, session=session)
|
||||||
|
|
||||||
|
assert response.role == "customer_admin"
|
||||||
|
assert len(response.tenant_ids) == 1
|
||||||
|
assert str(own_tenant_id) in response.tenant_ids
|
||||||
|
assert str(other_tenant_id) not in response.tenant_ids
|
||||||
188
tests/unit/test_rbac_guards.py
Normal file
188
tests/unit/test_rbac_guards.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for RBAC guard FastAPI dependencies.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- test_platform_admin_passes: platform_admin caller passes require_platform_admin
|
||||||
|
- test_customer_admin_rejected: customer_admin gets 403 from require_platform_admin
|
||||||
|
- test_customer_operator_rejected: customer_operator gets 403 from require_platform_admin
|
||||||
|
- test_tenant_admin_own_tenant: customer_admin with membership passes require_tenant_admin
|
||||||
|
- test_tenant_admin_no_membership: customer_admin without membership gets 403
|
||||||
|
- test_platform_admin_bypasses_tenant_check: platform_admin passes require_tenant_admin
|
||||||
|
without a UserTenantRole row (no DB query for membership)
|
||||||
|
- test_operator_rejected_from_admin: customer_operator gets 403 from require_tenant_admin
|
||||||
|
- test_tenant_member_all_roles: customer_admin and customer_operator with membership pass
|
||||||
|
- test_tenant_member_no_membership: user with no membership gets 403
|
||||||
|
- test_platform_admin_bypasses_tenant_member: platform_admin passes require_tenant_member
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from shared.api.rbac import (
|
||||||
|
PortalCaller,
|
||||||
|
require_platform_admin,
|
||||||
|
require_tenant_admin,
|
||||||
|
require_tenant_member,
|
||||||
|
)
|
||||||
|
from shared.models.auth import UserTenantRole
|
||||||
|
|
||||||
|
|
||||||
|
def _make_caller(role: str, tenant_id: uuid.UUID | None = None) -> PortalCaller:
|
||||||
|
return PortalCaller(user_id=uuid.uuid4(), role=role, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_membership(user_id: uuid.UUID, tenant_id: uuid.UUID, role: str) -> UserTenantRole:
|
||||||
|
m = MagicMock(spec=UserTenantRole)
|
||||||
|
m.user_id = user_id
|
||||||
|
m.tenant_id = tenant_id
|
||||||
|
m.role = role
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_session_with_membership(membership: UserTenantRole | None) -> AsyncMock:
|
||||||
|
session = AsyncMock()
|
||||||
|
session.execute.return_value = MagicMock(
|
||||||
|
scalar_one_or_none=MagicMock(return_value=membership)
|
||||||
|
)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# require_platform_admin tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_platform_admin_passes() -> None:
|
||||||
|
"""platform_admin caller should pass require_platform_admin and return the caller."""
|
||||||
|
caller = _make_caller("platform_admin")
|
||||||
|
result = require_platform_admin(caller=caller)
|
||||||
|
assert result is caller
|
||||||
|
|
||||||
|
|
||||||
|
def test_customer_admin_rejected() -> None:
|
||||||
|
"""customer_admin should get 403 from require_platform_admin."""
|
||||||
|
caller = _make_caller("customer_admin")
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
require_platform_admin(caller=caller)
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_customer_operator_rejected() -> None:
|
||||||
|
"""customer_operator should get 403 from require_platform_admin."""
|
||||||
|
caller = _make_caller("customer_operator")
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
require_platform_admin(caller=caller)
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# require_tenant_admin tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tenant_admin_own_tenant() -> None:
|
||||||
|
"""customer_admin with UserTenantRole membership passes require_tenant_admin."""
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
caller = _make_caller("customer_admin")
|
||||||
|
membership = _make_membership(caller.user_id, tenant_id, "customer_admin")
|
||||||
|
session = _mock_session_with_membership(membership)
|
||||||
|
|
||||||
|
result = await require_tenant_admin(tenant_id=tenant_id, caller=caller, session=session)
|
||||||
|
assert result is caller
|
||||||
|
session.execute.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tenant_admin_no_membership() -> None:
|
||||||
|
"""customer_admin without UserTenantRole row gets 403 from require_tenant_admin."""
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
caller = _make_caller("customer_admin")
|
||||||
|
session = _mock_session_with_membership(None)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await require_tenant_admin(tenant_id=tenant_id, caller=caller, session=session)
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_platform_admin_bypasses_tenant_check() -> None:
|
||||||
|
"""platform_admin passes require_tenant_admin without any DB membership query."""
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
caller = _make_caller("platform_admin")
|
||||||
|
session = AsyncMock() # Should NOT be called
|
||||||
|
|
||||||
|
result = await require_tenant_admin(tenant_id=tenant_id, caller=caller, session=session)
|
||||||
|
assert result is caller
|
||||||
|
session.execute.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_operator_rejected_from_admin() -> None:
|
||||||
|
"""customer_operator always gets 403 from require_tenant_admin (cannot be admin)."""
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
caller = _make_caller("customer_operator")
|
||||||
|
session = AsyncMock()
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await require_tenant_admin(tenant_id=tenant_id, caller=caller, session=session)
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
session.execute.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# require_tenant_member tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tenant_member_customer_admin() -> None:
|
||||||
|
"""customer_admin with membership passes require_tenant_member."""
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
caller = _make_caller("customer_admin")
|
||||||
|
membership = _make_membership(caller.user_id, tenant_id, "customer_admin")
|
||||||
|
session = _mock_session_with_membership(membership)
|
||||||
|
|
||||||
|
result = await require_tenant_member(tenant_id=tenant_id, caller=caller, session=session)
|
||||||
|
assert result is caller
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tenant_member_customer_operator() -> None:
|
||||||
|
"""customer_operator with membership passes require_tenant_member."""
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
caller = _make_caller("customer_operator")
|
||||||
|
membership = _make_membership(caller.user_id, tenant_id, "customer_operator")
|
||||||
|
session = _mock_session_with_membership(membership)
|
||||||
|
|
||||||
|
result = await require_tenant_member(tenant_id=tenant_id, caller=caller, session=session)
|
||||||
|
assert result is caller
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tenant_member_no_membership() -> None:
|
||||||
|
"""User with no UserTenantRole row gets 403 from require_tenant_member."""
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
caller = _make_caller("customer_admin")
|
||||||
|
session = _mock_session_with_membership(None)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await require_tenant_member(tenant_id=tenant_id, caller=caller, session=session)
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_platform_admin_bypasses_tenant_member() -> None:
|
||||||
|
"""platform_admin passes require_tenant_member without DB membership check."""
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
caller = _make_caller("platform_admin")
|
||||||
|
session = AsyncMock()
|
||||||
|
|
||||||
|
result = await require_tenant_member(tenant_id=tenant_id, caller=caller, session=session)
|
||||||
|
assert result is caller
|
||||||
|
session.execute.assert_not_called()
|
||||||
Reference in New Issue
Block a user