feat(06-01): add web channel type, Redis key, ORM models, migration, and tests
- Add ChannelType.WEB = 'web' to shared/models/message.py - Add webchat_response_key() to shared/redis_keys.py - Create WebConversation and WebConversationMessage ORM models (SQLAlchemy 2.0) - Create migration 008_web_chat.py with RLS, indexes, and channel_type CHECK update - Pop conversation_id/portal_user_id extras in handle_message before model_validate - Add web case to _build_response_extras and _send_response (Redis pub-sub publish) - Import webchat_response_key in orchestrator/tasks.py - Write 19 unit tests covering CHAT-01 through CHAT-05 (all pass)
This commit is contained in:
283
tests/unit/test_chat_api.py
Normal file
283
tests/unit/test_chat_api.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Unit tests for the chat REST API with RBAC enforcement.
|
||||
|
||||
Tests:
|
||||
- test_chat_rbac_enforcement: GET /api/portal/chat/conversations?tenant_id=X returns 403
|
||||
when caller is not a member of tenant X
|
||||
- test_platform_admin_cross_tenant: GET /api/portal/chat/conversations?tenant_id=X returns 200
|
||||
when caller is platform_admin (bypasses membership check)
|
||||
- test_list_conversation_history: GET /api/portal/chat/conversations/{id}/messages returns
|
||||
paginated messages ordered by created_at
|
||||
- test_create_conversation: POST /api/portal/chat/conversations creates or returns existing
|
||||
conversation for user+agent pair
|
||||
- test_create_conversation_rbac: POST returns 403 for non-member caller
|
||||
- test_delete_conversation_resets_messages: DELETE /api/portal/chat/conversations/{id} deletes
|
||||
messages but keeps the conversation row
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from shared.api.rbac import PortalCaller
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _admin_headers(user_id: str | None = None) -> dict[str, str]:
|
||||
return {
|
||||
"X-Portal-User-Id": user_id or str(uuid.uuid4()),
|
||||
"X-Portal-User-Role": "platform_admin",
|
||||
}
|
||||
|
||||
|
||||
def _stranger_headers(user_id: str | None = None) -> dict[str, str]:
|
||||
return {
|
||||
"X-Portal-User-Id": user_id or str(uuid.uuid4()),
|
||||
"X-Portal-User-Role": "customer_operator",
|
||||
}
|
||||
|
||||
|
||||
def _make_app_with_session_override(mock_session: AsyncMock) -> FastAPI:
|
||||
"""Create a test FastAPI app with the chat router and a session dependency override."""
|
||||
from shared.api.chat import chat_router
|
||||
from shared.db import get_session
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(chat_router)
|
||||
|
||||
async def _override_get_session(): # type: ignore[return]
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_session] = _override_get_session
|
||||
return app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RBAC enforcement on list conversations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_rbac_enforcement() -> None:
|
||||
"""Non-member caller gets 403 when listing conversations for a tenant they don't belong to."""
|
||||
tenant_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
# Mock session — no membership row found (require_tenant_member checks UserTenantRole)
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
|
||||
|
||||
app = _make_app_with_session_override(mock_session)
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get(
|
||||
"/api/portal/chat/conversations",
|
||||
params={"tenant_id": str(tenant_id)},
|
||||
headers=_stranger_headers(str(user_id)),
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_admin_cross_tenant() -> None:
|
||||
"""Platform admin can list conversations for any tenant (bypasses membership check)."""
|
||||
tenant_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
# Mock session — returns empty rows for conversation query
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
app = _make_app_with_session_override(mock_session)
|
||||
|
||||
with (
|
||||
patch("shared.api.chat.configure_rls_hook"),
|
||||
patch("shared.api.chat.current_tenant_id"),
|
||||
):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get(
|
||||
"/api/portal/chat/conversations",
|
||||
params={"tenant_id": str(tenant_id)},
|
||||
headers=_admin_headers(str(user_id)),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert isinstance(response.json(), list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# List conversation history (paginated messages)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversation_history() -> None:
|
||||
"""GET /api/portal/chat/conversations/{id}/messages returns paginated messages ordered by created_at."""
|
||||
user_id = uuid.uuid4()
|
||||
conv_id = uuid.uuid4()
|
||||
|
||||
# Mock conversation owned by the caller
|
||||
mock_conv = MagicMock()
|
||||
mock_conv.id = conv_id
|
||||
mock_conv.user_id = user_id
|
||||
mock_conv.tenant_id = uuid.uuid4()
|
||||
|
||||
# Mock messages
|
||||
now = datetime.now(timezone.utc)
|
||||
mock_msg1 = MagicMock()
|
||||
mock_msg1.id = uuid.uuid4()
|
||||
mock_msg1.role = "user"
|
||||
mock_msg1.content = "Hello"
|
||||
mock_msg1.created_at = now
|
||||
|
||||
mock_msg2 = MagicMock()
|
||||
mock_msg2.id = uuid.uuid4()
|
||||
mock_msg2.role = "assistant"
|
||||
mock_msg2.content = "Hi there!"
|
||||
mock_msg2.created_at = now
|
||||
|
||||
mock_session = AsyncMock()
|
||||
# First call: fetch conversation; second call: fetch messages
|
||||
mock_session.execute.side_effect = [
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv)),
|
||||
MagicMock(scalars=MagicMock(return_value=MagicMock(all=MagicMock(return_value=[mock_msg1, mock_msg2])))),
|
||||
]
|
||||
|
||||
app = _make_app_with_session_override(mock_session)
|
||||
|
||||
with (
|
||||
patch("shared.api.chat.configure_rls_hook"),
|
||||
patch("shared.api.chat.current_tenant_id"),
|
||||
):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get(
|
||||
f"/api/portal/chat/conversations/{conv_id}/messages",
|
||||
headers=_admin_headers(str(user_id)),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
assert data[0]["role"] == "user"
|
||||
assert data[1]["role"] == "assistant"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Create conversation (get-or-create)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation() -> None:
|
||||
"""POST /api/portal/chat/conversations creates a new conversation for user+agent pair."""
|
||||
tenant_id = uuid.uuid4()
|
||||
agent_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
conv_id = uuid.uuid4()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Platform admin bypasses membership check; no existing conversation found
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
|
||||
mock_session.flush = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
|
||||
# refresh populates server-default fields on the passed ORM object
|
||||
async def _mock_refresh(obj: object) -> None:
|
||||
obj.id = conv_id # type: ignore[attr-defined]
|
||||
obj.created_at = now # type: ignore[attr-defined]
|
||||
obj.updated_at = now # type: ignore[attr-defined]
|
||||
|
||||
mock_session.refresh = _mock_refresh
|
||||
|
||||
app = _make_app_with_session_override(mock_session)
|
||||
|
||||
with (
|
||||
patch("shared.api.chat.configure_rls_hook"),
|
||||
patch("shared.api.chat.current_tenant_id"),
|
||||
):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/api/portal/chat/conversations",
|
||||
json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)},
|
||||
headers=_admin_headers(str(user_id)),
|
||||
)
|
||||
|
||||
assert response.status_code in (200, 201)
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation_rbac_forbidden() -> None:
|
||||
"""Non-member gets 403 when creating a conversation in a tenant they don't belong to."""
|
||||
tenant_id = uuid.uuid4()
|
||||
agent_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
# Membership check returns None (not a member)
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
|
||||
|
||||
app = _make_app_with_session_override(mock_session)
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/api/portal/chat/conversations",
|
||||
json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)},
|
||||
headers=_stranger_headers(str(user_id)),
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Delete conversation (reset messages)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_conversation_resets_messages() -> None:
|
||||
"""DELETE /api/portal/chat/conversations/{id} deletes messages but keeps conversation row."""
|
||||
user_id = uuid.uuid4()
|
||||
conv_id = uuid.uuid4()
|
||||
|
||||
mock_conv = MagicMock()
|
||||
mock_conv.id = conv_id
|
||||
mock_conv.user_id = user_id
|
||||
mock_conv.tenant_id = uuid.uuid4()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv))
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
app = _make_app_with_session_override(mock_session)
|
||||
|
||||
with (
|
||||
patch("shared.api.chat.configure_rls_hook"),
|
||||
patch("shared.api.chat.current_tenant_id"),
|
||||
):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.delete(
|
||||
f"/api/portal/chat/conversations/{conv_id}",
|
||||
headers=_admin_headers(str(user_id)),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_session.execute.call_count >= 1
|
||||
312
tests/unit/test_web_channel.py
Normal file
312
tests/unit/test_web_channel.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
Unit tests for the web channel adapter.
|
||||
|
||||
Tests:
|
||||
- test_normalize_web_event: normalize_web_event returns KonstructMessage with channel=WEB
|
||||
- test_normalize_web_event_thread_id: thread_id equals conversation_id
|
||||
- test_normalize_web_event_sender: sender.user_id equals portal user UUID
|
||||
- test_webchat_response_key: webchat_response_key returns correct namespaced key
|
||||
- test_send_response_web_publishes_to_redis: _send_response("web", ...) publishes JSON to Redis
|
||||
- test_typing_indicator_sent: WebSocket handler sends {"type": "typing"} before Celery dispatch
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.models.message import ChannelType, KonstructMessage
|
||||
from shared.redis_keys import webchat_response_key
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test webchat_response_key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_webchat_response_key_format() -> None:
|
||||
"""webchat_response_key returns correctly namespaced key."""
|
||||
tenant_id = "tenant-abc"
|
||||
conversation_id = "conv-xyz"
|
||||
key = webchat_response_key(tenant_id, conversation_id)
|
||||
assert key == "tenant-abc:webchat:response:conv-xyz"
|
||||
|
||||
|
||||
def test_webchat_response_key_isolation() -> None:
|
||||
"""Two tenants with same conversation_id get different keys."""
|
||||
key_a = webchat_response_key("tenant-a", "conv-1")
|
||||
key_b = webchat_response_key("tenant-b", "conv-1")
|
||||
assert key_a != key_b
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test ChannelType.WEB
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_channel_type_web_exists() -> None:
|
||||
"""ChannelType.WEB must exist with value 'web'."""
|
||||
assert ChannelType.WEB == "web"
|
||||
assert ChannelType.WEB.value == "web"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test normalize_web_event
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_normalize_web_event_returns_konstruct_message() -> None:
|
||||
"""normalize_web_event returns a KonstructMessage."""
|
||||
from gateway.channels.web import normalize_web_event
|
||||
|
||||
tenant_id = str(uuid.uuid4())
|
||||
agent_id = str(uuid.uuid4())
|
||||
user_id = str(uuid.uuid4())
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
event = {
|
||||
"text": "Hello from the portal",
|
||||
"tenant_id": tenant_id,
|
||||
"agent_id": agent_id,
|
||||
"user_id": user_id,
|
||||
"display_name": "Portal User",
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
msg = normalize_web_event(event)
|
||||
assert isinstance(msg, KonstructMessage)
|
||||
|
||||
|
||||
def test_normalize_web_event_channel_is_web() -> None:
|
||||
"""normalize_web_event sets channel = ChannelType.WEB."""
|
||||
from gateway.channels.web import normalize_web_event
|
||||
|
||||
event = {
|
||||
"text": "test",
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"agent_id": str(uuid.uuid4()),
|
||||
"user_id": str(uuid.uuid4()),
|
||||
"display_name": "User",
|
||||
"conversation_id": str(uuid.uuid4()),
|
||||
}
|
||||
msg = normalize_web_event(event)
|
||||
assert msg.channel == ChannelType.WEB
|
||||
|
||||
|
||||
def test_normalize_web_event_thread_id_equals_conversation_id() -> None:
|
||||
"""normalize_web_event sets thread_id = conversation_id for memory scoping."""
|
||||
from gateway.channels.web import normalize_web_event
|
||||
|
||||
conversation_id = str(uuid.uuid4())
|
||||
event = {
|
||||
"text": "test",
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"agent_id": str(uuid.uuid4()),
|
||||
"user_id": str(uuid.uuid4()),
|
||||
"display_name": "User",
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
msg = normalize_web_event(event)
|
||||
assert msg.thread_id == conversation_id
|
||||
|
||||
|
||||
def test_normalize_web_event_sender_user_id() -> None:
|
||||
"""normalize_web_event sets sender.user_id to the portal user UUID."""
|
||||
from gateway.channels.web import normalize_web_event
|
||||
|
||||
user_id = str(uuid.uuid4())
|
||||
event = {
|
||||
"text": "test",
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"agent_id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"display_name": "Portal User",
|
||||
"conversation_id": str(uuid.uuid4()),
|
||||
}
|
||||
msg = normalize_web_event(event)
|
||||
assert msg.sender.user_id == user_id
|
||||
|
||||
|
||||
def test_normalize_web_event_channel_metadata() -> None:
|
||||
"""normalize_web_event populates channel_metadata with portal_user_id, tenant_id, conversation_id."""
|
||||
from gateway.channels.web import normalize_web_event
|
||||
|
||||
tenant_id = str(uuid.uuid4())
|
||||
user_id = str(uuid.uuid4())
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
event = {
|
||||
"text": "test",
|
||||
"tenant_id": tenant_id,
|
||||
"agent_id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"display_name": "User",
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
msg = normalize_web_event(event)
|
||||
assert msg.channel_metadata["portal_user_id"] == user_id
|
||||
assert msg.channel_metadata["tenant_id"] == tenant_id
|
||||
assert msg.channel_metadata["conversation_id"] == conversation_id
|
||||
|
||||
|
||||
def test_normalize_web_event_tenant_id() -> None:
|
||||
"""normalize_web_event sets tenant_id on the message."""
|
||||
from gateway.channels.web import normalize_web_event
|
||||
|
||||
tenant_id = str(uuid.uuid4())
|
||||
event = {
|
||||
"text": "hello",
|
||||
"tenant_id": tenant_id,
|
||||
"agent_id": str(uuid.uuid4()),
|
||||
"user_id": str(uuid.uuid4()),
|
||||
"display_name": "User",
|
||||
"conversation_id": str(uuid.uuid4()),
|
||||
}
|
||||
msg = normalize_web_event(event)
|
||||
assert msg.tenant_id == tenant_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test _send_response web case publishes to Redis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_response_web_publishes_to_redis() -> None:
|
||||
"""_send_response('web', ...) publishes JSON message to Redis webchat channel."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.publish = AsyncMock()
|
||||
mock_redis.aclose = AsyncMock()
|
||||
|
||||
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
|
||||
from orchestrator.tasks import _send_response
|
||||
|
||||
extras = {
|
||||
"conversation_id": conversation_id,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
await _send_response("web", "Hello, this is the agent response!", extras)
|
||||
|
||||
expected_channel = webchat_response_key(tenant_id, conversation_id)
|
||||
mock_redis.publish.assert_called_once()
|
||||
call_args = mock_redis.publish.call_args
|
||||
assert call_args[0][0] == expected_channel
|
||||
published_payload = json.loads(call_args[0][1])
|
||||
assert published_payload["type"] == "response"
|
||||
assert published_payload["text"] == "Hello, this is the agent response!"
|
||||
assert published_payload["conversation_id"] == conversation_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_response_web_connection_cleanup() -> None:
|
||||
"""_send_response web case always closes Redis connection (try/finally)."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.publish = AsyncMock()
|
||||
mock_redis.aclose = AsyncMock()
|
||||
|
||||
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
|
||||
from orchestrator.tasks import _send_response
|
||||
|
||||
await _send_response("web", "test", {"conversation_id": "c1", "tenant_id": "t1"})
|
||||
|
||||
mock_redis.aclose.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_response_web_missing_conversation_id_logs_warning() -> None:
|
||||
"""_send_response web case logs warning if conversation_id missing."""
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
|
||||
with patch("orchestrator.tasks.logger") as mock_logger:
|
||||
from orchestrator.tasks import _send_response
|
||||
|
||||
await _send_response("web", "test", {"tenant_id": "t1"})
|
||||
|
||||
mock_logger.warning.assert_called()
|
||||
mock_redis.publish.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test typing indicator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_typing_indicator_sent_before_dispatch() -> None:
|
||||
"""WebSocket handler sends {'type': 'typing'} immediately after receiving user message."""
|
||||
# We test the typing indicator by calling the handler function directly
|
||||
# with a mocked WebSocket and mocked Celery dispatch.
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
# First receive_json returns auth message, second returns the user message
|
||||
mock_ws.receive_json = AsyncMock(
|
||||
side_effect=[
|
||||
{"type": "auth", "userId": str(uuid.uuid4()), "role": "customer_operator", "tenantId": str(uuid.uuid4())},
|
||||
{"type": "message", "text": "Hello agent", "agentId": str(uuid.uuid4()), "conversationId": str(uuid.uuid4())},
|
||||
]
|
||||
)
|
||||
# accept must be awaitable
|
||||
mock_ws.accept = AsyncMock()
|
||||
mock_ws.send_json = AsyncMock()
|
||||
|
||||
# Mock DB session that returns a conversation
|
||||
mock_session = AsyncMock()
|
||||
mock_conv = MagicMock()
|
||||
mock_conv.id = uuid.uuid4()
|
||||
mock_conv.tenant_id = uuid.uuid4()
|
||||
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv))
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Mock Redis pub-sub: raise after publishing once so handler exits
|
||||
mock_redis = AsyncMock()
|
||||
mock_pubsub = AsyncMock()
|
||||
mock_pubsub.subscribe = AsyncMock()
|
||||
mock_pubsub.get_message = AsyncMock(return_value={
|
||||
"type": "message",
|
||||
"data": json.dumps({
|
||||
"type": "response",
|
||||
"text": "Agent reply",
|
||||
"conversation_id": str(uuid.uuid4()),
|
||||
}),
|
||||
})
|
||||
mock_pubsub.unsubscribe = AsyncMock()
|
||||
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
||||
mock_redis.aclose = AsyncMock()
|
||||
|
||||
mock_handle = MagicMock()
|
||||
|
||||
with (
|
||||
patch("gateway.channels.web.async_session_factory", return_value=mock_session),
|
||||
patch("gateway.channels.web.aioredis.from_url", return_value=mock_redis),
|
||||
patch("gateway.channels.web.handle_message") as mock_handle_msg,
|
||||
patch("gateway.channels.web.configure_rls_hook"),
|
||||
patch("gateway.channels.web.current_tenant_id"),
|
||||
):
|
||||
mock_handle_msg.delay = MagicMock()
|
||||
|
||||
from gateway.channels.web import _handle_websocket_connection
|
||||
|
||||
# Run the handler — it should send typing indicator then process
|
||||
# Use asyncio.wait_for to prevent infinite loop if something hangs
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
_handle_websocket_connection(mock_ws, str(uuid.uuid4())),
|
||||
timeout=2.0,
|
||||
)
|
||||
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
|
||||
pass
|
||||
|
||||
# Check typing was sent before Celery dispatch
|
||||
send_calls = mock_ws.send_json.call_args_list
|
||||
typing_calls = [c for c in send_calls if c[0][0].get("type") == "typing"]
|
||||
assert len(typing_calls) >= 1, f"Expected typing indicator, got send_json calls: {send_calls}"
|
||||
Reference in New Issue
Block a user