diff --git a/tests/unit/test_invitations.py b/tests/unit/test_invitations.py new file mode 100644 index 0000000..d60d63f --- /dev/null +++ b/tests/unit/test_invitations.py @@ -0,0 +1,368 @@ +""" +Unit tests for invitation HMAC token utilities and invitation API endpoints. + +Tests: + - test_token_roundtrip: generate then validate returns same invitation_id + - test_token_tamper_rejected: modified signature raises ValueError + - test_token_expired_rejected: artificially old token raises ValueError + - test_invite_create: POST /api/portal/invitations creates invitation, returns 201 + token + - test_invite_accept_creates_user: accepting invite creates PortalUser + UserTenantRole + - test_invite_accept_rejects_expired: expired invitation (status != pending) raises error + - test_invite_resend_updates_token: resend generates new token_hash and extends expires_at +""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from shared.invite_token import generate_invite_token, token_to_hash, validate_invite_token + + +# --------------------------------------------------------------------------- +# HMAC token tests (unit, no DB) +# --------------------------------------------------------------------------- + + +def test_token_roundtrip() -> None: + """generate_invite_token then validate_invite_token returns the same invitation_id.""" + inv_id = str(uuid.uuid4()) + token = generate_invite_token(inv_id) + result = validate_invite_token(token) + assert result == inv_id + + +def test_token_tamper_rejected() -> None: + """A token with a modified signature raises ValueError.""" + inv_id = str(uuid.uuid4()) + token = generate_invite_token(inv_id) + + # Truncate and append garbage to corrupt the signature portion + corrupted = token[:-8] + "XXXXXXXX" + with pytest.raises(ValueError): + validate_invite_token(corrupted) + + +def test_token_expired_rejected() -> None: + """A token older than 48h raises ValueError on validation.""" + inv_id = str(uuid.uuid4()) + + # Patch time.time to simulate a token created 49 hours ago + past_time = 1000000 # some fixed old epoch value + current_time = past_time + (49 * 3600) # 49h later + + with patch("shared.invite_token.time.time", return_value=past_time): + token = generate_invite_token(inv_id) + + with patch("shared.invite_token.time.time", return_value=current_time): + with pytest.raises(ValueError, match="expired"): + validate_invite_token(token) + + +def test_token_hash_is_hex_sha256() -> None: + """token_to_hash returns a 64-char hex string (SHA-256).""" + token = generate_invite_token(str(uuid.uuid4())) + h = token_to_hash(token) + assert len(h) == 64 + # Valid hex + int(h, 16) + + +def test_two_tokens_different() -> None: + """Two tokens for the same invitation_id at different times have different hashes.""" + inv_id = str(uuid.uuid4()) + + with patch("shared.invite_token.time.time", return_value=1000000): + token1 = generate_invite_token(inv_id) + + with patch("shared.invite_token.time.time", return_value=1000001): + token2 = generate_invite_token(inv_id) + + assert token1 != token2 + assert token_to_hash(token1) != token_to_hash(token2) + + +# --------------------------------------------------------------------------- +# Invitation API tests (mock DB session) +# --------------------------------------------------------------------------- + + +def _make_mock_session() -> AsyncMock: + """Create a mock AsyncSession.""" + session = AsyncMock() + session.execute.return_value = MagicMock( + scalar_one_or_none=MagicMock(return_value=None), + scalars=MagicMock(return_value=MagicMock(all=MagicMock(return_value=[]))), + ) + return session + + +def _make_mock_tenant(name: str = "Acme Corp") -> MagicMock: + t = MagicMock() + t.id = uuid.uuid4() + t.name = name + return t + + +def _make_mock_invitation( + tenant_id: uuid.UUID | None = None, + status: str = "pending", + expires_in_hours: float = 24, +) -> MagicMock: + inv = MagicMock() + inv.id = uuid.uuid4() + inv.email = "new@example.com" + inv.name = "New User" + inv.tenant_id = tenant_id or uuid.uuid4() + inv.role = "customer_admin" + inv.invited_by = uuid.uuid4() + inv.token_hash = "abc123" + inv.status = status + inv.expires_at = datetime.now(tz=timezone.utc) + timedelta(hours=expires_in_hours) + inv.created_at = datetime.now(tz=timezone.utc) + return inv + + +@pytest.mark.asyncio +async def test_invite_create_returns_201_with_token() -> None: + """ + POST /api/portal/invitations creates an invitation and returns 201 with raw token. + """ + from shared.api.invitations import InvitationCreate, create_invitation + + tenant_id = uuid.uuid4() + caller_id = uuid.uuid4() + + from shared.api.rbac import PortalCaller + caller = PortalCaller(user_id=caller_id, role="customer_admin", tenant_id=tenant_id) + + session = _make_mock_session() + mock_tenant = _make_mock_tenant() + + # Mock: require_tenant_admin (membership check) returns membership + mock_membership = MagicMock() + + # Setup execute to return tenant on first call, membership on second call + call_count = 0 + def execute_side_effect(stmt): + nonlocal call_count + call_count += 1 + result = MagicMock() + if call_count == 1: + # require_tenant_admin — membership check + result.scalar_one_or_none = MagicMock(return_value=mock_membership) + elif call_count == 2: + # _get_tenant_or_404 + result.scalar_one_or_none = MagicMock(return_value=mock_tenant) + else: + result.scalar_one_or_none = MagicMock(return_value=None) + return result + + session.execute = AsyncMock(side_effect=execute_side_effect) + session.flush = AsyncMock() + + # Mock refresh to set created_at on the invitation object + async def mock_refresh(obj): + if not hasattr(obj, 'created_at') or obj.created_at is None: + obj.created_at = datetime.now(tz=timezone.utc) + + session.refresh = AsyncMock(side_effect=mock_refresh) + + body = InvitationCreate( + email="new@example.com", + name="New User", + role="customer_admin", + tenant_id=tenant_id, + ) + + with patch("shared.api.invitations._dispatch_invite_email") as mock_dispatch: + result = await create_invitation(body=body, caller=caller, session=session) + + assert result.email == "new@example.com" + assert result.name == "New User" + assert result.token is not None + assert len(result.token) > 0 + mock_dispatch.assert_called_once() + + +@pytest.mark.asyncio +async def test_invite_accept_creates_user() -> None: + """ + POST /api/portal/invitations/accept with valid token creates PortalUser + UserTenantRole. + """ + from shared.api.invitations import InvitationAccept, accept_invitation + + inv_id = uuid.uuid4() + tenant_id = uuid.uuid4() + + # Create a valid token + token = generate_invite_token(str(inv_id)) + + mock_inv = _make_mock_invitation(tenant_id=tenant_id, status="pending") + mock_inv.id = inv_id + mock_inv.expires_at = datetime.now(tz=timezone.utc) + timedelta(hours=24) + + session = _make_mock_session() + + call_count = 0 + added_objects = [] + + def execute_side_effect(stmt): + nonlocal call_count + call_count += 1 + result = MagicMock() + if call_count == 1: + # Load invitation by ID + result.scalar_one_or_none = MagicMock(return_value=mock_inv) + elif call_count == 2: + # Check existing user by email + result.scalar_one_or_none = MagicMock(return_value=None) + else: + result.scalar_one_or_none = MagicMock(return_value=None) + return result + + session.execute = AsyncMock(side_effect=execute_side_effect) + + def capture_add(obj): + added_objects.append(obj) + + session.add = MagicMock(side_effect=capture_add) + + async def mock_refresh(obj): + # Ensure the user has an id and role + pass + + session.refresh = AsyncMock(side_effect=mock_refresh) + + body = InvitationAccept(token=token, password="securepassword123") + result = await accept_invitation(body=body, session=session) + + assert result.email == mock_inv.email + assert result.name == mock_inv.name + assert result.role == mock_inv.role + + # Verify user and membership were added + from shared.models.auth import PortalUser, UserTenantRole + portal_users = [o for o in added_objects if isinstance(o, PortalUser)] + memberships = [o for o in added_objects if isinstance(o, UserTenantRole)] + assert len(portal_users) == 1, f"Expected 1 PortalUser, got {len(portal_users)}" + assert len(memberships) == 1, f"Expected 1 UserTenantRole, got {len(memberships)}" + + # Verify invitation was marked accepted + assert mock_inv.status == "accepted" + + +@pytest.mark.asyncio +async def test_invite_accept_rejects_non_pending() -> None: + """ + POST /api/portal/invitations/accept with already-accepted invitation returns 409. + """ + from fastapi import HTTPException + + from shared.api.invitations import InvitationAccept, accept_invitation + + inv_id = uuid.uuid4() + token = generate_invite_token(str(inv_id)) + + mock_inv = _make_mock_invitation(status="accepted") + mock_inv.id = inv_id + + session = _make_mock_session() + session.execute = AsyncMock( + return_value=MagicMock(scalar_one_or_none=MagicMock(return_value=mock_inv)) + ) + + body = InvitationAccept(token=token, password="password123") + + with pytest.raises(HTTPException) as exc_info: + await accept_invitation(body=body, session=session) + + assert exc_info.value.status_code == 409 + + +@pytest.mark.asyncio +async def test_invite_accept_rejects_expired_invitation() -> None: + """ + POST /api/portal/invitations/accept with past expires_at returns 400. + """ + from fastapi import HTTPException + + from shared.api.invitations import InvitationAccept, accept_invitation + + inv_id = uuid.uuid4() + token = generate_invite_token(str(inv_id)) + + mock_inv = _make_mock_invitation(status="pending") + mock_inv.id = inv_id + mock_inv.expires_at = datetime.now(tz=timezone.utc) - timedelta(hours=1) # Expired + + session = _make_mock_session() + session.execute = AsyncMock( + return_value=MagicMock(scalar_one_or_none=MagicMock(return_value=mock_inv)) + ) + + body = InvitationAccept(token=token, password="password123") + + with pytest.raises(HTTPException) as exc_info: + await accept_invitation(body=body, session=session) + + assert exc_info.value.status_code == 400 + assert "expired" in exc_info.value.detail.lower() + + +@pytest.mark.asyncio +async def test_invite_resend_updates_token() -> None: + """ + POST /api/portal/invitations/{id}/resend generates new token_hash and extends expires_at. + """ + from shared.api.invitations import resend_invitation + from shared.api.rbac import PortalCaller + + tenant_id = uuid.uuid4() + caller = PortalCaller(user_id=uuid.uuid4(), role="platform_admin", tenant_id=tenant_id) + mock_inv = _make_mock_invitation(tenant_id=tenant_id, status="pending") + old_token_hash = mock_inv.token_hash + old_expires = mock_inv.expires_at + + mock_tenant = _make_mock_tenant() + session = _make_mock_session() + + call_count = 0 + + def execute_side_effect(stmt): + nonlocal call_count + call_count += 1 + result = MagicMock() + if call_count == 1: + # Load invitation + result.scalar_one_or_none = MagicMock(return_value=mock_inv) + elif call_count == 2: + # Load tenant + result.scalar_one_or_none = MagicMock(return_value=mock_tenant) + else: + result.scalar_one_or_none = MagicMock(return_value=None) + return result + + session.execute = AsyncMock(side_effect=execute_side_effect) + + async def mock_refresh(obj): + pass + + session.refresh = AsyncMock(side_effect=mock_refresh) + + with patch("shared.api.invitations._dispatch_invite_email"): + result = await resend_invitation( + invitation_id=mock_inv.id, + caller=caller, + session=session, + ) + + # token_hash should have been updated + assert mock_inv.token_hash != old_token_hash + # expires_at should have been extended + assert mock_inv.expires_at > old_expires + # Raw token returned + assert result.token is not None diff --git a/tests/unit/test_portal_auth.py b/tests/unit/test_portal_auth.py new file mode 100644 index 0000000..8b680c8 --- /dev/null +++ b/tests/unit/test_portal_auth.py @@ -0,0 +1,279 @@ +""" +Unit tests for the updated auth/verify endpoint. + +Tests verify that the response shape matches the new RBAC contract: + - Returns `role` (not `is_admin`) + - Returns `tenant_ids` as a list of UUID strings + - Returns `active_tenant_id` as the first tenant ID (or None) + - platform_admin returns all tenant IDs from the tenants table + - customer_admin returns only their UserTenantRole tenant IDs +""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import bcrypt +import pytest + +from shared.api.portal import AuthVerifyRequest, AuthVerifyResponse, verify_credentials +from shared.models.auth import PortalUser, UserTenantRole +from shared.models.tenant import Tenant + + +def _make_user(role: str, email: str = "test@example.com") -> PortalUser: + user = MagicMock(spec=PortalUser) + user.id = uuid.uuid4() + user.email = email + user.name = "Test User" + user.role = role + # Real bcrypt hash for password "testpassword" + user.hashed_password = bcrypt.hashpw(b"testpassword", bcrypt.gensalt()).decode() + user.created_at = datetime.now(tz=timezone.utc) + return user + + +def _make_tenant_role(user_id: uuid.UUID, tenant_id: uuid.UUID, role: str) -> UserTenantRole: + m = MagicMock(spec=UserTenantRole) + m.user_id = user_id + m.tenant_id = tenant_id + m.role = role + return m + + +def _make_tenant(name: str = "Acme") -> Tenant: + t = MagicMock(spec=Tenant) + t.id = uuid.uuid4() + t.name = name + return t + + +# --------------------------------------------------------------------------- +# test_auth_verify_returns_role +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_auth_verify_returns_role() -> None: + """ + auth/verify response contains 'role' field (not 'is_admin'). + """ + user = _make_user("customer_admin") + session = AsyncMock() + + call_count = 0 + + def execute_side_effect(stmt): + nonlocal call_count + call_count += 1 + result = MagicMock() + if call_count == 1: + # User lookup + result.scalar_one_or_none = MagicMock(return_value=user) + else: + # UserTenantRole lookup — empty + result.scalars = MagicMock(return_value=MagicMock(all=MagicMock(return_value=[]))) + return result + + session.execute = AsyncMock(side_effect=execute_side_effect) + + body = AuthVerifyRequest(email=user.email, password="testpassword") + response = await verify_credentials(body=body, session=session) + + assert isinstance(response, AuthVerifyResponse) + assert response.role == "customer_admin" + # Ensure is_admin is NOT in the response model fields + assert not hasattr(response, "is_admin") or not isinstance(getattr(response, "is_admin", None), bool) + + +# --------------------------------------------------------------------------- +# test_auth_verify_returns_tenant_ids +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_auth_verify_returns_tenant_ids() -> None: + """ + auth/verify response contains tenant_ids as a list of UUID strings. + """ + user = _make_user("customer_admin") + tenant_id_1 = uuid.uuid4() + tenant_id_2 = uuid.uuid4() + memberships = [ + _make_tenant_role(user.id, tenant_id_1, "customer_admin"), + _make_tenant_role(user.id, tenant_id_2, "customer_admin"), + ] + session = AsyncMock() + + call_count = 0 + + def execute_side_effect(stmt): + nonlocal call_count + call_count += 1 + result = MagicMock() + if call_count == 1: + result.scalar_one_or_none = MagicMock(return_value=user) + else: + result.scalars = MagicMock(return_value=MagicMock(all=MagicMock(return_value=memberships))) + return result + + session.execute = AsyncMock(side_effect=execute_side_effect) + + body = AuthVerifyRequest(email=user.email, password="testpassword") + response = await verify_credentials(body=body, session=session) + + assert isinstance(response.tenant_ids, list) + assert len(response.tenant_ids) == 2 + assert str(tenant_id_1) in response.tenant_ids + assert str(tenant_id_2) in response.tenant_ids + + +# --------------------------------------------------------------------------- +# test_auth_verify_returns_active_tenant +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_auth_verify_returns_active_tenant_first() -> None: + """ + auth/verify response contains active_tenant_id as the first tenant ID. + """ + user = _make_user("customer_admin") + tenant_id_1 = uuid.uuid4() + memberships = [ + _make_tenant_role(user.id, tenant_id_1, "customer_admin"), + ] + session = AsyncMock() + + call_count = 0 + + def execute_side_effect(stmt): + nonlocal call_count + call_count += 1 + result = MagicMock() + if call_count == 1: + result.scalar_one_or_none = MagicMock(return_value=user) + else: + result.scalars = MagicMock(return_value=MagicMock(all=MagicMock(return_value=memberships))) + return result + + session.execute = AsyncMock(side_effect=execute_side_effect) + + body = AuthVerifyRequest(email=user.email, password="testpassword") + response = await verify_credentials(body=body, session=session) + + assert response.active_tenant_id == str(tenant_id_1) + + +@pytest.mark.asyncio +async def test_auth_verify_active_tenant_none_for_no_memberships() -> None: + """ + auth/verify response contains active_tenant_id=None for users with no tenant memberships. + """ + user = _make_user("customer_admin") + session = AsyncMock() + + call_count = 0 + + def execute_side_effect(stmt): + nonlocal call_count + call_count += 1 + result = MagicMock() + if call_count == 1: + result.scalar_one_or_none = MagicMock(return_value=user) + else: + result.scalars = MagicMock(return_value=MagicMock(all=MagicMock(return_value=[]))) + return result + + session.execute = AsyncMock(side_effect=execute_side_effect) + + body = AuthVerifyRequest(email=user.email, password="testpassword") + response = await verify_credentials(body=body, session=session) + + assert response.active_tenant_id is None + assert response.tenant_ids == [] + + +# --------------------------------------------------------------------------- +# test_auth_verify_platform_admin_returns_all_tenants +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_auth_verify_platform_admin_returns_all_tenants() -> None: + """ + platform_admin auth/verify returns all tenant IDs from the tenants table. + """ + user = _make_user("platform_admin") + tenant_1 = _make_tenant("Acme") + tenant_2 = _make_tenant("Globex") + all_tenants = [tenant_1, tenant_2] + + session = AsyncMock() + + call_count = 0 + + def execute_side_effect(stmt): + nonlocal call_count + call_count += 1 + result = MagicMock() + if call_count == 1: + # User lookup + result.scalar_one_or_none = MagicMock(return_value=user) + else: + # All tenants query for platform_admin + result.scalars = MagicMock(return_value=MagicMock(all=MagicMock(return_value=all_tenants))) + return result + + session.execute = AsyncMock(side_effect=execute_side_effect) + + body = AuthVerifyRequest(email=user.email, password="testpassword") + response = await verify_credentials(body=body, session=session) + + assert response.role == "platform_admin" + assert len(response.tenant_ids) == 2 + assert str(tenant_1.id) in response.tenant_ids + assert str(tenant_2.id) in response.tenant_ids + + +# --------------------------------------------------------------------------- +# test_auth_verify_customer_admin_only_own_tenants +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_auth_verify_customer_admin_only_own_tenants() -> None: + """ + customer_admin auth/verify returns only their UserTenantRole tenant IDs. + Not the full tenant list. + """ + user = _make_user("customer_admin") + own_tenant_id = uuid.uuid4() + other_tenant_id = uuid.uuid4() # Should NOT appear in response + + memberships = [_make_tenant_role(user.id, own_tenant_id, "customer_admin")] + session = AsyncMock() + + call_count = 0 + + def execute_side_effect(stmt): + nonlocal call_count + call_count += 1 + result = MagicMock() + if call_count == 1: + result.scalar_one_or_none = MagicMock(return_value=user) + else: + result.scalars = MagicMock(return_value=MagicMock(all=MagicMock(return_value=memberships))) + return result + + session.execute = AsyncMock(side_effect=execute_side_effect) + + body = AuthVerifyRequest(email=user.email, password="testpassword") + response = await verify_credentials(body=body, session=session) + + assert response.role == "customer_admin" + assert len(response.tenant_ids) == 1 + assert str(own_tenant_id) in response.tenant_ids + assert str(other_tenant_id) not in response.tenant_ids diff --git a/tests/unit/test_rbac_guards.py b/tests/unit/test_rbac_guards.py new file mode 100644 index 0000000..23ba153 --- /dev/null +++ b/tests/unit/test_rbac_guards.py @@ -0,0 +1,188 @@ +""" +Unit tests for RBAC guard FastAPI dependencies. + +Tests: + - test_platform_admin_passes: platform_admin caller passes require_platform_admin + - test_customer_admin_rejected: customer_admin gets 403 from require_platform_admin + - test_customer_operator_rejected: customer_operator gets 403 from require_platform_admin + - test_tenant_admin_own_tenant: customer_admin with membership passes require_tenant_admin + - test_tenant_admin_no_membership: customer_admin without membership gets 403 + - test_platform_admin_bypasses_tenant_check: platform_admin passes require_tenant_admin + without a UserTenantRole row (no DB query for membership) + - test_operator_rejected_from_admin: customer_operator gets 403 from require_tenant_admin + - test_tenant_member_all_roles: customer_admin and customer_operator with membership pass + - test_tenant_member_no_membership: user with no membership gets 403 + - test_platform_admin_bypasses_tenant_member: platform_admin passes require_tenant_member +""" + +from __future__ import annotations + +import uuid +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import HTTPException + +from shared.api.rbac import ( + PortalCaller, + require_platform_admin, + require_tenant_admin, + require_tenant_member, +) +from shared.models.auth import UserTenantRole + + +def _make_caller(role: str, tenant_id: uuid.UUID | None = None) -> PortalCaller: + return PortalCaller(user_id=uuid.uuid4(), role=role, tenant_id=tenant_id) + + +def _make_membership(user_id: uuid.UUID, tenant_id: uuid.UUID, role: str) -> UserTenantRole: + m = MagicMock(spec=UserTenantRole) + m.user_id = user_id + m.tenant_id = tenant_id + m.role = role + return m + + +def _mock_session_with_membership(membership: UserTenantRole | None) -> AsyncMock: + session = AsyncMock() + session.execute.return_value = MagicMock( + scalar_one_or_none=MagicMock(return_value=membership) + ) + return session + + +# --------------------------------------------------------------------------- +# require_platform_admin tests +# --------------------------------------------------------------------------- + + +def test_platform_admin_passes() -> None: + """platform_admin caller should pass require_platform_admin and return the caller.""" + caller = _make_caller("platform_admin") + result = require_platform_admin(caller=caller) + assert result is caller + + +def test_customer_admin_rejected() -> None: + """customer_admin should get 403 from require_platform_admin.""" + caller = _make_caller("customer_admin") + with pytest.raises(HTTPException) as exc_info: + require_platform_admin(caller=caller) + assert exc_info.value.status_code == 403 + + +def test_customer_operator_rejected() -> None: + """customer_operator should get 403 from require_platform_admin.""" + caller = _make_caller("customer_operator") + with pytest.raises(HTTPException) as exc_info: + require_platform_admin(caller=caller) + assert exc_info.value.status_code == 403 + + +# --------------------------------------------------------------------------- +# require_tenant_admin tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_tenant_admin_own_tenant() -> None: + """customer_admin with UserTenantRole membership passes require_tenant_admin.""" + tenant_id = uuid.uuid4() + caller = _make_caller("customer_admin") + membership = _make_membership(caller.user_id, tenant_id, "customer_admin") + session = _mock_session_with_membership(membership) + + result = await require_tenant_admin(tenant_id=tenant_id, caller=caller, session=session) + assert result is caller + session.execute.assert_called_once() + + +@pytest.mark.asyncio +async def test_tenant_admin_no_membership() -> None: + """customer_admin without UserTenantRole row gets 403 from require_tenant_admin.""" + tenant_id = uuid.uuid4() + caller = _make_caller("customer_admin") + session = _mock_session_with_membership(None) + + with pytest.raises(HTTPException) as exc_info: + await require_tenant_admin(tenant_id=tenant_id, caller=caller, session=session) + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_platform_admin_bypasses_tenant_check() -> None: + """platform_admin passes require_tenant_admin without any DB membership query.""" + tenant_id = uuid.uuid4() + caller = _make_caller("platform_admin") + session = AsyncMock() # Should NOT be called + + result = await require_tenant_admin(tenant_id=tenant_id, caller=caller, session=session) + assert result is caller + session.execute.assert_not_called() + + +@pytest.mark.asyncio +async def test_operator_rejected_from_admin() -> None: + """customer_operator always gets 403 from require_tenant_admin (cannot be admin).""" + tenant_id = uuid.uuid4() + caller = _make_caller("customer_operator") + session = AsyncMock() + + with pytest.raises(HTTPException) as exc_info: + await require_tenant_admin(tenant_id=tenant_id, caller=caller, session=session) + assert exc_info.value.status_code == 403 + session.execute.assert_not_called() + + +# --------------------------------------------------------------------------- +# require_tenant_member tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_tenant_member_customer_admin() -> None: + """customer_admin with membership passes require_tenant_member.""" + tenant_id = uuid.uuid4() + caller = _make_caller("customer_admin") + membership = _make_membership(caller.user_id, tenant_id, "customer_admin") + session = _mock_session_with_membership(membership) + + result = await require_tenant_member(tenant_id=tenant_id, caller=caller, session=session) + assert result is caller + + +@pytest.mark.asyncio +async def test_tenant_member_customer_operator() -> None: + """customer_operator with membership passes require_tenant_member.""" + tenant_id = uuid.uuid4() + caller = _make_caller("customer_operator") + membership = _make_membership(caller.user_id, tenant_id, "customer_operator") + session = _mock_session_with_membership(membership) + + result = await require_tenant_member(tenant_id=tenant_id, caller=caller, session=session) + assert result is caller + + +@pytest.mark.asyncio +async def test_tenant_member_no_membership() -> None: + """User with no UserTenantRole row gets 403 from require_tenant_member.""" + tenant_id = uuid.uuid4() + caller = _make_caller("customer_admin") + session = _mock_session_with_membership(None) + + with pytest.raises(HTTPException) as exc_info: + await require_tenant_member(tenant_id=tenant_id, caller=caller, session=session) + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_platform_admin_bypasses_tenant_member() -> None: + """platform_admin passes require_tenant_member without DB membership check.""" + tenant_id = uuid.uuid4() + caller = _make_caller("platform_admin") + session = AsyncMock() + + result = await require_tenant_member(tenant_id=tenant_id, caller=caller, session=session) + assert result is caller + session.execute.assert_not_called()