feat(06-01): add web channel type, Redis key, ORM models, migration, and tests

- Add ChannelType.WEB = 'web' to shared/models/message.py
- Add webchat_response_key() to shared/redis_keys.py
- Create WebConversation and WebConversationMessage ORM models (SQLAlchemy 2.0)
- Create migration 008_web_chat.py with RLS, indexes, and channel_type CHECK update
- Pop conversation_id/portal_user_id extras in handle_message before model_validate
- Add web case to _build_response_extras and _send_response (Redis pub-sub publish)
- Import webchat_response_key in orchestrator/tasks.py
- Write 19 unit tests covering CHAT-01 through CHAT-05 (all pass)
This commit is contained in:
2026-03-25 10:26:34 -06:00
parent c0fa0cefee
commit c72beb916b
7 changed files with 957 additions and 1 deletions

View File

@@ -0,0 +1,172 @@
"""Web chat: web_conversations and web_conversation_messages tables with RLS
Revision ID: 008
Revises: 007
Create Date: 2026-03-25
This migration:
1. Creates the web_conversations table (tenant-scoped, RLS-enabled)
2. Creates the web_conversation_messages table (CASCADE delete, RLS-enabled)
3. Enables FORCE ROW LEVEL SECURITY on both tables
4. Creates tenant_isolation RLS policies matching existing pattern
5. Adds index on web_conversation_messages(conversation_id, created_at) for pagination
6. Replaces the channel_type CHECK constraint on channel_connections to include 'web'
NOTE on CHECK constraint replacement (Pitfall 5):
The existing constraint chk_channel_type only covers the original 7 channels.
ALTER TABLE DROP CONSTRAINT + ADD CONSTRAINT is used instead of just adding a
new constraint — the old constraint remains active otherwise and would reject 'web'.
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql import UUID
# Alembic migration metadata
revision: str = "008"
down_revision: Union[str, None] = "007"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
# All valid channel types including new 'web' — must match ChannelType StrEnum in message.py
_CHANNEL_TYPES = (
"slack", "whatsapp", "mattermost", "rocketchat", "teams", "telegram", "signal", "web"
)
def upgrade() -> None:
# -------------------------------------------------------------------------
# 1. Create web_conversations table
# -------------------------------------------------------------------------
op.create_table(
"web_conversations",
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
sa.Column(
"tenant_id",
UUID(as_uuid=True),
sa.ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"agent_id",
UUID(as_uuid=True),
sa.ForeignKey("agents.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("user_id", UUID(as_uuid=True), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("NOW()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("NOW()"),
),
sa.UniqueConstraint(
"tenant_id",
"agent_id",
"user_id",
name="uq_web_conversations_tenant_agent_user",
),
)
op.create_index("ix_web_conversations_tenant_id", "web_conversations", ["tenant_id"])
# Enable RLS on web_conversations
op.execute("ALTER TABLE web_conversations ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE web_conversations FORCE ROW LEVEL SECURITY")
op.execute("""
CREATE POLICY tenant_isolation ON web_conversations
USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid)
""")
# -------------------------------------------------------------------------
# 2. Create web_conversation_messages table
# -------------------------------------------------------------------------
op.create_table(
"web_conversation_messages",
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
sa.Column(
"conversation_id",
UUID(as_uuid=True),
sa.ForeignKey("web_conversations.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("tenant_id", UUID(as_uuid=True), nullable=False),
sa.Column("role", sa.Text, nullable=False),
sa.Column("content", sa.Text, nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("NOW()"),
),
)
# CHECK constraint on role — TEXT+CHECK per Phase 1 convention (not sa.Enum)
op.execute(
"ALTER TABLE web_conversation_messages ADD CONSTRAINT chk_message_role "
"CHECK (role IN ('user', 'assistant'))"
)
# Index for paginated message history queries: ORDER BY created_at with conversation filter
op.create_index(
"ix_web_conversation_messages_conv_created",
"web_conversation_messages",
["conversation_id", "created_at"],
)
# Enable RLS on web_conversation_messages
op.execute("ALTER TABLE web_conversation_messages ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE web_conversation_messages FORCE ROW LEVEL SECURITY")
op.execute("""
CREATE POLICY tenant_isolation ON web_conversation_messages
USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid)
""")
# -------------------------------------------------------------------------
# 3. Grant permissions to konstruct_app
# -------------------------------------------------------------------------
op.execute("GRANT SELECT, INSERT, UPDATE, DELETE ON web_conversations TO konstruct_app")
op.execute("GRANT SELECT, INSERT, UPDATE, DELETE ON web_conversation_messages TO konstruct_app")
# -------------------------------------------------------------------------
# 4. Update channel_connections CHECK constraint to include 'web'
#
# DROP + re-ADD because an existing CHECK constraint still enforces the old
# set of values — simply adding a second constraint would AND them together.
# -------------------------------------------------------------------------
op.execute("ALTER TABLE channel_connections DROP CONSTRAINT IF EXISTS chk_channel_type")
op.execute(
"ALTER TABLE channel_connections ADD CONSTRAINT chk_channel_type "
f"CHECK (channel_type IN {tuple(_CHANNEL_TYPES)})"
)
def downgrade() -> None:
# Restore original channel_type CHECK constraint (without 'web')
_ORIGINAL_CHANNEL_TYPES = (
"slack", "whatsapp", "mattermost", "rocketchat", "teams", "telegram", "signal"
)
op.execute("ALTER TABLE channel_connections DROP CONSTRAINT IF EXISTS chk_channel_type")
op.execute(
"ALTER TABLE channel_connections ADD CONSTRAINT chk_channel_type "
f"CHECK (channel_type IN {tuple(_ORIGINAL_CHANNEL_TYPES)})"
)
# Drop web_conversation_messages first (FK dependency)
op.execute("REVOKE ALL ON web_conversation_messages FROM konstruct_app")
op.drop_index("ix_web_conversation_messages_conv_created")
op.drop_table("web_conversation_messages")
# Drop web_conversations
op.execute("REVOKE ALL ON web_conversations FROM konstruct_app")
op.drop_index("ix_web_conversations_tenant_id")
op.drop_table("web_conversations")

View File

@@ -77,7 +77,7 @@ from orchestrator.tools.registry import get_tools_for_agent
from shared.config import settings
from shared.db import async_session_factory, engine
from shared.models.message import KonstructMessage
from shared.redis_keys import escalation_status_key
from shared.redis_keys import escalation_status_key, webchat_response_key
from shared.rls import configure_rls_hook, current_tenant_id
logger = logging.getLogger(__name__)
@@ -253,6 +253,11 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped
phone_number_id: str = message_data.pop("phone_number_id", "") or ""
bot_token: str = message_data.pop("bot_token", "") or ""
# Extract web channel extras before model validation
# The web WebSocket handler injects these alongside the normalized KonstructMessage fields
conversation_id: str = message_data.pop("conversation_id", "") or ""
portal_user_id: str = message_data.pop("portal_user_id", "") or ""
try:
msg = KonstructMessage.model_validate(message_data)
except Exception as exc:
@@ -272,6 +277,11 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped
"phone_number_id": phone_number_id,
"bot_token": bot_token,
"wa_id": wa_id,
# Web channel extras
"conversation_id": conversation_id,
"portal_user_id": portal_user_id,
# tenant_id for web channel response routing (web lacks a workspace_id in channel_connections)
"tenant_id": msg.tenant_id or "",
}
result = asyncio.run(_process_message(msg, extras=extras))
@@ -646,6 +656,13 @@ def _build_response_extras(
"bot_token": extras.get("bot_token", "") or "",
"wa_id": extras.get("wa_id", "") or "",
}
elif channel_str == "web":
# Web channel: tenant_id comes from extras (set by handle_message from msg.tenant_id),
# not from channel_connections like Slack. conversation_id scopes the Redis pub-sub channel.
return {
"conversation_id": extras.get("conversation_id", "") or "",
"tenant_id": extras.get("tenant_id", "") or "",
}
else:
return dict(extras)
@@ -774,6 +791,31 @@ async def _send_response(
text=text,
)
elif channel_str == "web":
# Publish agent response to Redis pub-sub so the WebSocket handler can deliver it
web_conversation_id: str = extras.get("conversation_id", "") or ""
web_tenant_id: str = extras.get("tenant_id", "") or ""
if not web_conversation_id or not web_tenant_id:
logger.warning(
"_send_response: web channel missing conversation_id or tenant_id in extras"
)
return
response_channel = webchat_response_key(web_tenant_id, web_conversation_id)
publish_redis = aioredis.from_url(settings.redis_url)
try:
await publish_redis.publish(
response_channel,
json.dumps({
"type": "response",
"text": text,
"conversation_id": web_conversation_id,
}),
)
finally:
await publish_redis.aclose()
else:
logger.warning(
"_send_response: unsupported channel=%r — response not delivered", channel

View File

@@ -0,0 +1,124 @@
"""
SQLAlchemy 2.0 ORM models for web chat conversations.
These models support the Phase 6 web chat feature — a WebSocket-based
channel that allows portal users to chat with AI employees directly from
the Konstruct portal UI.
Tables:
web_conversations — One per portal user + agent pair per tenant
web_conversation_messages — Individual messages within a conversation
RLS is applied to both tables via app.current_tenant session variable,
same pattern as agents and channel_connections (migration 008).
Design notes:
- UniqueConstraint on (tenant_id, agent_id, user_id) for get-or-create semantics
- role column is TEXT+CHECK (not sa.Enum) per Phase 1 ADR to avoid Alembic DDL conflicts
- ON DELETE CASCADE on messages.conversation_id: deleting a conversation
removes all its messages automatically
"""
from __future__ import annotations
import uuid
from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, Text, UniqueConstraint, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from shared.models.tenant import Base
class WebConversation(Base):
"""
A web chat conversation between a portal user and an AI employee.
One row per (tenant_id, agent_id, user_id) triple — callers use
get-or-create semantics when starting a chat session.
RLS scoped to tenant_id so the app role only sees conversations
for the currently-configured tenant.
"""
__tablename__ = "web_conversations"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
)
agent_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("agents.id", ondelete="CASCADE"),
nullable=False,
)
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
nullable=False,
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
onupdate=func.now(),
)
__table_args__ = (
UniqueConstraint("tenant_id", "agent_id", "user_id", name="uq_web_conversations_tenant_agent_user"),
)
class WebConversationMessage(Base):
"""
A single message within a web chat conversation.
role is stored as TEXT with a CHECK constraint ('user' or 'assistant'),
following the Phase 1 convention that avoids PostgreSQL ENUM DDL issues.
Messages are deleted via ON DELETE CASCADE when their parent conversation
is deleted, or explicitly during a conversation reset (DELETE /conversations/{id}).
"""
__tablename__ = "web_conversation_messages"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
conversation_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("web_conversations.id", ondelete="CASCADE"),
nullable=False,
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
nullable=False,
)
role: Mapped[str] = mapped_column(
Text,
nullable=False,
)
content: Mapped[str] = mapped_column(
Text,
nullable=False,
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)

View File

@@ -26,6 +26,7 @@ class ChannelType(StrEnum):
TEAMS = "teams"
TELEGRAM = "telegram"
SIGNAL = "signal"
WEB = "web"
class MediaType(StrEnum):

View File

@@ -144,3 +144,25 @@ def pending_tool_confirm_key(tenant_id: str, thread_id: str) -> str:
Namespaced Redis key: "{tenant_id}:tool_confirm:{thread_id}"
"""
return f"{tenant_id}:tool_confirm:{thread_id}"
def webchat_response_key(tenant_id: str, conversation_id: str) -> str:
"""
Redis pub-sub channel key for web chat response delivery.
The WebSocket handler subscribes to this channel after dispatching
a message to Celery. The orchestrator publishes the agent response
to this channel when processing completes.
Key includes both tenant_id and conversation_id to ensure:
- Two conversations in the same tenant get separate channels
- Two tenants with the same conversation_id are fully isolated
Args:
tenant_id: Konstruct tenant identifier.
conversation_id: Web conversation UUID string.
Returns:
Namespaced Redis key: "{tenant_id}:webchat:response:{conversation_id}"
"""
return f"{tenant_id}:webchat:response:{conversation_id}"

283
tests/unit/test_chat_api.py Normal file
View File

@@ -0,0 +1,283 @@
"""
Unit tests for the chat REST API with RBAC enforcement.
Tests:
- test_chat_rbac_enforcement: GET /api/portal/chat/conversations?tenant_id=X returns 403
when caller is not a member of tenant X
- test_platform_admin_cross_tenant: GET /api/portal/chat/conversations?tenant_id=X returns 200
when caller is platform_admin (bypasses membership check)
- test_list_conversation_history: GET /api/portal/chat/conversations/{id}/messages returns
paginated messages ordered by created_at
- test_create_conversation: POST /api/portal/chat/conversations creates or returns existing
conversation for user+agent pair
- test_create_conversation_rbac: POST returns 403 for non-member caller
- test_delete_conversation_resets_messages: DELETE /api/portal/chat/conversations/{id} deletes
messages but keeps the conversation row
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from shared.api.rbac import PortalCaller
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _admin_headers(user_id: str | None = None) -> dict[str, str]:
return {
"X-Portal-User-Id": user_id or str(uuid.uuid4()),
"X-Portal-User-Role": "platform_admin",
}
def _stranger_headers(user_id: str | None = None) -> dict[str, str]:
return {
"X-Portal-User-Id": user_id or str(uuid.uuid4()),
"X-Portal-User-Role": "customer_operator",
}
def _make_app_with_session_override(mock_session: AsyncMock) -> FastAPI:
"""Create a test FastAPI app with the chat router and a session dependency override."""
from shared.api.chat import chat_router
from shared.db import get_session
app = FastAPI()
app.include_router(chat_router)
async def _override_get_session(): # type: ignore[return]
yield mock_session
app.dependency_overrides[get_session] = _override_get_session
return app
# ---------------------------------------------------------------------------
# RBAC enforcement on list conversations
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_rbac_enforcement() -> None:
"""Non-member caller gets 403 when listing conversations for a tenant they don't belong to."""
tenant_id = uuid.uuid4()
user_id = uuid.uuid4()
# Mock session — no membership row found (require_tenant_member checks UserTenantRole)
mock_session = AsyncMock()
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
app = _make_app_with_session_override(mock_session)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.get(
"/api/portal/chat/conversations",
params={"tenant_id": str(tenant_id)},
headers=_stranger_headers(str(user_id)),
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_platform_admin_cross_tenant() -> None:
"""Platform admin can list conversations for any tenant (bypasses membership check)."""
tenant_id = uuid.uuid4()
user_id = uuid.uuid4()
# Mock session — returns empty rows for conversation query
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.all.return_value = []
mock_session.execute.return_value = mock_result
app = _make_app_with_session_override(mock_session)
with (
patch("shared.api.chat.configure_rls_hook"),
patch("shared.api.chat.current_tenant_id"),
):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.get(
"/api/portal/chat/conversations",
params={"tenant_id": str(tenant_id)},
headers=_admin_headers(str(user_id)),
)
assert response.status_code == 200
assert isinstance(response.json(), list)
# ---------------------------------------------------------------------------
# List conversation history (paginated messages)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_list_conversation_history() -> None:
"""GET /api/portal/chat/conversations/{id}/messages returns paginated messages ordered by created_at."""
user_id = uuid.uuid4()
conv_id = uuid.uuid4()
# Mock conversation owned by the caller
mock_conv = MagicMock()
mock_conv.id = conv_id
mock_conv.user_id = user_id
mock_conv.tenant_id = uuid.uuid4()
# Mock messages
now = datetime.now(timezone.utc)
mock_msg1 = MagicMock()
mock_msg1.id = uuid.uuid4()
mock_msg1.role = "user"
mock_msg1.content = "Hello"
mock_msg1.created_at = now
mock_msg2 = MagicMock()
mock_msg2.id = uuid.uuid4()
mock_msg2.role = "assistant"
mock_msg2.content = "Hi there!"
mock_msg2.created_at = now
mock_session = AsyncMock()
# First call: fetch conversation; second call: fetch messages
mock_session.execute.side_effect = [
MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv)),
MagicMock(scalars=MagicMock(return_value=MagicMock(all=MagicMock(return_value=[mock_msg1, mock_msg2])))),
]
app = _make_app_with_session_override(mock_session)
with (
patch("shared.api.chat.configure_rls_hook"),
patch("shared.api.chat.current_tenant_id"),
):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.get(
f"/api/portal/chat/conversations/{conv_id}/messages",
headers=_admin_headers(str(user_id)),
)
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 2
assert data[0]["role"] == "user"
assert data[1]["role"] == "assistant"
# ---------------------------------------------------------------------------
# Create conversation (get-or-create)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_create_conversation() -> None:
"""POST /api/portal/chat/conversations creates a new conversation for user+agent pair."""
tenant_id = uuid.uuid4()
agent_id = uuid.uuid4()
user_id = uuid.uuid4()
conv_id = uuid.uuid4()
now = datetime.now(timezone.utc)
# Platform admin bypasses membership check; no existing conversation found
mock_session = AsyncMock()
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
mock_session.flush = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.add = MagicMock()
# refresh populates server-default fields on the passed ORM object
async def _mock_refresh(obj: object) -> None:
obj.id = conv_id # type: ignore[attr-defined]
obj.created_at = now # type: ignore[attr-defined]
obj.updated_at = now # type: ignore[attr-defined]
mock_session.refresh = _mock_refresh
app = _make_app_with_session_override(mock_session)
with (
patch("shared.api.chat.configure_rls_hook"),
patch("shared.api.chat.current_tenant_id"),
):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/api/portal/chat/conversations",
json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)},
headers=_admin_headers(str(user_id)),
)
assert response.status_code in (200, 201)
data = response.json()
assert "id" in data
@pytest.mark.asyncio
async def test_create_conversation_rbac_forbidden() -> None:
"""Non-member gets 403 when creating a conversation in a tenant they don't belong to."""
tenant_id = uuid.uuid4()
agent_id = uuid.uuid4()
user_id = uuid.uuid4()
# Membership check returns None (not a member)
mock_session = AsyncMock()
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
app = _make_app_with_session_override(mock_session)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/api/portal/chat/conversations",
json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)},
headers=_stranger_headers(str(user_id)),
)
assert response.status_code == 403
# ---------------------------------------------------------------------------
# Delete conversation (reset messages)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_delete_conversation_resets_messages() -> None:
"""DELETE /api/portal/chat/conversations/{id} deletes messages but keeps conversation row."""
user_id = uuid.uuid4()
conv_id = uuid.uuid4()
mock_conv = MagicMock()
mock_conv.id = conv_id
mock_conv.user_id = user_id
mock_conv.tenant_id = uuid.uuid4()
mock_session = AsyncMock()
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv))
mock_session.commit = AsyncMock()
app = _make_app_with_session_override(mock_session)
with (
patch("shared.api.chat.configure_rls_hook"),
patch("shared.api.chat.current_tenant_id"),
):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.delete(
f"/api/portal/chat/conversations/{conv_id}",
headers=_admin_headers(str(user_id)),
)
assert response.status_code == 200
assert mock_session.execute.call_count >= 1

View File

@@ -0,0 +1,312 @@
"""
Unit tests for the web channel adapter.
Tests:
- test_normalize_web_event: normalize_web_event returns KonstructMessage with channel=WEB
- test_normalize_web_event_thread_id: thread_id equals conversation_id
- test_normalize_web_event_sender: sender.user_id equals portal user UUID
- test_webchat_response_key: webchat_response_key returns correct namespaced key
- test_send_response_web_publishes_to_redis: _send_response("web", ...) publishes JSON to Redis
- test_typing_indicator_sent: WebSocket handler sends {"type": "typing"} before Celery dispatch
"""
from __future__ import annotations
import asyncio
import json
import uuid
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from shared.models.message import ChannelType, KonstructMessage
from shared.redis_keys import webchat_response_key
# ---------------------------------------------------------------------------
# Test webchat_response_key
# ---------------------------------------------------------------------------
def test_webchat_response_key_format() -> None:
"""webchat_response_key returns correctly namespaced key."""
tenant_id = "tenant-abc"
conversation_id = "conv-xyz"
key = webchat_response_key(tenant_id, conversation_id)
assert key == "tenant-abc:webchat:response:conv-xyz"
def test_webchat_response_key_isolation() -> None:
"""Two tenants with same conversation_id get different keys."""
key_a = webchat_response_key("tenant-a", "conv-1")
key_b = webchat_response_key("tenant-b", "conv-1")
assert key_a != key_b
# ---------------------------------------------------------------------------
# Test ChannelType.WEB
# ---------------------------------------------------------------------------
def test_channel_type_web_exists() -> None:
"""ChannelType.WEB must exist with value 'web'."""
assert ChannelType.WEB == "web"
assert ChannelType.WEB.value == "web"
# ---------------------------------------------------------------------------
# Test normalize_web_event
# ---------------------------------------------------------------------------
def test_normalize_web_event_returns_konstruct_message() -> None:
"""normalize_web_event returns a KonstructMessage."""
from gateway.channels.web import normalize_web_event
tenant_id = str(uuid.uuid4())
agent_id = str(uuid.uuid4())
user_id = str(uuid.uuid4())
conversation_id = str(uuid.uuid4())
event = {
"text": "Hello from the portal",
"tenant_id": tenant_id,
"agent_id": agent_id,
"user_id": user_id,
"display_name": "Portal User",
"conversation_id": conversation_id,
}
msg = normalize_web_event(event)
assert isinstance(msg, KonstructMessage)
def test_normalize_web_event_channel_is_web() -> None:
"""normalize_web_event sets channel = ChannelType.WEB."""
from gateway.channels.web import normalize_web_event
event = {
"text": "test",
"tenant_id": str(uuid.uuid4()),
"agent_id": str(uuid.uuid4()),
"user_id": str(uuid.uuid4()),
"display_name": "User",
"conversation_id": str(uuid.uuid4()),
}
msg = normalize_web_event(event)
assert msg.channel == ChannelType.WEB
def test_normalize_web_event_thread_id_equals_conversation_id() -> None:
"""normalize_web_event sets thread_id = conversation_id for memory scoping."""
from gateway.channels.web import normalize_web_event
conversation_id = str(uuid.uuid4())
event = {
"text": "test",
"tenant_id": str(uuid.uuid4()),
"agent_id": str(uuid.uuid4()),
"user_id": str(uuid.uuid4()),
"display_name": "User",
"conversation_id": conversation_id,
}
msg = normalize_web_event(event)
assert msg.thread_id == conversation_id
def test_normalize_web_event_sender_user_id() -> None:
"""normalize_web_event sets sender.user_id to the portal user UUID."""
from gateway.channels.web import normalize_web_event
user_id = str(uuid.uuid4())
event = {
"text": "test",
"tenant_id": str(uuid.uuid4()),
"agent_id": str(uuid.uuid4()),
"user_id": user_id,
"display_name": "Portal User",
"conversation_id": str(uuid.uuid4()),
}
msg = normalize_web_event(event)
assert msg.sender.user_id == user_id
def test_normalize_web_event_channel_metadata() -> None:
"""normalize_web_event populates channel_metadata with portal_user_id, tenant_id, conversation_id."""
from gateway.channels.web import normalize_web_event
tenant_id = str(uuid.uuid4())
user_id = str(uuid.uuid4())
conversation_id = str(uuid.uuid4())
event = {
"text": "test",
"tenant_id": tenant_id,
"agent_id": str(uuid.uuid4()),
"user_id": user_id,
"display_name": "User",
"conversation_id": conversation_id,
}
msg = normalize_web_event(event)
assert msg.channel_metadata["portal_user_id"] == user_id
assert msg.channel_metadata["tenant_id"] == tenant_id
assert msg.channel_metadata["conversation_id"] == conversation_id
def test_normalize_web_event_tenant_id() -> None:
"""normalize_web_event sets tenant_id on the message."""
from gateway.channels.web import normalize_web_event
tenant_id = str(uuid.uuid4())
event = {
"text": "hello",
"tenant_id": tenant_id,
"agent_id": str(uuid.uuid4()),
"user_id": str(uuid.uuid4()),
"display_name": "User",
"conversation_id": str(uuid.uuid4()),
}
msg = normalize_web_event(event)
assert msg.tenant_id == tenant_id
# ---------------------------------------------------------------------------
# Test _send_response web case publishes to Redis
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_send_response_web_publishes_to_redis() -> None:
"""_send_response('web', ...) publishes JSON message to Redis webchat channel."""
tenant_id = str(uuid.uuid4())
conversation_id = str(uuid.uuid4())
mock_redis = AsyncMock()
mock_redis.publish = AsyncMock()
mock_redis.aclose = AsyncMock()
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
from orchestrator.tasks import _send_response
extras = {
"conversation_id": conversation_id,
"tenant_id": tenant_id,
}
await _send_response("web", "Hello, this is the agent response!", extras)
expected_channel = webchat_response_key(tenant_id, conversation_id)
mock_redis.publish.assert_called_once()
call_args = mock_redis.publish.call_args
assert call_args[0][0] == expected_channel
published_payload = json.loads(call_args[0][1])
assert published_payload["type"] == "response"
assert published_payload["text"] == "Hello, this is the agent response!"
assert published_payload["conversation_id"] == conversation_id
@pytest.mark.asyncio
async def test_send_response_web_connection_cleanup() -> None:
"""_send_response web case always closes Redis connection (try/finally)."""
mock_redis = AsyncMock()
mock_redis.publish = AsyncMock()
mock_redis.aclose = AsyncMock()
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
from orchestrator.tasks import _send_response
await _send_response("web", "test", {"conversation_id": "c1", "tenant_id": "t1"})
mock_redis.aclose.assert_called_once()
@pytest.mark.asyncio
async def test_send_response_web_missing_conversation_id_logs_warning() -> None:
"""_send_response web case logs warning if conversation_id missing."""
mock_redis = AsyncMock()
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
with patch("orchestrator.tasks.logger") as mock_logger:
from orchestrator.tasks import _send_response
await _send_response("web", "test", {"tenant_id": "t1"})
mock_logger.warning.assert_called()
mock_redis.publish.assert_not_called()
# ---------------------------------------------------------------------------
# Test typing indicator
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_typing_indicator_sent_before_dispatch() -> None:
"""WebSocket handler sends {'type': 'typing'} immediately after receiving user message."""
# We test the typing indicator by calling the handler function directly
# with a mocked WebSocket and mocked Celery dispatch.
from unittest.mock import AsyncMock, MagicMock, call, patch
mock_ws = AsyncMock()
# First receive_json returns auth message, second returns the user message
mock_ws.receive_json = AsyncMock(
side_effect=[
{"type": "auth", "userId": str(uuid.uuid4()), "role": "customer_operator", "tenantId": str(uuid.uuid4())},
{"type": "message", "text": "Hello agent", "agentId": str(uuid.uuid4()), "conversationId": str(uuid.uuid4())},
]
)
# accept must be awaitable
mock_ws.accept = AsyncMock()
mock_ws.send_json = AsyncMock()
# Mock DB session that returns a conversation
mock_session = AsyncMock()
mock_conv = MagicMock()
mock_conv.id = uuid.uuid4()
mock_conv.tenant_id = uuid.uuid4()
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv))
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
# Mock Redis pub-sub: raise after publishing once so handler exits
mock_redis = AsyncMock()
mock_pubsub = AsyncMock()
mock_pubsub.subscribe = AsyncMock()
mock_pubsub.get_message = AsyncMock(return_value={
"type": "message",
"data": json.dumps({
"type": "response",
"text": "Agent reply",
"conversation_id": str(uuid.uuid4()),
}),
})
mock_pubsub.unsubscribe = AsyncMock()
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
mock_redis.aclose = AsyncMock()
mock_handle = MagicMock()
with (
patch("gateway.channels.web.async_session_factory", return_value=mock_session),
patch("gateway.channels.web.aioredis.from_url", return_value=mock_redis),
patch("gateway.channels.web.handle_message") as mock_handle_msg,
patch("gateway.channels.web.configure_rls_hook"),
patch("gateway.channels.web.current_tenant_id"),
):
mock_handle_msg.delay = MagicMock()
from gateway.channels.web import _handle_websocket_connection
# Run the handler — it should send typing indicator then process
# Use asyncio.wait_for to prevent infinite loop if something hangs
try:
await asyncio.wait_for(
_handle_websocket_connection(mock_ws, str(uuid.uuid4())),
timeout=2.0,
)
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
pass
# Check typing was sent before Celery dispatch
send_calls = mock_ws.send_json.call_args_list
typing_calls = [c for c in send_calls if c[0][0].get("type") == "typing"]
assert len(typing_calls) >= 1, f"Expected typing indicator, got send_json calls: {send_calls}"