""" Unit tests for the web channel adapter. Tests: - test_normalize_web_event: normalize_web_event returns KonstructMessage with channel=WEB - test_normalize_web_event_thread_id: thread_id equals conversation_id - test_normalize_web_event_sender: sender.user_id equals portal user UUID - test_webchat_response_key: webchat_response_key returns correct namespaced key - test_send_response_web_publishes_to_redis: _send_response("web", ...) publishes JSON to Redis - test_typing_indicator_sent: WebSocket handler sends {"type": "typing"} before Celery dispatch """ from __future__ import annotations import asyncio import json import uuid from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from shared.models.message import ChannelType, KonstructMessage from shared.redis_keys import webchat_response_key # --------------------------------------------------------------------------- # Test webchat_response_key # --------------------------------------------------------------------------- def test_webchat_response_key_format() -> None: """webchat_response_key returns correctly namespaced key.""" tenant_id = "tenant-abc" conversation_id = "conv-xyz" key = webchat_response_key(tenant_id, conversation_id) assert key == "tenant-abc:webchat:response:conv-xyz" def test_webchat_response_key_isolation() -> None: """Two tenants with same conversation_id get different keys.""" key_a = webchat_response_key("tenant-a", "conv-1") key_b = webchat_response_key("tenant-b", "conv-1") assert key_a != key_b # --------------------------------------------------------------------------- # Test ChannelType.WEB # --------------------------------------------------------------------------- def test_channel_type_web_exists() -> None: """ChannelType.WEB must exist with value 'web'.""" assert ChannelType.WEB == "web" assert ChannelType.WEB.value == "web" # --------------------------------------------------------------------------- # Test normalize_web_event # --------------------------------------------------------------------------- def test_normalize_web_event_returns_konstruct_message() -> None: """normalize_web_event returns a KonstructMessage.""" from gateway.channels.web import normalize_web_event tenant_id = str(uuid.uuid4()) agent_id = str(uuid.uuid4()) user_id = str(uuid.uuid4()) conversation_id = str(uuid.uuid4()) event = { "text": "Hello from the portal", "tenant_id": tenant_id, "agent_id": agent_id, "user_id": user_id, "display_name": "Portal User", "conversation_id": conversation_id, } msg = normalize_web_event(event) assert isinstance(msg, KonstructMessage) def test_normalize_web_event_channel_is_web() -> None: """normalize_web_event sets channel = ChannelType.WEB.""" from gateway.channels.web import normalize_web_event event = { "text": "test", "tenant_id": str(uuid.uuid4()), "agent_id": str(uuid.uuid4()), "user_id": str(uuid.uuid4()), "display_name": "User", "conversation_id": str(uuid.uuid4()), } msg = normalize_web_event(event) assert msg.channel == ChannelType.WEB def test_normalize_web_event_thread_id_equals_conversation_id() -> None: """normalize_web_event sets thread_id = conversation_id for memory scoping.""" from gateway.channels.web import normalize_web_event conversation_id = str(uuid.uuid4()) event = { "text": "test", "tenant_id": str(uuid.uuid4()), "agent_id": str(uuid.uuid4()), "user_id": str(uuid.uuid4()), "display_name": "User", "conversation_id": conversation_id, } msg = normalize_web_event(event) assert msg.thread_id == conversation_id def test_normalize_web_event_sender_user_id() -> None: """normalize_web_event sets sender.user_id to the portal user UUID.""" from gateway.channels.web import normalize_web_event user_id = str(uuid.uuid4()) event = { "text": "test", "tenant_id": str(uuid.uuid4()), "agent_id": str(uuid.uuid4()), "user_id": user_id, "display_name": "Portal User", "conversation_id": str(uuid.uuid4()), } msg = normalize_web_event(event) assert msg.sender.user_id == user_id def test_normalize_web_event_channel_metadata() -> None: """normalize_web_event populates channel_metadata with portal_user_id, tenant_id, conversation_id.""" from gateway.channels.web import normalize_web_event tenant_id = str(uuid.uuid4()) user_id = str(uuid.uuid4()) conversation_id = str(uuid.uuid4()) event = { "text": "test", "tenant_id": tenant_id, "agent_id": str(uuid.uuid4()), "user_id": user_id, "display_name": "User", "conversation_id": conversation_id, } msg = normalize_web_event(event) assert msg.channel_metadata["portal_user_id"] == user_id assert msg.channel_metadata["tenant_id"] == tenant_id assert msg.channel_metadata["conversation_id"] == conversation_id def test_normalize_web_event_tenant_id() -> None: """normalize_web_event sets tenant_id on the message.""" from gateway.channels.web import normalize_web_event tenant_id = str(uuid.uuid4()) event = { "text": "hello", "tenant_id": tenant_id, "agent_id": str(uuid.uuid4()), "user_id": str(uuid.uuid4()), "display_name": "User", "conversation_id": str(uuid.uuid4()), } msg = normalize_web_event(event) assert msg.tenant_id == tenant_id # --------------------------------------------------------------------------- # Test _send_response web case publishes to Redis # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_send_response_web_publishes_to_redis() -> None: """_send_response('web', ...) publishes JSON message to Redis webchat channel.""" tenant_id = str(uuid.uuid4()) conversation_id = str(uuid.uuid4()) mock_redis = AsyncMock() mock_redis.publish = AsyncMock() mock_redis.aclose = AsyncMock() with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis): from orchestrator.tasks import _send_response extras = { "conversation_id": conversation_id, "tenant_id": tenant_id, } await _send_response("web", "Hello, this is the agent response!", extras) expected_channel = webchat_response_key(tenant_id, conversation_id) mock_redis.publish.assert_called_once() call_args = mock_redis.publish.call_args assert call_args[0][0] == expected_channel published_payload = json.loads(call_args[0][1]) assert published_payload["type"] == "response" assert published_payload["text"] == "Hello, this is the agent response!" assert published_payload["conversation_id"] == conversation_id @pytest.mark.asyncio async def test_send_response_web_connection_cleanup() -> None: """_send_response web case always closes Redis connection (try/finally).""" mock_redis = AsyncMock() mock_redis.publish = AsyncMock() mock_redis.aclose = AsyncMock() with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis): from orchestrator.tasks import _send_response await _send_response("web", "test", {"conversation_id": "c1", "tenant_id": "t1"}) mock_redis.aclose.assert_called_once() @pytest.mark.asyncio async def test_send_response_web_missing_conversation_id_logs_warning() -> None: """_send_response web case logs warning if conversation_id missing.""" mock_redis = AsyncMock() with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis): with patch("orchestrator.tasks.logger") as mock_logger: from orchestrator.tasks import _send_response await _send_response("web", "test", {"tenant_id": "t1"}) mock_logger.warning.assert_called() mock_redis.publish.assert_not_called() # --------------------------------------------------------------------------- # Test typing indicator # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_typing_indicator_sent_before_dispatch() -> None: """WebSocket handler sends {'type': 'typing'} immediately after receiving user message.""" # We test the typing indicator by calling the handler function directly # with a mocked WebSocket and mocked Celery dispatch. from unittest.mock import AsyncMock, MagicMock, call, patch mock_ws = AsyncMock() # First receive_json returns auth message, second returns the user message mock_ws.receive_json = AsyncMock( side_effect=[ {"type": "auth", "userId": str(uuid.uuid4()), "role": "customer_operator", "tenantId": str(uuid.uuid4())}, {"type": "message", "text": "Hello agent", "agentId": str(uuid.uuid4()), "conversationId": str(uuid.uuid4())}, ] ) # accept must be awaitable mock_ws.accept = AsyncMock() mock_ws.send_json = AsyncMock() # Mock DB session that returns a conversation mock_session = AsyncMock() mock_conv = MagicMock() mock_conv.id = uuid.uuid4() mock_conv.tenant_id = uuid.uuid4() mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv)) mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aexit__ = AsyncMock(return_value=False) # Mock Redis pub-sub: raise after publishing once so handler exits mock_redis = AsyncMock() mock_pubsub = AsyncMock() mock_pubsub.subscribe = AsyncMock() mock_pubsub.get_message = AsyncMock(return_value={ "type": "message", "data": json.dumps({ "type": "response", "text": "Agent reply", "conversation_id": str(uuid.uuid4()), }), }) mock_pubsub.unsubscribe = AsyncMock() mock_redis.pubsub = MagicMock(return_value=mock_pubsub) mock_redis.aclose = AsyncMock() mock_handle = MagicMock() with ( patch("gateway.channels.web.async_session_factory", return_value=mock_session), patch("gateway.channels.web.aioredis.from_url", return_value=mock_redis), patch("gateway.channels.web.handle_message") as mock_handle_msg, patch("gateway.channels.web.configure_rls_hook"), patch("gateway.channels.web.current_tenant_id"), ): mock_handle_msg.delay = MagicMock() from gateway.channels.web import _handle_websocket_connection # Run the handler — it should send typing indicator then process # Use asyncio.wait_for to prevent infinite loop if something hangs try: await asyncio.wait_for( _handle_websocket_connection(mock_ws, str(uuid.uuid4())), timeout=2.0, ) except (asyncio.TimeoutError, StopAsyncIteration, Exception): pass # Check typing was sent before Celery dispatch send_calls = mock_ws.send_json.call_args_list typing_calls = [c for c in send_calls if c[0][0].get("type") == "typing"] assert len(typing_calls) >= 1, f"Expected typing indicator, got send_json calls: {send_calls}"