Files
konstruct/tests/unit/test_web_channel.py
Adolfo Delorenzo c72beb916b feat(06-01): add web channel type, Redis key, ORM models, migration, and tests
- Add ChannelType.WEB = 'web' to shared/models/message.py
- Add webchat_response_key() to shared/redis_keys.py
- Create WebConversation and WebConversationMessage ORM models (SQLAlchemy 2.0)
- Create migration 008_web_chat.py with RLS, indexes, and channel_type CHECK update
- Pop conversation_id/portal_user_id extras in handle_message before model_validate
- Add web case to _build_response_extras and _send_response (Redis pub-sub publish)
- Import webchat_response_key in orchestrator/tasks.py
- Write 19 unit tests covering CHAT-01 through CHAT-05 (all pass)
2026-03-25 10:26:34 -06:00

313 lines
11 KiB
Python

"""
Unit tests for the web channel adapter.
Tests:
- test_normalize_web_event: normalize_web_event returns KonstructMessage with channel=WEB
- test_normalize_web_event_thread_id: thread_id equals conversation_id
- test_normalize_web_event_sender: sender.user_id equals portal user UUID
- test_webchat_response_key: webchat_response_key returns correct namespaced key
- test_send_response_web_publishes_to_redis: _send_response("web", ...) publishes JSON to Redis
- test_typing_indicator_sent: WebSocket handler sends {"type": "typing"} before Celery dispatch
"""
from __future__ import annotations
import asyncio
import json
import uuid
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from shared.models.message import ChannelType, KonstructMessage
from shared.redis_keys import webchat_response_key
# ---------------------------------------------------------------------------
# Test webchat_response_key
# ---------------------------------------------------------------------------
def test_webchat_response_key_format() -> None:
"""webchat_response_key returns correctly namespaced key."""
tenant_id = "tenant-abc"
conversation_id = "conv-xyz"
key = webchat_response_key(tenant_id, conversation_id)
assert key == "tenant-abc:webchat:response:conv-xyz"
def test_webchat_response_key_isolation() -> None:
"""Two tenants with same conversation_id get different keys."""
key_a = webchat_response_key("tenant-a", "conv-1")
key_b = webchat_response_key("tenant-b", "conv-1")
assert key_a != key_b
# ---------------------------------------------------------------------------
# Test ChannelType.WEB
# ---------------------------------------------------------------------------
def test_channel_type_web_exists() -> None:
"""ChannelType.WEB must exist with value 'web'."""
assert ChannelType.WEB == "web"
assert ChannelType.WEB.value == "web"
# ---------------------------------------------------------------------------
# Test normalize_web_event
# ---------------------------------------------------------------------------
def test_normalize_web_event_returns_konstruct_message() -> None:
"""normalize_web_event returns a KonstructMessage."""
from gateway.channels.web import normalize_web_event
tenant_id = str(uuid.uuid4())
agent_id = str(uuid.uuid4())
user_id = str(uuid.uuid4())
conversation_id = str(uuid.uuid4())
event = {
"text": "Hello from the portal",
"tenant_id": tenant_id,
"agent_id": agent_id,
"user_id": user_id,
"display_name": "Portal User",
"conversation_id": conversation_id,
}
msg = normalize_web_event(event)
assert isinstance(msg, KonstructMessage)
def test_normalize_web_event_channel_is_web() -> None:
"""normalize_web_event sets channel = ChannelType.WEB."""
from gateway.channels.web import normalize_web_event
event = {
"text": "test",
"tenant_id": str(uuid.uuid4()),
"agent_id": str(uuid.uuid4()),
"user_id": str(uuid.uuid4()),
"display_name": "User",
"conversation_id": str(uuid.uuid4()),
}
msg = normalize_web_event(event)
assert msg.channel == ChannelType.WEB
def test_normalize_web_event_thread_id_equals_conversation_id() -> None:
"""normalize_web_event sets thread_id = conversation_id for memory scoping."""
from gateway.channels.web import normalize_web_event
conversation_id = str(uuid.uuid4())
event = {
"text": "test",
"tenant_id": str(uuid.uuid4()),
"agent_id": str(uuid.uuid4()),
"user_id": str(uuid.uuid4()),
"display_name": "User",
"conversation_id": conversation_id,
}
msg = normalize_web_event(event)
assert msg.thread_id == conversation_id
def test_normalize_web_event_sender_user_id() -> None:
"""normalize_web_event sets sender.user_id to the portal user UUID."""
from gateway.channels.web import normalize_web_event
user_id = str(uuid.uuid4())
event = {
"text": "test",
"tenant_id": str(uuid.uuid4()),
"agent_id": str(uuid.uuid4()),
"user_id": user_id,
"display_name": "Portal User",
"conversation_id": str(uuid.uuid4()),
}
msg = normalize_web_event(event)
assert msg.sender.user_id == user_id
def test_normalize_web_event_channel_metadata() -> None:
"""normalize_web_event populates channel_metadata with portal_user_id, tenant_id, conversation_id."""
from gateway.channels.web import normalize_web_event
tenant_id = str(uuid.uuid4())
user_id = str(uuid.uuid4())
conversation_id = str(uuid.uuid4())
event = {
"text": "test",
"tenant_id": tenant_id,
"agent_id": str(uuid.uuid4()),
"user_id": user_id,
"display_name": "User",
"conversation_id": conversation_id,
}
msg = normalize_web_event(event)
assert msg.channel_metadata["portal_user_id"] == user_id
assert msg.channel_metadata["tenant_id"] == tenant_id
assert msg.channel_metadata["conversation_id"] == conversation_id
def test_normalize_web_event_tenant_id() -> None:
"""normalize_web_event sets tenant_id on the message."""
from gateway.channels.web import normalize_web_event
tenant_id = str(uuid.uuid4())
event = {
"text": "hello",
"tenant_id": tenant_id,
"agent_id": str(uuid.uuid4()),
"user_id": str(uuid.uuid4()),
"display_name": "User",
"conversation_id": str(uuid.uuid4()),
}
msg = normalize_web_event(event)
assert msg.tenant_id == tenant_id
# ---------------------------------------------------------------------------
# Test _send_response web case publishes to Redis
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_send_response_web_publishes_to_redis() -> None:
"""_send_response('web', ...) publishes JSON message to Redis webchat channel."""
tenant_id = str(uuid.uuid4())
conversation_id = str(uuid.uuid4())
mock_redis = AsyncMock()
mock_redis.publish = AsyncMock()
mock_redis.aclose = AsyncMock()
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
from orchestrator.tasks import _send_response
extras = {
"conversation_id": conversation_id,
"tenant_id": tenant_id,
}
await _send_response("web", "Hello, this is the agent response!", extras)
expected_channel = webchat_response_key(tenant_id, conversation_id)
mock_redis.publish.assert_called_once()
call_args = mock_redis.publish.call_args
assert call_args[0][0] == expected_channel
published_payload = json.loads(call_args[0][1])
assert published_payload["type"] == "response"
assert published_payload["text"] == "Hello, this is the agent response!"
assert published_payload["conversation_id"] == conversation_id
@pytest.mark.asyncio
async def test_send_response_web_connection_cleanup() -> None:
"""_send_response web case always closes Redis connection (try/finally)."""
mock_redis = AsyncMock()
mock_redis.publish = AsyncMock()
mock_redis.aclose = AsyncMock()
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
from orchestrator.tasks import _send_response
await _send_response("web", "test", {"conversation_id": "c1", "tenant_id": "t1"})
mock_redis.aclose.assert_called_once()
@pytest.mark.asyncio
async def test_send_response_web_missing_conversation_id_logs_warning() -> None:
"""_send_response web case logs warning if conversation_id missing."""
mock_redis = AsyncMock()
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
with patch("orchestrator.tasks.logger") as mock_logger:
from orchestrator.tasks import _send_response
await _send_response("web", "test", {"tenant_id": "t1"})
mock_logger.warning.assert_called()
mock_redis.publish.assert_not_called()
# ---------------------------------------------------------------------------
# Test typing indicator
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_typing_indicator_sent_before_dispatch() -> None:
"""WebSocket handler sends {'type': 'typing'} immediately after receiving user message."""
# We test the typing indicator by calling the handler function directly
# with a mocked WebSocket and mocked Celery dispatch.
from unittest.mock import AsyncMock, MagicMock, call, patch
mock_ws = AsyncMock()
# First receive_json returns auth message, second returns the user message
mock_ws.receive_json = AsyncMock(
side_effect=[
{"type": "auth", "userId": str(uuid.uuid4()), "role": "customer_operator", "tenantId": str(uuid.uuid4())},
{"type": "message", "text": "Hello agent", "agentId": str(uuid.uuid4()), "conversationId": str(uuid.uuid4())},
]
)
# accept must be awaitable
mock_ws.accept = AsyncMock()
mock_ws.send_json = AsyncMock()
# Mock DB session that returns a conversation
mock_session = AsyncMock()
mock_conv = MagicMock()
mock_conv.id = uuid.uuid4()
mock_conv.tenant_id = uuid.uuid4()
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv))
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
# Mock Redis pub-sub: raise after publishing once so handler exits
mock_redis = AsyncMock()
mock_pubsub = AsyncMock()
mock_pubsub.subscribe = AsyncMock()
mock_pubsub.get_message = AsyncMock(return_value={
"type": "message",
"data": json.dumps({
"type": "response",
"text": "Agent reply",
"conversation_id": str(uuid.uuid4()),
}),
})
mock_pubsub.unsubscribe = AsyncMock()
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
mock_redis.aclose = AsyncMock()
mock_handle = MagicMock()
with (
patch("gateway.channels.web.async_session_factory", return_value=mock_session),
patch("gateway.channels.web.aioredis.from_url", return_value=mock_redis),
patch("gateway.channels.web.handle_message") as mock_handle_msg,
patch("gateway.channels.web.configure_rls_hook"),
patch("gateway.channels.web.current_tenant_id"),
):
mock_handle_msg.delay = MagicMock()
from gateway.channels.web import _handle_websocket_connection
# Run the handler — it should send typing indicator then process
# Use asyncio.wait_for to prevent infinite loop if something hangs
try:
await asyncio.wait_for(
_handle_websocket_connection(mock_ws, str(uuid.uuid4())),
timeout=2.0,
)
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
pass
# Check typing was sent before Celery dispatch
send_calls = mock_ws.send_json.call_args_list
typing_calls = [c for c in send_calls if c[0][0].get("type") == "typing"]
assert len(typing_calls) >= 1, f"Expected typing indicator, got send_json calls: {send_calls}"