feat(06-01): add web channel type, Redis key, ORM models, migration, and tests

- Add ChannelType.WEB = 'web' to shared/models/message.py
- Add webchat_response_key() to shared/redis_keys.py
- Create WebConversation and WebConversationMessage ORM models (SQLAlchemy 2.0)
- Create migration 008_web_chat.py with RLS, indexes, and channel_type CHECK update
- Pop conversation_id/portal_user_id extras in handle_message before model_validate
- Add web case to _build_response_extras and _send_response (Redis pub-sub publish)
- Import webchat_response_key in orchestrator/tasks.py
- Write 19 unit tests covering CHAT-01 through CHAT-05 (all pass)
This commit is contained in:
2026-03-25 10:26:34 -06:00
parent c0fa0cefee
commit c72beb916b
7 changed files with 957 additions and 1 deletions

283
tests/unit/test_chat_api.py Normal file
View File

@@ -0,0 +1,283 @@
"""
Unit tests for the chat REST API with RBAC enforcement.
Tests:
- test_chat_rbac_enforcement: GET /api/portal/chat/conversations?tenant_id=X returns 403
when caller is not a member of tenant X
- test_platform_admin_cross_tenant: GET /api/portal/chat/conversations?tenant_id=X returns 200
when caller is platform_admin (bypasses membership check)
- test_list_conversation_history: GET /api/portal/chat/conversations/{id}/messages returns
paginated messages ordered by created_at
- test_create_conversation: POST /api/portal/chat/conversations creates or returns existing
conversation for user+agent pair
- test_create_conversation_rbac: POST returns 403 for non-member caller
- test_delete_conversation_resets_messages: DELETE /api/portal/chat/conversations/{id} deletes
messages but keeps the conversation row
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from shared.api.rbac import PortalCaller
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _admin_headers(user_id: str | None = None) -> dict[str, str]:
return {
"X-Portal-User-Id": user_id or str(uuid.uuid4()),
"X-Portal-User-Role": "platform_admin",
}
def _stranger_headers(user_id: str | None = None) -> dict[str, str]:
return {
"X-Portal-User-Id": user_id or str(uuid.uuid4()),
"X-Portal-User-Role": "customer_operator",
}
def _make_app_with_session_override(mock_session: AsyncMock) -> FastAPI:
"""Create a test FastAPI app with the chat router and a session dependency override."""
from shared.api.chat import chat_router
from shared.db import get_session
app = FastAPI()
app.include_router(chat_router)
async def _override_get_session(): # type: ignore[return]
yield mock_session
app.dependency_overrides[get_session] = _override_get_session
return app
# ---------------------------------------------------------------------------
# RBAC enforcement on list conversations
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_rbac_enforcement() -> None:
"""Non-member caller gets 403 when listing conversations for a tenant they don't belong to."""
tenant_id = uuid.uuid4()
user_id = uuid.uuid4()
# Mock session — no membership row found (require_tenant_member checks UserTenantRole)
mock_session = AsyncMock()
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
app = _make_app_with_session_override(mock_session)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.get(
"/api/portal/chat/conversations",
params={"tenant_id": str(tenant_id)},
headers=_stranger_headers(str(user_id)),
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_platform_admin_cross_tenant() -> None:
"""Platform admin can list conversations for any tenant (bypasses membership check)."""
tenant_id = uuid.uuid4()
user_id = uuid.uuid4()
# Mock session — returns empty rows for conversation query
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.all.return_value = []
mock_session.execute.return_value = mock_result
app = _make_app_with_session_override(mock_session)
with (
patch("shared.api.chat.configure_rls_hook"),
patch("shared.api.chat.current_tenant_id"),
):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.get(
"/api/portal/chat/conversations",
params={"tenant_id": str(tenant_id)},
headers=_admin_headers(str(user_id)),
)
assert response.status_code == 200
assert isinstance(response.json(), list)
# ---------------------------------------------------------------------------
# List conversation history (paginated messages)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_list_conversation_history() -> None:
"""GET /api/portal/chat/conversations/{id}/messages returns paginated messages ordered by created_at."""
user_id = uuid.uuid4()
conv_id = uuid.uuid4()
# Mock conversation owned by the caller
mock_conv = MagicMock()
mock_conv.id = conv_id
mock_conv.user_id = user_id
mock_conv.tenant_id = uuid.uuid4()
# Mock messages
now = datetime.now(timezone.utc)
mock_msg1 = MagicMock()
mock_msg1.id = uuid.uuid4()
mock_msg1.role = "user"
mock_msg1.content = "Hello"
mock_msg1.created_at = now
mock_msg2 = MagicMock()
mock_msg2.id = uuid.uuid4()
mock_msg2.role = "assistant"
mock_msg2.content = "Hi there!"
mock_msg2.created_at = now
mock_session = AsyncMock()
# First call: fetch conversation; second call: fetch messages
mock_session.execute.side_effect = [
MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv)),
MagicMock(scalars=MagicMock(return_value=MagicMock(all=MagicMock(return_value=[mock_msg1, mock_msg2])))),
]
app = _make_app_with_session_override(mock_session)
with (
patch("shared.api.chat.configure_rls_hook"),
patch("shared.api.chat.current_tenant_id"),
):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.get(
f"/api/portal/chat/conversations/{conv_id}/messages",
headers=_admin_headers(str(user_id)),
)
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 2
assert data[0]["role"] == "user"
assert data[1]["role"] == "assistant"
# ---------------------------------------------------------------------------
# Create conversation (get-or-create)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_create_conversation() -> None:
"""POST /api/portal/chat/conversations creates a new conversation for user+agent pair."""
tenant_id = uuid.uuid4()
agent_id = uuid.uuid4()
user_id = uuid.uuid4()
conv_id = uuid.uuid4()
now = datetime.now(timezone.utc)
# Platform admin bypasses membership check; no existing conversation found
mock_session = AsyncMock()
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
mock_session.flush = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.add = MagicMock()
# refresh populates server-default fields on the passed ORM object
async def _mock_refresh(obj: object) -> None:
obj.id = conv_id # type: ignore[attr-defined]
obj.created_at = now # type: ignore[attr-defined]
obj.updated_at = now # type: ignore[attr-defined]
mock_session.refresh = _mock_refresh
app = _make_app_with_session_override(mock_session)
with (
patch("shared.api.chat.configure_rls_hook"),
patch("shared.api.chat.current_tenant_id"),
):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/api/portal/chat/conversations",
json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)},
headers=_admin_headers(str(user_id)),
)
assert response.status_code in (200, 201)
data = response.json()
assert "id" in data
@pytest.mark.asyncio
async def test_create_conversation_rbac_forbidden() -> None:
"""Non-member gets 403 when creating a conversation in a tenant they don't belong to."""
tenant_id = uuid.uuid4()
agent_id = uuid.uuid4()
user_id = uuid.uuid4()
# Membership check returns None (not a member)
mock_session = AsyncMock()
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
app = _make_app_with_session_override(mock_session)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/api/portal/chat/conversations",
json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)},
headers=_stranger_headers(str(user_id)),
)
assert response.status_code == 403
# ---------------------------------------------------------------------------
# Delete conversation (reset messages)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_delete_conversation_resets_messages() -> None:
"""DELETE /api/portal/chat/conversations/{id} deletes messages but keeps conversation row."""
user_id = uuid.uuid4()
conv_id = uuid.uuid4()
mock_conv = MagicMock()
mock_conv.id = conv_id
mock_conv.user_id = user_id
mock_conv.tenant_id = uuid.uuid4()
mock_session = AsyncMock()
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv))
mock_session.commit = AsyncMock()
app = _make_app_with_session_override(mock_session)
with (
patch("shared.api.chat.configure_rls_hook"),
patch("shared.api.chat.current_tenant_id"),
):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.delete(
f"/api/portal/chat/conversations/{conv_id}",
headers=_admin_headers(str(user_id)),
)
assert response.status_code == 200
assert mock_session.execute.call_count >= 1

View File

@@ -0,0 +1,312 @@
"""
Unit tests for the web channel adapter.
Tests:
- test_normalize_web_event: normalize_web_event returns KonstructMessage with channel=WEB
- test_normalize_web_event_thread_id: thread_id equals conversation_id
- test_normalize_web_event_sender: sender.user_id equals portal user UUID
- test_webchat_response_key: webchat_response_key returns correct namespaced key
- test_send_response_web_publishes_to_redis: _send_response("web", ...) publishes JSON to Redis
- test_typing_indicator_sent: WebSocket handler sends {"type": "typing"} before Celery dispatch
"""
from __future__ import annotations
import asyncio
import json
import uuid
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from shared.models.message import ChannelType, KonstructMessage
from shared.redis_keys import webchat_response_key
# ---------------------------------------------------------------------------
# Test webchat_response_key
# ---------------------------------------------------------------------------
def test_webchat_response_key_format() -> None:
"""webchat_response_key returns correctly namespaced key."""
tenant_id = "tenant-abc"
conversation_id = "conv-xyz"
key = webchat_response_key(tenant_id, conversation_id)
assert key == "tenant-abc:webchat:response:conv-xyz"
def test_webchat_response_key_isolation() -> None:
"""Two tenants with same conversation_id get different keys."""
key_a = webchat_response_key("tenant-a", "conv-1")
key_b = webchat_response_key("tenant-b", "conv-1")
assert key_a != key_b
# ---------------------------------------------------------------------------
# Test ChannelType.WEB
# ---------------------------------------------------------------------------
def test_channel_type_web_exists() -> None:
"""ChannelType.WEB must exist with value 'web'."""
assert ChannelType.WEB == "web"
assert ChannelType.WEB.value == "web"
# ---------------------------------------------------------------------------
# Test normalize_web_event
# ---------------------------------------------------------------------------
def test_normalize_web_event_returns_konstruct_message() -> None:
"""normalize_web_event returns a KonstructMessage."""
from gateway.channels.web import normalize_web_event
tenant_id = str(uuid.uuid4())
agent_id = str(uuid.uuid4())
user_id = str(uuid.uuid4())
conversation_id = str(uuid.uuid4())
event = {
"text": "Hello from the portal",
"tenant_id": tenant_id,
"agent_id": agent_id,
"user_id": user_id,
"display_name": "Portal User",
"conversation_id": conversation_id,
}
msg = normalize_web_event(event)
assert isinstance(msg, KonstructMessage)
def test_normalize_web_event_channel_is_web() -> None:
"""normalize_web_event sets channel = ChannelType.WEB."""
from gateway.channels.web import normalize_web_event
event = {
"text": "test",
"tenant_id": str(uuid.uuid4()),
"agent_id": str(uuid.uuid4()),
"user_id": str(uuid.uuid4()),
"display_name": "User",
"conversation_id": str(uuid.uuid4()),
}
msg = normalize_web_event(event)
assert msg.channel == ChannelType.WEB
def test_normalize_web_event_thread_id_equals_conversation_id() -> None:
"""normalize_web_event sets thread_id = conversation_id for memory scoping."""
from gateway.channels.web import normalize_web_event
conversation_id = str(uuid.uuid4())
event = {
"text": "test",
"tenant_id": str(uuid.uuid4()),
"agent_id": str(uuid.uuid4()),
"user_id": str(uuid.uuid4()),
"display_name": "User",
"conversation_id": conversation_id,
}
msg = normalize_web_event(event)
assert msg.thread_id == conversation_id
def test_normalize_web_event_sender_user_id() -> None:
"""normalize_web_event sets sender.user_id to the portal user UUID."""
from gateway.channels.web import normalize_web_event
user_id = str(uuid.uuid4())
event = {
"text": "test",
"tenant_id": str(uuid.uuid4()),
"agent_id": str(uuid.uuid4()),
"user_id": user_id,
"display_name": "Portal User",
"conversation_id": str(uuid.uuid4()),
}
msg = normalize_web_event(event)
assert msg.sender.user_id == user_id
def test_normalize_web_event_channel_metadata() -> None:
"""normalize_web_event populates channel_metadata with portal_user_id, tenant_id, conversation_id."""
from gateway.channels.web import normalize_web_event
tenant_id = str(uuid.uuid4())
user_id = str(uuid.uuid4())
conversation_id = str(uuid.uuid4())
event = {
"text": "test",
"tenant_id": tenant_id,
"agent_id": str(uuid.uuid4()),
"user_id": user_id,
"display_name": "User",
"conversation_id": conversation_id,
}
msg = normalize_web_event(event)
assert msg.channel_metadata["portal_user_id"] == user_id
assert msg.channel_metadata["tenant_id"] == tenant_id
assert msg.channel_metadata["conversation_id"] == conversation_id
def test_normalize_web_event_tenant_id() -> None:
"""normalize_web_event sets tenant_id on the message."""
from gateway.channels.web import normalize_web_event
tenant_id = str(uuid.uuid4())
event = {
"text": "hello",
"tenant_id": tenant_id,
"agent_id": str(uuid.uuid4()),
"user_id": str(uuid.uuid4()),
"display_name": "User",
"conversation_id": str(uuid.uuid4()),
}
msg = normalize_web_event(event)
assert msg.tenant_id == tenant_id
# ---------------------------------------------------------------------------
# Test _send_response web case publishes to Redis
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_send_response_web_publishes_to_redis() -> None:
"""_send_response('web', ...) publishes JSON message to Redis webchat channel."""
tenant_id = str(uuid.uuid4())
conversation_id = str(uuid.uuid4())
mock_redis = AsyncMock()
mock_redis.publish = AsyncMock()
mock_redis.aclose = AsyncMock()
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
from orchestrator.tasks import _send_response
extras = {
"conversation_id": conversation_id,
"tenant_id": tenant_id,
}
await _send_response("web", "Hello, this is the agent response!", extras)
expected_channel = webchat_response_key(tenant_id, conversation_id)
mock_redis.publish.assert_called_once()
call_args = mock_redis.publish.call_args
assert call_args[0][0] == expected_channel
published_payload = json.loads(call_args[0][1])
assert published_payload["type"] == "response"
assert published_payload["text"] == "Hello, this is the agent response!"
assert published_payload["conversation_id"] == conversation_id
@pytest.mark.asyncio
async def test_send_response_web_connection_cleanup() -> None:
"""_send_response web case always closes Redis connection (try/finally)."""
mock_redis = AsyncMock()
mock_redis.publish = AsyncMock()
mock_redis.aclose = AsyncMock()
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
from orchestrator.tasks import _send_response
await _send_response("web", "test", {"conversation_id": "c1", "tenant_id": "t1"})
mock_redis.aclose.assert_called_once()
@pytest.mark.asyncio
async def test_send_response_web_missing_conversation_id_logs_warning() -> None:
"""_send_response web case logs warning if conversation_id missing."""
mock_redis = AsyncMock()
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
with patch("orchestrator.tasks.logger") as mock_logger:
from orchestrator.tasks import _send_response
await _send_response("web", "test", {"tenant_id": "t1"})
mock_logger.warning.assert_called()
mock_redis.publish.assert_not_called()
# ---------------------------------------------------------------------------
# Test typing indicator
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_typing_indicator_sent_before_dispatch() -> None:
"""WebSocket handler sends {'type': 'typing'} immediately after receiving user message."""
# We test the typing indicator by calling the handler function directly
# with a mocked WebSocket and mocked Celery dispatch.
from unittest.mock import AsyncMock, MagicMock, call, patch
mock_ws = AsyncMock()
# First receive_json returns auth message, second returns the user message
mock_ws.receive_json = AsyncMock(
side_effect=[
{"type": "auth", "userId": str(uuid.uuid4()), "role": "customer_operator", "tenantId": str(uuid.uuid4())},
{"type": "message", "text": "Hello agent", "agentId": str(uuid.uuid4()), "conversationId": str(uuid.uuid4())},
]
)
# accept must be awaitable
mock_ws.accept = AsyncMock()
mock_ws.send_json = AsyncMock()
# Mock DB session that returns a conversation
mock_session = AsyncMock()
mock_conv = MagicMock()
mock_conv.id = uuid.uuid4()
mock_conv.tenant_id = uuid.uuid4()
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv))
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
# Mock Redis pub-sub: raise after publishing once so handler exits
mock_redis = AsyncMock()
mock_pubsub = AsyncMock()
mock_pubsub.subscribe = AsyncMock()
mock_pubsub.get_message = AsyncMock(return_value={
"type": "message",
"data": json.dumps({
"type": "response",
"text": "Agent reply",
"conversation_id": str(uuid.uuid4()),
}),
})
mock_pubsub.unsubscribe = AsyncMock()
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
mock_redis.aclose = AsyncMock()
mock_handle = MagicMock()
with (
patch("gateway.channels.web.async_session_factory", return_value=mock_session),
patch("gateway.channels.web.aioredis.from_url", return_value=mock_redis),
patch("gateway.channels.web.handle_message") as mock_handle_msg,
patch("gateway.channels.web.configure_rls_hook"),
patch("gateway.channels.web.current_tenant_id"),
):
mock_handle_msg.delay = MagicMock()
from gateway.channels.web import _handle_websocket_connection
# Run the handler — it should send typing indicator then process
# Use asyncio.wait_for to prevent infinite loop if something hangs
try:
await asyncio.wait_for(
_handle_websocket_connection(mock_ws, str(uuid.uuid4())),
timeout=2.0,
)
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
pass
# Check typing was sent before Celery dispatch
send_calls = mock_ws.send_json.call_args_list
typing_calls = [c for c in send_calls if c[0][0].get("type") == "typing"]
assert len(typing_calls) >= 1, f"Expected typing indicator, got send_json calls: {send_calls}"