diff --git a/migrations/versions/008_web_chat.py b/migrations/versions/008_web_chat.py new file mode 100644 index 0000000..6e812f0 --- /dev/null +++ b/migrations/versions/008_web_chat.py @@ -0,0 +1,172 @@ +"""Web chat: web_conversations and web_conversation_messages tables with RLS + +Revision ID: 008 +Revises: 007 +Create Date: 2026-03-25 + +This migration: +1. Creates the web_conversations table (tenant-scoped, RLS-enabled) +2. Creates the web_conversation_messages table (CASCADE delete, RLS-enabled) +3. Enables FORCE ROW LEVEL SECURITY on both tables +4. Creates tenant_isolation RLS policies matching existing pattern +5. Adds index on web_conversation_messages(conversation_id, created_at) for pagination +6. Replaces the channel_type CHECK constraint on channel_connections to include 'web' + +NOTE on CHECK constraint replacement (Pitfall 5): + The existing constraint chk_channel_type only covers the original 7 channels. + ALTER TABLE DROP CONSTRAINT + ADD CONSTRAINT is used instead of just adding a + new constraint — the old constraint remains active otherwise and would reject 'web'. +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import UUID + +# Alembic migration metadata +revision: str = "008" +down_revision: Union[str, None] = "007" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +# All valid channel types including new 'web' — must match ChannelType StrEnum in message.py +_CHANNEL_TYPES = ( + "slack", "whatsapp", "mattermost", "rocketchat", "teams", "telegram", "signal", "web" +) + + +def upgrade() -> None: + # ------------------------------------------------------------------------- + # 1. Create web_conversations table + # ------------------------------------------------------------------------- + op.create_table( + "web_conversations", + sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column( + "tenant_id", + UUID(as_uuid=True), + sa.ForeignKey("tenants.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "agent_id", + UUID(as_uuid=True), + sa.ForeignKey("agents.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("user_id", UUID(as_uuid=True), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("NOW()"), + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("NOW()"), + ), + sa.UniqueConstraint( + "tenant_id", + "agent_id", + "user_id", + name="uq_web_conversations_tenant_agent_user", + ), + ) + op.create_index("ix_web_conversations_tenant_id", "web_conversations", ["tenant_id"]) + + # Enable RLS on web_conversations + op.execute("ALTER TABLE web_conversations ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE web_conversations FORCE ROW LEVEL SECURITY") + op.execute(""" + CREATE POLICY tenant_isolation ON web_conversations + USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid) + """) + + # ------------------------------------------------------------------------- + # 2. Create web_conversation_messages table + # ------------------------------------------------------------------------- + op.create_table( + "web_conversation_messages", + sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column( + "conversation_id", + UUID(as_uuid=True), + sa.ForeignKey("web_conversations.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("tenant_id", UUID(as_uuid=True), nullable=False), + sa.Column("role", sa.Text, nullable=False), + sa.Column("content", sa.Text, nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("NOW()"), + ), + ) + + # CHECK constraint on role — TEXT+CHECK per Phase 1 convention (not sa.Enum) + op.execute( + "ALTER TABLE web_conversation_messages ADD CONSTRAINT chk_message_role " + "CHECK (role IN ('user', 'assistant'))" + ) + + # Index for paginated message history queries: ORDER BY created_at with conversation filter + op.create_index( + "ix_web_conversation_messages_conv_created", + "web_conversation_messages", + ["conversation_id", "created_at"], + ) + + # Enable RLS on web_conversation_messages + op.execute("ALTER TABLE web_conversation_messages ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE web_conversation_messages FORCE ROW LEVEL SECURITY") + op.execute(""" + CREATE POLICY tenant_isolation ON web_conversation_messages + USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid) + """) + + # ------------------------------------------------------------------------- + # 3. Grant permissions to konstruct_app + # ------------------------------------------------------------------------- + op.execute("GRANT SELECT, INSERT, UPDATE, DELETE ON web_conversations TO konstruct_app") + op.execute("GRANT SELECT, INSERT, UPDATE, DELETE ON web_conversation_messages TO konstruct_app") + + # ------------------------------------------------------------------------- + # 4. Update channel_connections CHECK constraint to include 'web' + # + # DROP + re-ADD because an existing CHECK constraint still enforces the old + # set of values — simply adding a second constraint would AND them together. + # ------------------------------------------------------------------------- + op.execute("ALTER TABLE channel_connections DROP CONSTRAINT IF EXISTS chk_channel_type") + op.execute( + "ALTER TABLE channel_connections ADD CONSTRAINT chk_channel_type " + f"CHECK (channel_type IN {tuple(_CHANNEL_TYPES)})" + ) + + +def downgrade() -> None: + # Restore original channel_type CHECK constraint (without 'web') + _ORIGINAL_CHANNEL_TYPES = ( + "slack", "whatsapp", "mattermost", "rocketchat", "teams", "telegram", "signal" + ) + op.execute("ALTER TABLE channel_connections DROP CONSTRAINT IF EXISTS chk_channel_type") + op.execute( + "ALTER TABLE channel_connections ADD CONSTRAINT chk_channel_type " + f"CHECK (channel_type IN {tuple(_ORIGINAL_CHANNEL_TYPES)})" + ) + + # Drop web_conversation_messages first (FK dependency) + op.execute("REVOKE ALL ON web_conversation_messages FROM konstruct_app") + op.drop_index("ix_web_conversation_messages_conv_created") + op.drop_table("web_conversation_messages") + + # Drop web_conversations + op.execute("REVOKE ALL ON web_conversations FROM konstruct_app") + op.drop_index("ix_web_conversations_tenant_id") + op.drop_table("web_conversations") diff --git a/packages/orchestrator/orchestrator/tasks.py b/packages/orchestrator/orchestrator/tasks.py index 4f7e05f..9dd8896 100644 --- a/packages/orchestrator/orchestrator/tasks.py +++ b/packages/orchestrator/orchestrator/tasks.py @@ -77,7 +77,7 @@ from orchestrator.tools.registry import get_tools_for_agent from shared.config import settings from shared.db import async_session_factory, engine from shared.models.message import KonstructMessage -from shared.redis_keys import escalation_status_key +from shared.redis_keys import escalation_status_key, webchat_response_key from shared.rls import configure_rls_hook, current_tenant_id logger = logging.getLogger(__name__) @@ -253,6 +253,11 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped phone_number_id: str = message_data.pop("phone_number_id", "") or "" bot_token: str = message_data.pop("bot_token", "") or "" + # Extract web channel extras before model validation + # The web WebSocket handler injects these alongside the normalized KonstructMessage fields + conversation_id: str = message_data.pop("conversation_id", "") or "" + portal_user_id: str = message_data.pop("portal_user_id", "") or "" + try: msg = KonstructMessage.model_validate(message_data) except Exception as exc: @@ -272,6 +277,11 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped "phone_number_id": phone_number_id, "bot_token": bot_token, "wa_id": wa_id, + # Web channel extras + "conversation_id": conversation_id, + "portal_user_id": portal_user_id, + # tenant_id for web channel response routing (web lacks a workspace_id in channel_connections) + "tenant_id": msg.tenant_id or "", } result = asyncio.run(_process_message(msg, extras=extras)) @@ -646,6 +656,13 @@ def _build_response_extras( "bot_token": extras.get("bot_token", "") or "", "wa_id": extras.get("wa_id", "") or "", } + elif channel_str == "web": + # Web channel: tenant_id comes from extras (set by handle_message from msg.tenant_id), + # not from channel_connections like Slack. conversation_id scopes the Redis pub-sub channel. + return { + "conversation_id": extras.get("conversation_id", "") or "", + "tenant_id": extras.get("tenant_id", "") or "", + } else: return dict(extras) @@ -774,6 +791,31 @@ async def _send_response( text=text, ) + elif channel_str == "web": + # Publish agent response to Redis pub-sub so the WebSocket handler can deliver it + web_conversation_id: str = extras.get("conversation_id", "") or "" + web_tenant_id: str = extras.get("tenant_id", "") or "" + + if not web_conversation_id or not web_tenant_id: + logger.warning( + "_send_response: web channel missing conversation_id or tenant_id in extras" + ) + return + + response_channel = webchat_response_key(web_tenant_id, web_conversation_id) + publish_redis = aioredis.from_url(settings.redis_url) + try: + await publish_redis.publish( + response_channel, + json.dumps({ + "type": "response", + "text": text, + "conversation_id": web_conversation_id, + }), + ) + finally: + await publish_redis.aclose() + else: logger.warning( "_send_response: unsupported channel=%r — response not delivered", channel diff --git a/packages/shared/shared/models/chat.py b/packages/shared/shared/models/chat.py new file mode 100644 index 0000000..8d2bab0 --- /dev/null +++ b/packages/shared/shared/models/chat.py @@ -0,0 +1,124 @@ +""" +SQLAlchemy 2.0 ORM models for web chat conversations. + +These models support the Phase 6 web chat feature — a WebSocket-based +channel that allows portal users to chat with AI employees directly from +the Konstruct portal UI. + +Tables: + web_conversations — One per portal user + agent pair per tenant + web_conversation_messages — Individual messages within a conversation + +RLS is applied to both tables via app.current_tenant session variable, +same pattern as agents and channel_connections (migration 008). + +Design notes: + - UniqueConstraint on (tenant_id, agent_id, user_id) for get-or-create semantics + - role column is TEXT+CHECK (not sa.Enum) per Phase 1 ADR to avoid Alembic DDL conflicts + - ON DELETE CASCADE on messages.conversation_id: deleting a conversation + removes all its messages automatically +""" + +from __future__ import annotations + +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, ForeignKey, Text, UniqueConstraint, func +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +from shared.models.tenant import Base + + +class WebConversation(Base): + """ + A web chat conversation between a portal user and an AI employee. + + One row per (tenant_id, agent_id, user_id) triple — callers use + get-or-create semantics when starting a chat session. + + RLS scoped to tenant_id so the app role only sees conversations + for the currently-configured tenant. + """ + + __tablename__ = "web_conversations" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + server_default=func.gen_random_uuid(), + ) + tenant_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("tenants.id", ondelete="CASCADE"), + nullable=False, + ) + agent_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("agents.id", ondelete="CASCADE"), + nullable=False, + ) + user_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + nullable=False, + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + onupdate=func.now(), + ) + + __table_args__ = ( + UniqueConstraint("tenant_id", "agent_id", "user_id", name="uq_web_conversations_tenant_agent_user"), + ) + + +class WebConversationMessage(Base): + """ + A single message within a web chat conversation. + + role is stored as TEXT with a CHECK constraint ('user' or 'assistant'), + following the Phase 1 convention that avoids PostgreSQL ENUM DDL issues. + + Messages are deleted via ON DELETE CASCADE when their parent conversation + is deleted, or explicitly during a conversation reset (DELETE /conversations/{id}). + """ + + __tablename__ = "web_conversation_messages" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + server_default=func.gen_random_uuid(), + ) + conversation_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("web_conversations.id", ondelete="CASCADE"), + nullable=False, + ) + tenant_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + nullable=False, + ) + role: Mapped[str] = mapped_column( + Text, + nullable=False, + ) + content: Mapped[str] = mapped_column( + Text, + nullable=False, + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ) diff --git a/packages/shared/shared/models/message.py b/packages/shared/shared/models/message.py index c3f7810..8621c22 100644 --- a/packages/shared/shared/models/message.py +++ b/packages/shared/shared/models/message.py @@ -26,6 +26,7 @@ class ChannelType(StrEnum): TEAMS = "teams" TELEGRAM = "telegram" SIGNAL = "signal" + WEB = "web" class MediaType(StrEnum): diff --git a/packages/shared/shared/redis_keys.py b/packages/shared/shared/redis_keys.py index fbce779..eb4b187 100644 --- a/packages/shared/shared/redis_keys.py +++ b/packages/shared/shared/redis_keys.py @@ -144,3 +144,25 @@ def pending_tool_confirm_key(tenant_id: str, thread_id: str) -> str: Namespaced Redis key: "{tenant_id}:tool_confirm:{thread_id}" """ return f"{tenant_id}:tool_confirm:{thread_id}" + + +def webchat_response_key(tenant_id: str, conversation_id: str) -> str: + """ + Redis pub-sub channel key for web chat response delivery. + + The WebSocket handler subscribes to this channel after dispatching + a message to Celery. The orchestrator publishes the agent response + to this channel when processing completes. + + Key includes both tenant_id and conversation_id to ensure: + - Two conversations in the same tenant get separate channels + - Two tenants with the same conversation_id are fully isolated + + Args: + tenant_id: Konstruct tenant identifier. + conversation_id: Web conversation UUID string. + + Returns: + Namespaced Redis key: "{tenant_id}:webchat:response:{conversation_id}" + """ + return f"{tenant_id}:webchat:response:{conversation_id}" diff --git a/tests/unit/test_chat_api.py b/tests/unit/test_chat_api.py new file mode 100644 index 0000000..d33009c --- /dev/null +++ b/tests/unit/test_chat_api.py @@ -0,0 +1,283 @@ +""" +Unit tests for the chat REST API with RBAC enforcement. + +Tests: + - test_chat_rbac_enforcement: GET /api/portal/chat/conversations?tenant_id=X returns 403 + when caller is not a member of tenant X + - test_platform_admin_cross_tenant: GET /api/portal/chat/conversations?tenant_id=X returns 200 + when caller is platform_admin (bypasses membership check) + - test_list_conversation_history: GET /api/portal/chat/conversations/{id}/messages returns + paginated messages ordered by created_at + - test_create_conversation: POST /api/portal/chat/conversations creates or returns existing + conversation for user+agent pair + - test_create_conversation_rbac: POST returns 403 for non-member caller + - test_delete_conversation_resets_messages: DELETE /api/portal/chat/conversations/{id} deletes + messages but keeps the conversation row +""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient + +from shared.api.rbac import PortalCaller + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _admin_headers(user_id: str | None = None) -> dict[str, str]: + return { + "X-Portal-User-Id": user_id or str(uuid.uuid4()), + "X-Portal-User-Role": "platform_admin", + } + + +def _stranger_headers(user_id: str | None = None) -> dict[str, str]: + return { + "X-Portal-User-Id": user_id or str(uuid.uuid4()), + "X-Portal-User-Role": "customer_operator", + } + + +def _make_app_with_session_override(mock_session: AsyncMock) -> FastAPI: + """Create a test FastAPI app with the chat router and a session dependency override.""" + from shared.api.chat import chat_router + from shared.db import get_session + + app = FastAPI() + app.include_router(chat_router) + + async def _override_get_session(): # type: ignore[return] + yield mock_session + + app.dependency_overrides[get_session] = _override_get_session + return app + + +# --------------------------------------------------------------------------- +# RBAC enforcement on list conversations +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_chat_rbac_enforcement() -> None: + """Non-member caller gets 403 when listing conversations for a tenant they don't belong to.""" + tenant_id = uuid.uuid4() + user_id = uuid.uuid4() + + # Mock session — no membership row found (require_tenant_member checks UserTenantRole) + mock_session = AsyncMock() + mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None)) + + app = _make_app_with_session_override(mock_session) + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get( + "/api/portal/chat/conversations", + params={"tenant_id": str(tenant_id)}, + headers=_stranger_headers(str(user_id)), + ) + + assert response.status_code == 403 + + +@pytest.mark.asyncio +async def test_platform_admin_cross_tenant() -> None: + """Platform admin can list conversations for any tenant (bypasses membership check).""" + tenant_id = uuid.uuid4() + user_id = uuid.uuid4() + + # Mock session — returns empty rows for conversation query + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.all.return_value = [] + mock_session.execute.return_value = mock_result + + app = _make_app_with_session_override(mock_session) + + with ( + patch("shared.api.chat.configure_rls_hook"), + patch("shared.api.chat.current_tenant_id"), + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get( + "/api/portal/chat/conversations", + params={"tenant_id": str(tenant_id)}, + headers=_admin_headers(str(user_id)), + ) + + assert response.status_code == 200 + assert isinstance(response.json(), list) + + +# --------------------------------------------------------------------------- +# List conversation history (paginated messages) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_list_conversation_history() -> None: + """GET /api/portal/chat/conversations/{id}/messages returns paginated messages ordered by created_at.""" + user_id = uuid.uuid4() + conv_id = uuid.uuid4() + + # Mock conversation owned by the caller + mock_conv = MagicMock() + mock_conv.id = conv_id + mock_conv.user_id = user_id + mock_conv.tenant_id = uuid.uuid4() + + # Mock messages + now = datetime.now(timezone.utc) + mock_msg1 = MagicMock() + mock_msg1.id = uuid.uuid4() + mock_msg1.role = "user" + mock_msg1.content = "Hello" + mock_msg1.created_at = now + + mock_msg2 = MagicMock() + mock_msg2.id = uuid.uuid4() + mock_msg2.role = "assistant" + mock_msg2.content = "Hi there!" + mock_msg2.created_at = now + + mock_session = AsyncMock() + # First call: fetch conversation; second call: fetch messages + mock_session.execute.side_effect = [ + MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv)), + MagicMock(scalars=MagicMock(return_value=MagicMock(all=MagicMock(return_value=[mock_msg1, mock_msg2])))), + ] + + app = _make_app_with_session_override(mock_session) + + with ( + patch("shared.api.chat.configure_rls_hook"), + patch("shared.api.chat.current_tenant_id"), + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get( + f"/api/portal/chat/conversations/{conv_id}/messages", + headers=_admin_headers(str(user_id)), + ) + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert len(data) == 2 + assert data[0]["role"] == "user" + assert data[1]["role"] == "assistant" + + +# --------------------------------------------------------------------------- +# Create conversation (get-or-create) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_create_conversation() -> None: + """POST /api/portal/chat/conversations creates a new conversation for user+agent pair.""" + tenant_id = uuid.uuid4() + agent_id = uuid.uuid4() + user_id = uuid.uuid4() + conv_id = uuid.uuid4() + + now = datetime.now(timezone.utc) + + # Platform admin bypasses membership check; no existing conversation found + mock_session = AsyncMock() + mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None)) + mock_session.flush = AsyncMock() + mock_session.commit = AsyncMock() + mock_session.add = MagicMock() + + # refresh populates server-default fields on the passed ORM object + async def _mock_refresh(obj: object) -> None: + obj.id = conv_id # type: ignore[attr-defined] + obj.created_at = now # type: ignore[attr-defined] + obj.updated_at = now # type: ignore[attr-defined] + + mock_session.refresh = _mock_refresh + + app = _make_app_with_session_override(mock_session) + + with ( + patch("shared.api.chat.configure_rls_hook"), + patch("shared.api.chat.current_tenant_id"), + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/api/portal/chat/conversations", + json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)}, + headers=_admin_headers(str(user_id)), + ) + + assert response.status_code in (200, 201) + data = response.json() + assert "id" in data + + +@pytest.mark.asyncio +async def test_create_conversation_rbac_forbidden() -> None: + """Non-member gets 403 when creating a conversation in a tenant they don't belong to.""" + tenant_id = uuid.uuid4() + agent_id = uuid.uuid4() + user_id = uuid.uuid4() + + # Membership check returns None (not a member) + mock_session = AsyncMock() + mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None)) + + app = _make_app_with_session_override(mock_session) + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/api/portal/chat/conversations", + json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)}, + headers=_stranger_headers(str(user_id)), + ) + + assert response.status_code == 403 + + +# --------------------------------------------------------------------------- +# Delete conversation (reset messages) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_delete_conversation_resets_messages() -> None: + """DELETE /api/portal/chat/conversations/{id} deletes messages but keeps conversation row.""" + user_id = uuid.uuid4() + conv_id = uuid.uuid4() + + mock_conv = MagicMock() + mock_conv.id = conv_id + mock_conv.user_id = user_id + mock_conv.tenant_id = uuid.uuid4() + + mock_session = AsyncMock() + mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv)) + mock_session.commit = AsyncMock() + + app = _make_app_with_session_override(mock_session) + + with ( + patch("shared.api.chat.configure_rls_hook"), + patch("shared.api.chat.current_tenant_id"), + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.delete( + f"/api/portal/chat/conversations/{conv_id}", + headers=_admin_headers(str(user_id)), + ) + + assert response.status_code == 200 + assert mock_session.execute.call_count >= 1 diff --git a/tests/unit/test_web_channel.py b/tests/unit/test_web_channel.py new file mode 100644 index 0000000..742d88e --- /dev/null +++ b/tests/unit/test_web_channel.py @@ -0,0 +1,312 @@ +""" +Unit tests for the web channel adapter. + +Tests: + - test_normalize_web_event: normalize_web_event returns KonstructMessage with channel=WEB + - test_normalize_web_event_thread_id: thread_id equals conversation_id + - test_normalize_web_event_sender: sender.user_id equals portal user UUID + - test_webchat_response_key: webchat_response_key returns correct namespaced key + - test_send_response_web_publishes_to_redis: _send_response("web", ...) publishes JSON to Redis + - test_typing_indicator_sent: WebSocket handler sends {"type": "typing"} before Celery dispatch +""" + +from __future__ import annotations + +import asyncio +import json +import uuid +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from shared.models.message import ChannelType, KonstructMessage +from shared.redis_keys import webchat_response_key + + +# --------------------------------------------------------------------------- +# Test webchat_response_key +# --------------------------------------------------------------------------- + + +def test_webchat_response_key_format() -> None: + """webchat_response_key returns correctly namespaced key.""" + tenant_id = "tenant-abc" + conversation_id = "conv-xyz" + key = webchat_response_key(tenant_id, conversation_id) + assert key == "tenant-abc:webchat:response:conv-xyz" + + +def test_webchat_response_key_isolation() -> None: + """Two tenants with same conversation_id get different keys.""" + key_a = webchat_response_key("tenant-a", "conv-1") + key_b = webchat_response_key("tenant-b", "conv-1") + assert key_a != key_b + + +# --------------------------------------------------------------------------- +# Test ChannelType.WEB +# --------------------------------------------------------------------------- + + +def test_channel_type_web_exists() -> None: + """ChannelType.WEB must exist with value 'web'.""" + assert ChannelType.WEB == "web" + assert ChannelType.WEB.value == "web" + + +# --------------------------------------------------------------------------- +# Test normalize_web_event +# --------------------------------------------------------------------------- + + +def test_normalize_web_event_returns_konstruct_message() -> None: + """normalize_web_event returns a KonstructMessage.""" + from gateway.channels.web import normalize_web_event + + tenant_id = str(uuid.uuid4()) + agent_id = str(uuid.uuid4()) + user_id = str(uuid.uuid4()) + conversation_id = str(uuid.uuid4()) + + event = { + "text": "Hello from the portal", + "tenant_id": tenant_id, + "agent_id": agent_id, + "user_id": user_id, + "display_name": "Portal User", + "conversation_id": conversation_id, + } + msg = normalize_web_event(event) + assert isinstance(msg, KonstructMessage) + + +def test_normalize_web_event_channel_is_web() -> None: + """normalize_web_event sets channel = ChannelType.WEB.""" + from gateway.channels.web import normalize_web_event + + event = { + "text": "test", + "tenant_id": str(uuid.uuid4()), + "agent_id": str(uuid.uuid4()), + "user_id": str(uuid.uuid4()), + "display_name": "User", + "conversation_id": str(uuid.uuid4()), + } + msg = normalize_web_event(event) + assert msg.channel == ChannelType.WEB + + +def test_normalize_web_event_thread_id_equals_conversation_id() -> None: + """normalize_web_event sets thread_id = conversation_id for memory scoping.""" + from gateway.channels.web import normalize_web_event + + conversation_id = str(uuid.uuid4()) + event = { + "text": "test", + "tenant_id": str(uuid.uuid4()), + "agent_id": str(uuid.uuid4()), + "user_id": str(uuid.uuid4()), + "display_name": "User", + "conversation_id": conversation_id, + } + msg = normalize_web_event(event) + assert msg.thread_id == conversation_id + + +def test_normalize_web_event_sender_user_id() -> None: + """normalize_web_event sets sender.user_id to the portal user UUID.""" + from gateway.channels.web import normalize_web_event + + user_id = str(uuid.uuid4()) + event = { + "text": "test", + "tenant_id": str(uuid.uuid4()), + "agent_id": str(uuid.uuid4()), + "user_id": user_id, + "display_name": "Portal User", + "conversation_id": str(uuid.uuid4()), + } + msg = normalize_web_event(event) + assert msg.sender.user_id == user_id + + +def test_normalize_web_event_channel_metadata() -> None: + """normalize_web_event populates channel_metadata with portal_user_id, tenant_id, conversation_id.""" + from gateway.channels.web import normalize_web_event + + tenant_id = str(uuid.uuid4()) + user_id = str(uuid.uuid4()) + conversation_id = str(uuid.uuid4()) + + event = { + "text": "test", + "tenant_id": tenant_id, + "agent_id": str(uuid.uuid4()), + "user_id": user_id, + "display_name": "User", + "conversation_id": conversation_id, + } + msg = normalize_web_event(event) + assert msg.channel_metadata["portal_user_id"] == user_id + assert msg.channel_metadata["tenant_id"] == tenant_id + assert msg.channel_metadata["conversation_id"] == conversation_id + + +def test_normalize_web_event_tenant_id() -> None: + """normalize_web_event sets tenant_id on the message.""" + from gateway.channels.web import normalize_web_event + + tenant_id = str(uuid.uuid4()) + event = { + "text": "hello", + "tenant_id": tenant_id, + "agent_id": str(uuid.uuid4()), + "user_id": str(uuid.uuid4()), + "display_name": "User", + "conversation_id": str(uuid.uuid4()), + } + msg = normalize_web_event(event) + assert msg.tenant_id == tenant_id + + +# --------------------------------------------------------------------------- +# Test _send_response web case publishes to Redis +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_send_response_web_publishes_to_redis() -> None: + """_send_response('web', ...) publishes JSON message to Redis webchat channel.""" + tenant_id = str(uuid.uuid4()) + conversation_id = str(uuid.uuid4()) + + mock_redis = AsyncMock() + mock_redis.publish = AsyncMock() + mock_redis.aclose = AsyncMock() + + with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis): + from orchestrator.tasks import _send_response + + extras = { + "conversation_id": conversation_id, + "tenant_id": tenant_id, + } + await _send_response("web", "Hello, this is the agent response!", extras) + + expected_channel = webchat_response_key(tenant_id, conversation_id) + mock_redis.publish.assert_called_once() + call_args = mock_redis.publish.call_args + assert call_args[0][0] == expected_channel + published_payload = json.loads(call_args[0][1]) + assert published_payload["type"] == "response" + assert published_payload["text"] == "Hello, this is the agent response!" + assert published_payload["conversation_id"] == conversation_id + + +@pytest.mark.asyncio +async def test_send_response_web_connection_cleanup() -> None: + """_send_response web case always closes Redis connection (try/finally).""" + mock_redis = AsyncMock() + mock_redis.publish = AsyncMock() + mock_redis.aclose = AsyncMock() + + with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis): + from orchestrator.tasks import _send_response + + await _send_response("web", "test", {"conversation_id": "c1", "tenant_id": "t1"}) + + mock_redis.aclose.assert_called_once() + + +@pytest.mark.asyncio +async def test_send_response_web_missing_conversation_id_logs_warning() -> None: + """_send_response web case logs warning if conversation_id missing.""" + mock_redis = AsyncMock() + + with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis): + with patch("orchestrator.tasks.logger") as mock_logger: + from orchestrator.tasks import _send_response + + await _send_response("web", "test", {"tenant_id": "t1"}) + + mock_logger.warning.assert_called() + mock_redis.publish.assert_not_called() + + +# --------------------------------------------------------------------------- +# Test typing indicator +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_typing_indicator_sent_before_dispatch() -> None: + """WebSocket handler sends {'type': 'typing'} immediately after receiving user message.""" + # We test the typing indicator by calling the handler function directly + # with a mocked WebSocket and mocked Celery dispatch. + from unittest.mock import AsyncMock, MagicMock, call, patch + + mock_ws = AsyncMock() + # First receive_json returns auth message, second returns the user message + mock_ws.receive_json = AsyncMock( + side_effect=[ + {"type": "auth", "userId": str(uuid.uuid4()), "role": "customer_operator", "tenantId": str(uuid.uuid4())}, + {"type": "message", "text": "Hello agent", "agentId": str(uuid.uuid4()), "conversationId": str(uuid.uuid4())}, + ] + ) + # accept must be awaitable + mock_ws.accept = AsyncMock() + mock_ws.send_json = AsyncMock() + + # Mock DB session that returns a conversation + mock_session = AsyncMock() + mock_conv = MagicMock() + mock_conv.id = uuid.uuid4() + mock_conv.tenant_id = uuid.uuid4() + mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv)) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + # Mock Redis pub-sub: raise after publishing once so handler exits + mock_redis = AsyncMock() + mock_pubsub = AsyncMock() + mock_pubsub.subscribe = AsyncMock() + mock_pubsub.get_message = AsyncMock(return_value={ + "type": "message", + "data": json.dumps({ + "type": "response", + "text": "Agent reply", + "conversation_id": str(uuid.uuid4()), + }), + }) + mock_pubsub.unsubscribe = AsyncMock() + mock_redis.pubsub = MagicMock(return_value=mock_pubsub) + mock_redis.aclose = AsyncMock() + + mock_handle = MagicMock() + + with ( + patch("gateway.channels.web.async_session_factory", return_value=mock_session), + patch("gateway.channels.web.aioredis.from_url", return_value=mock_redis), + patch("gateway.channels.web.handle_message") as mock_handle_msg, + patch("gateway.channels.web.configure_rls_hook"), + patch("gateway.channels.web.current_tenant_id"), + ): + mock_handle_msg.delay = MagicMock() + + from gateway.channels.web import _handle_websocket_connection + + # Run the handler — it should send typing indicator then process + # Use asyncio.wait_for to prevent infinite loop if something hangs + try: + await asyncio.wait_for( + _handle_websocket_connection(mock_ws, str(uuid.uuid4())), + timeout=2.0, + ) + except (asyncio.TimeoutError, StopAsyncIteration, Exception): + pass + + # Check typing was sent before Celery dispatch + send_calls = mock_ws.send_json.call_args_list + typing_calls = [c for c in send_calls if c[0][0].get("type") == "typing"] + assert len(typing_calls) >= 1, f"Expected typing indicator, got send_json calls: {send_calls}"