""" Unit tests for the chat REST API with RBAC enforcement. Tests: - test_chat_rbac_enforcement: GET /api/portal/chat/conversations?tenant_id=X returns 403 when caller is not a member of tenant X - test_platform_admin_cross_tenant: GET /api/portal/chat/conversations?tenant_id=X returns 200 when caller is platform_admin (bypasses membership check) - test_list_conversation_history: GET /api/portal/chat/conversations/{id}/messages returns paginated messages ordered by created_at - test_create_conversation: POST /api/portal/chat/conversations creates or returns existing conversation for user+agent pair - test_create_conversation_rbac: POST returns 403 for non-member caller - test_delete_conversation_resets_messages: DELETE /api/portal/chat/conversations/{id} deletes messages but keeps the conversation row """ from __future__ import annotations import uuid from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import FastAPI from httpx import ASGITransport, AsyncClient from shared.api.rbac import PortalCaller # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _admin_headers(user_id: str | None = None) -> dict[str, str]: return { "X-Portal-User-Id": user_id or str(uuid.uuid4()), "X-Portal-User-Role": "platform_admin", } def _stranger_headers(user_id: str | None = None) -> dict[str, str]: return { "X-Portal-User-Id": user_id or str(uuid.uuid4()), "X-Portal-User-Role": "customer_operator", } def _make_app_with_session_override(mock_session: AsyncMock) -> FastAPI: """Create a test FastAPI app with the chat router and a session dependency override.""" from shared.api.chat import chat_router from shared.db import get_session app = FastAPI() app.include_router(chat_router) async def _override_get_session(): # type: ignore[return] yield mock_session app.dependency_overrides[get_session] = _override_get_session return app # --------------------------------------------------------------------------- # RBAC enforcement on list conversations # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_chat_rbac_enforcement() -> None: """Non-member caller gets 403 when listing conversations for a tenant they don't belong to.""" tenant_id = uuid.uuid4() user_id = uuid.uuid4() # Mock session — no membership row found (require_tenant_member checks UserTenantRole) mock_session = AsyncMock() mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None)) app = _make_app_with_session_override(mock_session) async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: response = await client.get( "/api/portal/chat/conversations", params={"tenant_id": str(tenant_id)}, headers=_stranger_headers(str(user_id)), ) assert response.status_code == 403 @pytest.mark.asyncio async def test_platform_admin_cross_tenant() -> None: """Platform admin can list conversations for any tenant (bypasses membership check).""" tenant_id = uuid.uuid4() user_id = uuid.uuid4() # Mock session — returns empty rows for conversation query mock_session = AsyncMock() mock_result = MagicMock() mock_result.all.return_value = [] mock_session.execute.return_value = mock_result app = _make_app_with_session_override(mock_session) with ( patch("shared.api.chat.configure_rls_hook"), patch("shared.api.chat.current_tenant_id"), ): async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: response = await client.get( "/api/portal/chat/conversations", params={"tenant_id": str(tenant_id)}, headers=_admin_headers(str(user_id)), ) assert response.status_code == 200 assert isinstance(response.json(), list) # --------------------------------------------------------------------------- # List conversation history (paginated messages) # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_list_conversation_history() -> None: """GET /api/portal/chat/conversations/{id}/messages returns paginated messages ordered by created_at.""" user_id = uuid.uuid4() conv_id = uuid.uuid4() # Mock conversation owned by the caller mock_conv = MagicMock() mock_conv.id = conv_id mock_conv.user_id = user_id mock_conv.tenant_id = uuid.uuid4() # Mock messages now = datetime.now(timezone.utc) mock_msg1 = MagicMock() mock_msg1.id = uuid.uuid4() mock_msg1.role = "user" mock_msg1.content = "Hello" mock_msg1.created_at = now mock_msg2 = MagicMock() mock_msg2.id = uuid.uuid4() mock_msg2.role = "assistant" mock_msg2.content = "Hi there!" mock_msg2.created_at = now mock_session = AsyncMock() # First call: fetch conversation; second call: fetch messages mock_session.execute.side_effect = [ MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv)), MagicMock(scalars=MagicMock(return_value=MagicMock(all=MagicMock(return_value=[mock_msg1, mock_msg2])))), ] app = _make_app_with_session_override(mock_session) with ( patch("shared.api.chat.configure_rls_hook"), patch("shared.api.chat.current_tenant_id"), ): async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: response = await client.get( f"/api/portal/chat/conversations/{conv_id}/messages", headers=_admin_headers(str(user_id)), ) assert response.status_code == 200 data = response.json() assert isinstance(data, list) assert len(data) == 2 assert data[0]["role"] == "user" assert data[1]["role"] == "assistant" # --------------------------------------------------------------------------- # Create conversation (get-or-create) # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_create_conversation() -> None: """POST /api/portal/chat/conversations creates a new conversation for user+agent pair.""" tenant_id = uuid.uuid4() agent_id = uuid.uuid4() user_id = uuid.uuid4() conv_id = uuid.uuid4() now = datetime.now(timezone.utc) # Platform admin bypasses membership check; no existing conversation found mock_session = AsyncMock() mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None)) mock_session.flush = AsyncMock() mock_session.commit = AsyncMock() mock_session.add = MagicMock() # refresh populates server-default fields on the passed ORM object async def _mock_refresh(obj: object) -> None: obj.id = conv_id # type: ignore[attr-defined] obj.created_at = now # type: ignore[attr-defined] obj.updated_at = now # type: ignore[attr-defined] mock_session.refresh = _mock_refresh app = _make_app_with_session_override(mock_session) with ( patch("shared.api.chat.configure_rls_hook"), patch("shared.api.chat.current_tenant_id"), ): async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: response = await client.post( "/api/portal/chat/conversations", json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)}, headers=_admin_headers(str(user_id)), ) assert response.status_code in (200, 201) data = response.json() assert "id" in data @pytest.mark.asyncio async def test_create_conversation_rbac_forbidden() -> None: """Non-member gets 403 when creating a conversation in a tenant they don't belong to.""" tenant_id = uuid.uuid4() agent_id = uuid.uuid4() user_id = uuid.uuid4() # Membership check returns None (not a member) mock_session = AsyncMock() mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None)) app = _make_app_with_session_override(mock_session) async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: response = await client.post( "/api/portal/chat/conversations", json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)}, headers=_stranger_headers(str(user_id)), ) assert response.status_code == 403 # --------------------------------------------------------------------------- # Delete conversation (reset messages) # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_delete_conversation_resets_messages() -> None: """DELETE /api/portal/chat/conversations/{id} deletes messages but keeps conversation row.""" user_id = uuid.uuid4() conv_id = uuid.uuid4() mock_conv = MagicMock() mock_conv.id = conv_id mock_conv.user_id = user_id mock_conv.tenant_id = uuid.uuid4() mock_session = AsyncMock() mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv)) mock_session.commit = AsyncMock() app = _make_app_with_session_override(mock_session) with ( patch("shared.api.chat.configure_rls_hook"), patch("shared.api.chat.current_tenant_id"), ): async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: response = await client.delete( f"/api/portal/chat/conversations/{conv_id}", headers=_admin_headers(str(user_id)), ) assert response.status_code == 200 assert mock_session.execute.call_count >= 1