- Add ChannelType.WEB = 'web' to shared/models/message.py - Add webchat_response_key() to shared/redis_keys.py - Create WebConversation and WebConversationMessage ORM models (SQLAlchemy 2.0) - Create migration 008_web_chat.py with RLS, indexes, and channel_type CHECK update - Pop conversation_id/portal_user_id extras in handle_message before model_validate - Add web case to _build_response_extras and _send_response (Redis pub-sub publish) - Import webchat_response_key in orchestrator/tasks.py - Write 19 unit tests covering CHAT-01 through CHAT-05 (all pass)
284 lines
9.9 KiB
Python
284 lines
9.9 KiB
Python
"""
|
|
Unit tests for the chat REST API with RBAC enforcement.
|
|
|
|
Tests:
|
|
- test_chat_rbac_enforcement: GET /api/portal/chat/conversations?tenant_id=X returns 403
|
|
when caller is not a member of tenant X
|
|
- test_platform_admin_cross_tenant: GET /api/portal/chat/conversations?tenant_id=X returns 200
|
|
when caller is platform_admin (bypasses membership check)
|
|
- test_list_conversation_history: GET /api/portal/chat/conversations/{id}/messages returns
|
|
paginated messages ordered by created_at
|
|
- test_create_conversation: POST /api/portal/chat/conversations creates or returns existing
|
|
conversation for user+agent pair
|
|
- test_create_conversation_rbac: POST returns 403 for non-member caller
|
|
- test_delete_conversation_resets_messages: DELETE /api/portal/chat/conversations/{id} deletes
|
|
messages but keeps the conversation row
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
from shared.api.rbac import PortalCaller
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _admin_headers(user_id: str | None = None) -> dict[str, str]:
|
|
return {
|
|
"X-Portal-User-Id": user_id or str(uuid.uuid4()),
|
|
"X-Portal-User-Role": "platform_admin",
|
|
}
|
|
|
|
|
|
def _stranger_headers(user_id: str | None = None) -> dict[str, str]:
|
|
return {
|
|
"X-Portal-User-Id": user_id or str(uuid.uuid4()),
|
|
"X-Portal-User-Role": "customer_operator",
|
|
}
|
|
|
|
|
|
def _make_app_with_session_override(mock_session: AsyncMock) -> FastAPI:
|
|
"""Create a test FastAPI app with the chat router and a session dependency override."""
|
|
from shared.api.chat import chat_router
|
|
from shared.db import get_session
|
|
|
|
app = FastAPI()
|
|
app.include_router(chat_router)
|
|
|
|
async def _override_get_session(): # type: ignore[return]
|
|
yield mock_session
|
|
|
|
app.dependency_overrides[get_session] = _override_get_session
|
|
return app
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RBAC enforcement on list conversations
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_rbac_enforcement() -> None:
|
|
"""Non-member caller gets 403 when listing conversations for a tenant they don't belong to."""
|
|
tenant_id = uuid.uuid4()
|
|
user_id = uuid.uuid4()
|
|
|
|
# Mock session — no membership row found (require_tenant_member checks UserTenantRole)
|
|
mock_session = AsyncMock()
|
|
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
|
|
|
|
app = _make_app_with_session_override(mock_session)
|
|
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
|
response = await client.get(
|
|
"/api/portal/chat/conversations",
|
|
params={"tenant_id": str(tenant_id)},
|
|
headers=_stranger_headers(str(user_id)),
|
|
)
|
|
|
|
assert response.status_code == 403
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_platform_admin_cross_tenant() -> None:
|
|
"""Platform admin can list conversations for any tenant (bypasses membership check)."""
|
|
tenant_id = uuid.uuid4()
|
|
user_id = uuid.uuid4()
|
|
|
|
# Mock session — returns empty rows for conversation query
|
|
mock_session = AsyncMock()
|
|
mock_result = MagicMock()
|
|
mock_result.all.return_value = []
|
|
mock_session.execute.return_value = mock_result
|
|
|
|
app = _make_app_with_session_override(mock_session)
|
|
|
|
with (
|
|
patch("shared.api.chat.configure_rls_hook"),
|
|
patch("shared.api.chat.current_tenant_id"),
|
|
):
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
|
response = await client.get(
|
|
"/api/portal/chat/conversations",
|
|
params={"tenant_id": str(tenant_id)},
|
|
headers=_admin_headers(str(user_id)),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert isinstance(response.json(), list)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# List conversation history (paginated messages)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_conversation_history() -> None:
|
|
"""GET /api/portal/chat/conversations/{id}/messages returns paginated messages ordered by created_at."""
|
|
user_id = uuid.uuid4()
|
|
conv_id = uuid.uuid4()
|
|
|
|
# Mock conversation owned by the caller
|
|
mock_conv = MagicMock()
|
|
mock_conv.id = conv_id
|
|
mock_conv.user_id = user_id
|
|
mock_conv.tenant_id = uuid.uuid4()
|
|
|
|
# Mock messages
|
|
now = datetime.now(timezone.utc)
|
|
mock_msg1 = MagicMock()
|
|
mock_msg1.id = uuid.uuid4()
|
|
mock_msg1.role = "user"
|
|
mock_msg1.content = "Hello"
|
|
mock_msg1.created_at = now
|
|
|
|
mock_msg2 = MagicMock()
|
|
mock_msg2.id = uuid.uuid4()
|
|
mock_msg2.role = "assistant"
|
|
mock_msg2.content = "Hi there!"
|
|
mock_msg2.created_at = now
|
|
|
|
mock_session = AsyncMock()
|
|
# First call: fetch conversation; second call: fetch messages
|
|
mock_session.execute.side_effect = [
|
|
MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv)),
|
|
MagicMock(scalars=MagicMock(return_value=MagicMock(all=MagicMock(return_value=[mock_msg1, mock_msg2])))),
|
|
]
|
|
|
|
app = _make_app_with_session_override(mock_session)
|
|
|
|
with (
|
|
patch("shared.api.chat.configure_rls_hook"),
|
|
patch("shared.api.chat.current_tenant_id"),
|
|
):
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
|
response = await client.get(
|
|
f"/api/portal/chat/conversations/{conv_id}/messages",
|
|
headers=_admin_headers(str(user_id)),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
assert len(data) == 2
|
|
assert data[0]["role"] == "user"
|
|
assert data[1]["role"] == "assistant"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Create conversation (get-or-create)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_conversation() -> None:
|
|
"""POST /api/portal/chat/conversations creates a new conversation for user+agent pair."""
|
|
tenant_id = uuid.uuid4()
|
|
agent_id = uuid.uuid4()
|
|
user_id = uuid.uuid4()
|
|
conv_id = uuid.uuid4()
|
|
|
|
now = datetime.now(timezone.utc)
|
|
|
|
# Platform admin bypasses membership check; no existing conversation found
|
|
mock_session = AsyncMock()
|
|
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
|
|
mock_session.flush = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
mock_session.add = MagicMock()
|
|
|
|
# refresh populates server-default fields on the passed ORM object
|
|
async def _mock_refresh(obj: object) -> None:
|
|
obj.id = conv_id # type: ignore[attr-defined]
|
|
obj.created_at = now # type: ignore[attr-defined]
|
|
obj.updated_at = now # type: ignore[attr-defined]
|
|
|
|
mock_session.refresh = _mock_refresh
|
|
|
|
app = _make_app_with_session_override(mock_session)
|
|
|
|
with (
|
|
patch("shared.api.chat.configure_rls_hook"),
|
|
patch("shared.api.chat.current_tenant_id"),
|
|
):
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
|
response = await client.post(
|
|
"/api/portal/chat/conversations",
|
|
json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)},
|
|
headers=_admin_headers(str(user_id)),
|
|
)
|
|
|
|
assert response.status_code in (200, 201)
|
|
data = response.json()
|
|
assert "id" in data
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_conversation_rbac_forbidden() -> None:
|
|
"""Non-member gets 403 when creating a conversation in a tenant they don't belong to."""
|
|
tenant_id = uuid.uuid4()
|
|
agent_id = uuid.uuid4()
|
|
user_id = uuid.uuid4()
|
|
|
|
# Membership check returns None (not a member)
|
|
mock_session = AsyncMock()
|
|
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
|
|
|
|
app = _make_app_with_session_override(mock_session)
|
|
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
|
response = await client.post(
|
|
"/api/portal/chat/conversations",
|
|
json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)},
|
|
headers=_stranger_headers(str(user_id)),
|
|
)
|
|
|
|
assert response.status_code == 403
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Delete conversation (reset messages)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_conversation_resets_messages() -> None:
|
|
"""DELETE /api/portal/chat/conversations/{id} deletes messages but keeps conversation row."""
|
|
user_id = uuid.uuid4()
|
|
conv_id = uuid.uuid4()
|
|
|
|
mock_conv = MagicMock()
|
|
mock_conv.id = conv_id
|
|
mock_conv.user_id = user_id
|
|
mock_conv.tenant_id = uuid.uuid4()
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv))
|
|
mock_session.commit = AsyncMock()
|
|
|
|
app = _make_app_with_session_override(mock_session)
|
|
|
|
with (
|
|
patch("shared.api.chat.configure_rls_hook"),
|
|
patch("shared.api.chat.current_tenant_id"),
|
|
):
|
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
|
response = await client.delete(
|
|
f"/api/portal/chat/conversations/{conv_id}",
|
|
headers=_admin_headers(str(user_id)),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert mock_session.execute.call_count >= 1
|