feat(06-01): add web channel type, Redis key, ORM models, migration, and tests
- Add ChannelType.WEB = 'web' to shared/models/message.py - Add webchat_response_key() to shared/redis_keys.py - Create WebConversation and WebConversationMessage ORM models (SQLAlchemy 2.0) - Create migration 008_web_chat.py with RLS, indexes, and channel_type CHECK update - Pop conversation_id/portal_user_id extras in handle_message before model_validate - Add web case to _build_response_extras and _send_response (Redis pub-sub publish) - Import webchat_response_key in orchestrator/tasks.py - Write 19 unit tests covering CHAT-01 through CHAT-05 (all pass)
This commit is contained in:
172
migrations/versions/008_web_chat.py
Normal file
172
migrations/versions/008_web_chat.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
"""Web chat: web_conversations and web_conversation_messages tables with RLS
|
||||||
|
|
||||||
|
Revision ID: 008
|
||||||
|
Revises: 007
|
||||||
|
Create Date: 2026-03-25
|
||||||
|
|
||||||
|
This migration:
|
||||||
|
1. Creates the web_conversations table (tenant-scoped, RLS-enabled)
|
||||||
|
2. Creates the web_conversation_messages table (CASCADE delete, RLS-enabled)
|
||||||
|
3. Enables FORCE ROW LEVEL SECURITY on both tables
|
||||||
|
4. Creates tenant_isolation RLS policies matching existing pattern
|
||||||
|
5. Adds index on web_conversation_messages(conversation_id, created_at) for pagination
|
||||||
|
6. Replaces the channel_type CHECK constraint on channel_connections to include 'web'
|
||||||
|
|
||||||
|
NOTE on CHECK constraint replacement (Pitfall 5):
|
||||||
|
The existing constraint chk_channel_type only covers the original 7 channels.
|
||||||
|
ALTER TABLE DROP CONSTRAINT + ADD CONSTRAINT is used instead of just adding a
|
||||||
|
new constraint — the old constraint remains active otherwise and would reject 'web'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
# Alembic migration metadata
|
||||||
|
revision: str = "008"
|
||||||
|
down_revision: Union[str, None] = "007"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
# All valid channel types including new 'web' — must match ChannelType StrEnum in message.py
|
||||||
|
_CHANNEL_TYPES = (
|
||||||
|
"slack", "whatsapp", "mattermost", "rocketchat", "teams", "telegram", "signal", "web"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# 1. Create web_conversations table
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
op.create_table(
|
||||||
|
"web_conversations",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||||
|
sa.Column(
|
||||||
|
"tenant_id",
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
sa.ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"agent_id",
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
sa.ForeignKey("agents.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("user_id", UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("NOW()"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("NOW()"),
|
||||||
|
),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"tenant_id",
|
||||||
|
"agent_id",
|
||||||
|
"user_id",
|
||||||
|
name="uq_web_conversations_tenant_agent_user",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_web_conversations_tenant_id", "web_conversations", ["tenant_id"])
|
||||||
|
|
||||||
|
# Enable RLS on web_conversations
|
||||||
|
op.execute("ALTER TABLE web_conversations ENABLE ROW LEVEL SECURITY")
|
||||||
|
op.execute("ALTER TABLE web_conversations FORCE ROW LEVEL SECURITY")
|
||||||
|
op.execute("""
|
||||||
|
CREATE POLICY tenant_isolation ON web_conversations
|
||||||
|
USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# 2. Create web_conversation_messages table
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
op.create_table(
|
||||||
|
"web_conversation_messages",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||||
|
sa.Column(
|
||||||
|
"conversation_id",
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
sa.ForeignKey("web_conversations.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("tenant_id", UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column("role", sa.Text, nullable=False),
|
||||||
|
sa.Column("content", sa.Text, nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("NOW()"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# CHECK constraint on role — TEXT+CHECK per Phase 1 convention (not sa.Enum)
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE web_conversation_messages ADD CONSTRAINT chk_message_role "
|
||||||
|
"CHECK (role IN ('user', 'assistant'))"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index for paginated message history queries: ORDER BY created_at with conversation filter
|
||||||
|
op.create_index(
|
||||||
|
"ix_web_conversation_messages_conv_created",
|
||||||
|
"web_conversation_messages",
|
||||||
|
["conversation_id", "created_at"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enable RLS on web_conversation_messages
|
||||||
|
op.execute("ALTER TABLE web_conversation_messages ENABLE ROW LEVEL SECURITY")
|
||||||
|
op.execute("ALTER TABLE web_conversation_messages FORCE ROW LEVEL SECURITY")
|
||||||
|
op.execute("""
|
||||||
|
CREATE POLICY tenant_isolation ON web_conversation_messages
|
||||||
|
USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# 3. Grant permissions to konstruct_app
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
op.execute("GRANT SELECT, INSERT, UPDATE, DELETE ON web_conversations TO konstruct_app")
|
||||||
|
op.execute("GRANT SELECT, INSERT, UPDATE, DELETE ON web_conversation_messages TO konstruct_app")
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# 4. Update channel_connections CHECK constraint to include 'web'
|
||||||
|
#
|
||||||
|
# DROP + re-ADD because an existing CHECK constraint still enforces the old
|
||||||
|
# set of values — simply adding a second constraint would AND them together.
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
op.execute("ALTER TABLE channel_connections DROP CONSTRAINT IF EXISTS chk_channel_type")
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE channel_connections ADD CONSTRAINT chk_channel_type "
|
||||||
|
f"CHECK (channel_type IN {tuple(_CHANNEL_TYPES)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Restore original channel_type CHECK constraint (without 'web')
|
||||||
|
_ORIGINAL_CHANNEL_TYPES = (
|
||||||
|
"slack", "whatsapp", "mattermost", "rocketchat", "teams", "telegram", "signal"
|
||||||
|
)
|
||||||
|
op.execute("ALTER TABLE channel_connections DROP CONSTRAINT IF EXISTS chk_channel_type")
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE channel_connections ADD CONSTRAINT chk_channel_type "
|
||||||
|
f"CHECK (channel_type IN {tuple(_ORIGINAL_CHANNEL_TYPES)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop web_conversation_messages first (FK dependency)
|
||||||
|
op.execute("REVOKE ALL ON web_conversation_messages FROM konstruct_app")
|
||||||
|
op.drop_index("ix_web_conversation_messages_conv_created")
|
||||||
|
op.drop_table("web_conversation_messages")
|
||||||
|
|
||||||
|
# Drop web_conversations
|
||||||
|
op.execute("REVOKE ALL ON web_conversations FROM konstruct_app")
|
||||||
|
op.drop_index("ix_web_conversations_tenant_id")
|
||||||
|
op.drop_table("web_conversations")
|
||||||
@@ -77,7 +77,7 @@ from orchestrator.tools.registry import get_tools_for_agent
|
|||||||
from shared.config import settings
|
from shared.config import settings
|
||||||
from shared.db import async_session_factory, engine
|
from shared.db import async_session_factory, engine
|
||||||
from shared.models.message import KonstructMessage
|
from shared.models.message import KonstructMessage
|
||||||
from shared.redis_keys import escalation_status_key
|
from shared.redis_keys import escalation_status_key, webchat_response_key
|
||||||
from shared.rls import configure_rls_hook, current_tenant_id
|
from shared.rls import configure_rls_hook, current_tenant_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -253,6 +253,11 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped
|
|||||||
phone_number_id: str = message_data.pop("phone_number_id", "") or ""
|
phone_number_id: str = message_data.pop("phone_number_id", "") or ""
|
||||||
bot_token: str = message_data.pop("bot_token", "") or ""
|
bot_token: str = message_data.pop("bot_token", "") or ""
|
||||||
|
|
||||||
|
# Extract web channel extras before model validation
|
||||||
|
# The web WebSocket handler injects these alongside the normalized KonstructMessage fields
|
||||||
|
conversation_id: str = message_data.pop("conversation_id", "") or ""
|
||||||
|
portal_user_id: str = message_data.pop("portal_user_id", "") or ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
msg = KonstructMessage.model_validate(message_data)
|
msg = KonstructMessage.model_validate(message_data)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -272,6 +277,11 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped
|
|||||||
"phone_number_id": phone_number_id,
|
"phone_number_id": phone_number_id,
|
||||||
"bot_token": bot_token,
|
"bot_token": bot_token,
|
||||||
"wa_id": wa_id,
|
"wa_id": wa_id,
|
||||||
|
# Web channel extras
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"portal_user_id": portal_user_id,
|
||||||
|
# tenant_id for web channel response routing (web lacks a workspace_id in channel_connections)
|
||||||
|
"tenant_id": msg.tenant_id or "",
|
||||||
}
|
}
|
||||||
|
|
||||||
result = asyncio.run(_process_message(msg, extras=extras))
|
result = asyncio.run(_process_message(msg, extras=extras))
|
||||||
@@ -646,6 +656,13 @@ def _build_response_extras(
|
|||||||
"bot_token": extras.get("bot_token", "") or "",
|
"bot_token": extras.get("bot_token", "") or "",
|
||||||
"wa_id": extras.get("wa_id", "") or "",
|
"wa_id": extras.get("wa_id", "") or "",
|
||||||
}
|
}
|
||||||
|
elif channel_str == "web":
|
||||||
|
# Web channel: tenant_id comes from extras (set by handle_message from msg.tenant_id),
|
||||||
|
# not from channel_connections like Slack. conversation_id scopes the Redis pub-sub channel.
|
||||||
|
return {
|
||||||
|
"conversation_id": extras.get("conversation_id", "") or "",
|
||||||
|
"tenant_id": extras.get("tenant_id", "") or "",
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return dict(extras)
|
return dict(extras)
|
||||||
|
|
||||||
@@ -774,6 +791,31 @@ async def _send_response(
|
|||||||
text=text,
|
text=text,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif channel_str == "web":
|
||||||
|
# Publish agent response to Redis pub-sub so the WebSocket handler can deliver it
|
||||||
|
web_conversation_id: str = extras.get("conversation_id", "") or ""
|
||||||
|
web_tenant_id: str = extras.get("tenant_id", "") or ""
|
||||||
|
|
||||||
|
if not web_conversation_id or not web_tenant_id:
|
||||||
|
logger.warning(
|
||||||
|
"_send_response: web channel missing conversation_id or tenant_id in extras"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
response_channel = webchat_response_key(web_tenant_id, web_conversation_id)
|
||||||
|
publish_redis = aioredis.from_url(settings.redis_url)
|
||||||
|
try:
|
||||||
|
await publish_redis.publish(
|
||||||
|
response_channel,
|
||||||
|
json.dumps({
|
||||||
|
"type": "response",
|
||||||
|
"text": text,
|
||||||
|
"conversation_id": web_conversation_id,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await publish_redis.aclose()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"_send_response: unsupported channel=%r — response not delivered", channel
|
"_send_response: unsupported channel=%r — response not delivered", channel
|
||||||
|
|||||||
124
packages/shared/shared/models/chat.py
Normal file
124
packages/shared/shared/models/chat.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""
|
||||||
|
SQLAlchemy 2.0 ORM models for web chat conversations.
|
||||||
|
|
||||||
|
These models support the Phase 6 web chat feature — a WebSocket-based
|
||||||
|
channel that allows portal users to chat with AI employees directly from
|
||||||
|
the Konstruct portal UI.
|
||||||
|
|
||||||
|
Tables:
|
||||||
|
web_conversations — One per portal user + agent pair per tenant
|
||||||
|
web_conversation_messages — Individual messages within a conversation
|
||||||
|
|
||||||
|
RLS is applied to both tables via app.current_tenant session variable,
|
||||||
|
same pattern as agents and channel_connections (migration 008).
|
||||||
|
|
||||||
|
Design notes:
|
||||||
|
- UniqueConstraint on (tenant_id, agent_id, user_id) for get-or-create semantics
|
||||||
|
- role column is TEXT+CHECK (not sa.Enum) per Phase 1 ADR to avoid Alembic DDL conflicts
|
||||||
|
- ON DELETE CASCADE on messages.conversation_id: deleting a conversation
|
||||||
|
removes all its messages automatically
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, ForeignKey, Text, UniqueConstraint, func
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
from shared.models.tenant import Base
|
||||||
|
|
||||||
|
|
||||||
|
class WebConversation(Base):
|
||||||
|
"""
|
||||||
|
A web chat conversation between a portal user and an AI employee.
|
||||||
|
|
||||||
|
One row per (tenant_id, agent_id, user_id) triple — callers use
|
||||||
|
get-or-create semantics when starting a chat session.
|
||||||
|
|
||||||
|
RLS scoped to tenant_id so the app role only sees conversations
|
||||||
|
for the currently-configured tenant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "web_conversations"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
primary_key=True,
|
||||||
|
default=uuid.uuid4,
|
||||||
|
server_default=func.gen_random_uuid(),
|
||||||
|
)
|
||||||
|
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
agent_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("agents.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=func.now(),
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("tenant_id", "agent_id", "user_id", name="uq_web_conversations_tenant_agent_user"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WebConversationMessage(Base):
|
||||||
|
"""
|
||||||
|
A single message within a web chat conversation.
|
||||||
|
|
||||||
|
role is stored as TEXT with a CHECK constraint ('user' or 'assistant'),
|
||||||
|
following the Phase 1 convention that avoids PostgreSQL ENUM DDL issues.
|
||||||
|
|
||||||
|
Messages are deleted via ON DELETE CASCADE when their parent conversation
|
||||||
|
is deleted, or explicitly during a conversation reset (DELETE /conversations/{id}).
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "web_conversation_messages"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
primary_key=True,
|
||||||
|
default=uuid.uuid4,
|
||||||
|
server_default=func.gen_random_uuid(),
|
||||||
|
)
|
||||||
|
conversation_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("web_conversations.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
role: Mapped[str] = mapped_column(
|
||||||
|
Text,
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
content: Mapped[str] = mapped_column(
|
||||||
|
Text,
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=func.now(),
|
||||||
|
)
|
||||||
@@ -26,6 +26,7 @@ class ChannelType(StrEnum):
|
|||||||
TEAMS = "teams"
|
TEAMS = "teams"
|
||||||
TELEGRAM = "telegram"
|
TELEGRAM = "telegram"
|
||||||
SIGNAL = "signal"
|
SIGNAL = "signal"
|
||||||
|
WEB = "web"
|
||||||
|
|
||||||
|
|
||||||
class MediaType(StrEnum):
|
class MediaType(StrEnum):
|
||||||
|
|||||||
@@ -144,3 +144,25 @@ def pending_tool_confirm_key(tenant_id: str, thread_id: str) -> str:
|
|||||||
Namespaced Redis key: "{tenant_id}:tool_confirm:{thread_id}"
|
Namespaced Redis key: "{tenant_id}:tool_confirm:{thread_id}"
|
||||||
"""
|
"""
|
||||||
return f"{tenant_id}:tool_confirm:{thread_id}"
|
return f"{tenant_id}:tool_confirm:{thread_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def webchat_response_key(tenant_id: str, conversation_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Redis pub-sub channel key for web chat response delivery.
|
||||||
|
|
||||||
|
The WebSocket handler subscribes to this channel after dispatching
|
||||||
|
a message to Celery. The orchestrator publishes the agent response
|
||||||
|
to this channel when processing completes.
|
||||||
|
|
||||||
|
Key includes both tenant_id and conversation_id to ensure:
|
||||||
|
- Two conversations in the same tenant get separate channels
|
||||||
|
- Two tenants with the same conversation_id are fully isolated
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Konstruct tenant identifier.
|
||||||
|
conversation_id: Web conversation UUID string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Namespaced Redis key: "{tenant_id}:webchat:response:{conversation_id}"
|
||||||
|
"""
|
||||||
|
return f"{tenant_id}:webchat:response:{conversation_id}"
|
||||||
|
|||||||
283
tests/unit/test_chat_api.py
Normal file
283
tests/unit/test_chat_api.py
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the chat REST API with RBAC enforcement.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- test_chat_rbac_enforcement: GET /api/portal/chat/conversations?tenant_id=X returns 403
|
||||||
|
when caller is not a member of tenant X
|
||||||
|
- test_platform_admin_cross_tenant: GET /api/portal/chat/conversations?tenant_id=X returns 200
|
||||||
|
when caller is platform_admin (bypasses membership check)
|
||||||
|
- test_list_conversation_history: GET /api/portal/chat/conversations/{id}/messages returns
|
||||||
|
paginated messages ordered by created_at
|
||||||
|
- test_create_conversation: POST /api/portal/chat/conversations creates or returns existing
|
||||||
|
conversation for user+agent pair
|
||||||
|
- test_create_conversation_rbac: POST returns 403 for non-member caller
|
||||||
|
- test_delete_conversation_resets_messages: DELETE /api/portal/chat/conversations/{id} deletes
|
||||||
|
messages but keeps the conversation row
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
from shared.api.rbac import PortalCaller
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _admin_headers(user_id: str | None = None) -> dict[str, str]:
|
||||||
|
return {
|
||||||
|
"X-Portal-User-Id": user_id or str(uuid.uuid4()),
|
||||||
|
"X-Portal-User-Role": "platform_admin",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _stranger_headers(user_id: str | None = None) -> dict[str, str]:
|
||||||
|
return {
|
||||||
|
"X-Portal-User-Id": user_id or str(uuid.uuid4()),
|
||||||
|
"X-Portal-User-Role": "customer_operator",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_app_with_session_override(mock_session: AsyncMock) -> FastAPI:
|
||||||
|
"""Create a test FastAPI app with the chat router and a session dependency override."""
|
||||||
|
from shared.api.chat import chat_router
|
||||||
|
from shared.db import get_session
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(chat_router)
|
||||||
|
|
||||||
|
async def _override_get_session(): # type: ignore[return]
|
||||||
|
yield mock_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _override_get_session
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RBAC enforcement on list conversations
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_rbac_enforcement() -> None:
|
||||||
|
"""Non-member caller gets 403 when listing conversations for a tenant they don't belong to."""
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
|
||||||
|
# Mock session — no membership row found (require_tenant_member checks UserTenantRole)
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
|
||||||
|
|
||||||
|
app = _make_app_with_session_override(mock_session)
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.get(
|
||||||
|
"/api/portal/chat/conversations",
|
||||||
|
params={"tenant_id": str(tenant_id)},
|
||||||
|
headers=_stranger_headers(str(user_id)),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_platform_admin_cross_tenant() -> None:
|
||||||
|
"""Platform admin can list conversations for any tenant (bypasses membership check)."""
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
|
||||||
|
# Mock session — returns empty rows for conversation query
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.all.return_value = []
|
||||||
|
mock_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
app = _make_app_with_session_override(mock_session)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("shared.api.chat.configure_rls_hook"),
|
||||||
|
patch("shared.api.chat.current_tenant_id"),
|
||||||
|
):
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.get(
|
||||||
|
"/api/portal/chat/conversations",
|
||||||
|
params={"tenant_id": str(tenant_id)},
|
||||||
|
headers=_admin_headers(str(user_id)),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert isinstance(response.json(), list)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# List conversation history (paginated messages)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_conversation_history() -> None:
|
||||||
|
"""GET /api/portal/chat/conversations/{id}/messages returns paginated messages ordered by created_at."""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
conv_id = uuid.uuid4()
|
||||||
|
|
||||||
|
# Mock conversation owned by the caller
|
||||||
|
mock_conv = MagicMock()
|
||||||
|
mock_conv.id = conv_id
|
||||||
|
mock_conv.user_id = user_id
|
||||||
|
mock_conv.tenant_id = uuid.uuid4()
|
||||||
|
|
||||||
|
# Mock messages
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
mock_msg1 = MagicMock()
|
||||||
|
mock_msg1.id = uuid.uuid4()
|
||||||
|
mock_msg1.role = "user"
|
||||||
|
mock_msg1.content = "Hello"
|
||||||
|
mock_msg1.created_at = now
|
||||||
|
|
||||||
|
mock_msg2 = MagicMock()
|
||||||
|
mock_msg2.id = uuid.uuid4()
|
||||||
|
mock_msg2.role = "assistant"
|
||||||
|
mock_msg2.content = "Hi there!"
|
||||||
|
mock_msg2.created_at = now
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
# First call: fetch conversation; second call: fetch messages
|
||||||
|
mock_session.execute.side_effect = [
|
||||||
|
MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv)),
|
||||||
|
MagicMock(scalars=MagicMock(return_value=MagicMock(all=MagicMock(return_value=[mock_msg1, mock_msg2])))),
|
||||||
|
]
|
||||||
|
|
||||||
|
app = _make_app_with_session_override(mock_session)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("shared.api.chat.configure_rls_hook"),
|
||||||
|
patch("shared.api.chat.current_tenant_id"),
|
||||||
|
):
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/portal/chat/conversations/{conv_id}/messages",
|
||||||
|
headers=_admin_headers(str(user_id)),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
assert len(data) == 2
|
||||||
|
assert data[0]["role"] == "user"
|
||||||
|
assert data[1]["role"] == "assistant"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Create conversation (get-or-create)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_conversation() -> None:
|
||||||
|
"""POST /api/portal/chat/conversations creates a new conversation for user+agent pair."""
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
agent_id = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
conv_id = uuid.uuid4()
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Platform admin bypasses membership check; no existing conversation found
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
|
||||||
|
mock_session.flush = AsyncMock()
|
||||||
|
mock_session.commit = AsyncMock()
|
||||||
|
mock_session.add = MagicMock()
|
||||||
|
|
||||||
|
# refresh populates server-default fields on the passed ORM object
|
||||||
|
async def _mock_refresh(obj: object) -> None:
|
||||||
|
obj.id = conv_id # type: ignore[attr-defined]
|
||||||
|
obj.created_at = now # type: ignore[attr-defined]
|
||||||
|
obj.updated_at = now # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
mock_session.refresh = _mock_refresh
|
||||||
|
|
||||||
|
app = _make_app_with_session_override(mock_session)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("shared.api.chat.configure_rls_hook"),
|
||||||
|
patch("shared.api.chat.current_tenant_id"),
|
||||||
|
):
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/portal/chat/conversations",
|
||||||
|
json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)},
|
||||||
|
headers=_admin_headers(str(user_id)),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code in (200, 201)
|
||||||
|
data = response.json()
|
||||||
|
assert "id" in data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_conversation_rbac_forbidden() -> None:
|
||||||
|
"""Non-member gets 403 when creating a conversation in a tenant they don't belong to."""
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
agent_id = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
|
||||||
|
# Membership check returns None (not a member)
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
|
||||||
|
|
||||||
|
app = _make_app_with_session_override(mock_session)
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/portal/chat/conversations",
|
||||||
|
json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)},
|
||||||
|
headers=_stranger_headers(str(user_id)),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Delete conversation (reset messages)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_conversation_resets_messages() -> None:
|
||||||
|
"""DELETE /api/portal/chat/conversations/{id} deletes messages but keeps conversation row."""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
conv_id = uuid.uuid4()
|
||||||
|
|
||||||
|
mock_conv = MagicMock()
|
||||||
|
mock_conv.id = conv_id
|
||||||
|
mock_conv.user_id = user_id
|
||||||
|
mock_conv.tenant_id = uuid.uuid4()
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv))
|
||||||
|
mock_session.commit = AsyncMock()
|
||||||
|
|
||||||
|
app = _make_app_with_session_override(mock_session)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("shared.api.chat.configure_rls_hook"),
|
||||||
|
patch("shared.api.chat.current_tenant_id"),
|
||||||
|
):
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/portal/chat/conversations/{conv_id}",
|
||||||
|
headers=_admin_headers(str(user_id)),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert mock_session.execute.call_count >= 1
|
||||||
312
tests/unit/test_web_channel.py
Normal file
312
tests/unit/test_web_channel.py
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the web channel adapter.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- test_normalize_web_event: normalize_web_event returns KonstructMessage with channel=WEB
|
||||||
|
- test_normalize_web_event_thread_id: thread_id equals conversation_id
|
||||||
|
- test_normalize_web_event_sender: sender.user_id equals portal user UUID
|
||||||
|
- test_webchat_response_key: webchat_response_key returns correct namespaced key
|
||||||
|
- test_send_response_web_publishes_to_redis: _send_response("web", ...) publishes JSON to Redis
|
||||||
|
- test_typing_indicator_sent: WebSocket handler sends {"type": "typing"} before Celery dispatch
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from shared.models.message import ChannelType, KonstructMessage
|
||||||
|
from shared.redis_keys import webchat_response_key
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test webchat_response_key
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_webchat_response_key_format() -> None:
|
||||||
|
"""webchat_response_key returns correctly namespaced key."""
|
||||||
|
tenant_id = "tenant-abc"
|
||||||
|
conversation_id = "conv-xyz"
|
||||||
|
key = webchat_response_key(tenant_id, conversation_id)
|
||||||
|
assert key == "tenant-abc:webchat:response:conv-xyz"
|
||||||
|
|
||||||
|
|
||||||
|
def test_webchat_response_key_isolation() -> None:
|
||||||
|
"""Two tenants with same conversation_id get different keys."""
|
||||||
|
key_a = webchat_response_key("tenant-a", "conv-1")
|
||||||
|
key_b = webchat_response_key("tenant-b", "conv-1")
|
||||||
|
assert key_a != key_b
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test ChannelType.WEB
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_type_web_exists() -> None:
|
||||||
|
"""ChannelType.WEB must exist with value 'web'."""
|
||||||
|
assert ChannelType.WEB == "web"
|
||||||
|
assert ChannelType.WEB.value == "web"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test normalize_web_event
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_web_event_returns_konstruct_message() -> None:
|
||||||
|
"""normalize_web_event returns a KonstructMessage."""
|
||||||
|
from gateway.channels.web import normalize_web_event
|
||||||
|
|
||||||
|
tenant_id = str(uuid.uuid4())
|
||||||
|
agent_id = str(uuid.uuid4())
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
conversation_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
event = {
|
||||||
|
"text": "Hello from the portal",
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"display_name": "Portal User",
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
}
|
||||||
|
msg = normalize_web_event(event)
|
||||||
|
assert isinstance(msg, KonstructMessage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_web_event_channel_is_web() -> None:
|
||||||
|
"""normalize_web_event sets channel = ChannelType.WEB."""
|
||||||
|
from gateway.channels.web import normalize_web_event
|
||||||
|
|
||||||
|
event = {
|
||||||
|
"text": "test",
|
||||||
|
"tenant_id": str(uuid.uuid4()),
|
||||||
|
"agent_id": str(uuid.uuid4()),
|
||||||
|
"user_id": str(uuid.uuid4()),
|
||||||
|
"display_name": "User",
|
||||||
|
"conversation_id": str(uuid.uuid4()),
|
||||||
|
}
|
||||||
|
msg = normalize_web_event(event)
|
||||||
|
assert msg.channel == ChannelType.WEB
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_web_event_thread_id_equals_conversation_id() -> None:
|
||||||
|
"""normalize_web_event sets thread_id = conversation_id for memory scoping."""
|
||||||
|
from gateway.channels.web import normalize_web_event
|
||||||
|
|
||||||
|
conversation_id = str(uuid.uuid4())
|
||||||
|
event = {
|
||||||
|
"text": "test",
|
||||||
|
"tenant_id": str(uuid.uuid4()),
|
||||||
|
"agent_id": str(uuid.uuid4()),
|
||||||
|
"user_id": str(uuid.uuid4()),
|
||||||
|
"display_name": "User",
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
}
|
||||||
|
msg = normalize_web_event(event)
|
||||||
|
assert msg.thread_id == conversation_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_web_event_sender_user_id() -> None:
|
||||||
|
"""normalize_web_event sets sender.user_id to the portal user UUID."""
|
||||||
|
from gateway.channels.web import normalize_web_event
|
||||||
|
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
event = {
|
||||||
|
"text": "test",
|
||||||
|
"tenant_id": str(uuid.uuid4()),
|
||||||
|
"agent_id": str(uuid.uuid4()),
|
||||||
|
"user_id": user_id,
|
||||||
|
"display_name": "Portal User",
|
||||||
|
"conversation_id": str(uuid.uuid4()),
|
||||||
|
}
|
||||||
|
msg = normalize_web_event(event)
|
||||||
|
assert msg.sender.user_id == user_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_web_event_channel_metadata() -> None:
|
||||||
|
"""normalize_web_event populates channel_metadata with portal_user_id, tenant_id, conversation_id."""
|
||||||
|
from gateway.channels.web import normalize_web_event
|
||||||
|
|
||||||
|
tenant_id = str(uuid.uuid4())
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
conversation_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
event = {
|
||||||
|
"text": "test",
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"agent_id": str(uuid.uuid4()),
|
||||||
|
"user_id": user_id,
|
||||||
|
"display_name": "User",
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
}
|
||||||
|
msg = normalize_web_event(event)
|
||||||
|
assert msg.channel_metadata["portal_user_id"] == user_id
|
||||||
|
assert msg.channel_metadata["tenant_id"] == tenant_id
|
||||||
|
assert msg.channel_metadata["conversation_id"] == conversation_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_web_event_tenant_id() -> None:
|
||||||
|
"""normalize_web_event sets tenant_id on the message."""
|
||||||
|
from gateway.channels.web import normalize_web_event
|
||||||
|
|
||||||
|
tenant_id = str(uuid.uuid4())
|
||||||
|
event = {
|
||||||
|
"text": "hello",
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"agent_id": str(uuid.uuid4()),
|
||||||
|
"user_id": str(uuid.uuid4()),
|
||||||
|
"display_name": "User",
|
||||||
|
"conversation_id": str(uuid.uuid4()),
|
||||||
|
}
|
||||||
|
msg = normalize_web_event(event)
|
||||||
|
assert msg.tenant_id == tenant_id
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test _send_response web case publishes to Redis
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_response_web_publishes_to_redis() -> None:
|
||||||
|
"""_send_response('web', ...) publishes JSON message to Redis webchat channel."""
|
||||||
|
tenant_id = str(uuid.uuid4())
|
||||||
|
conversation_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.publish = AsyncMock()
|
||||||
|
mock_redis.aclose = AsyncMock()
|
||||||
|
|
||||||
|
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
|
||||||
|
from orchestrator.tasks import _send_response
|
||||||
|
|
||||||
|
extras = {
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
}
|
||||||
|
await _send_response("web", "Hello, this is the agent response!", extras)
|
||||||
|
|
||||||
|
expected_channel = webchat_response_key(tenant_id, conversation_id)
|
||||||
|
mock_redis.publish.assert_called_once()
|
||||||
|
call_args = mock_redis.publish.call_args
|
||||||
|
assert call_args[0][0] == expected_channel
|
||||||
|
published_payload = json.loads(call_args[0][1])
|
||||||
|
assert published_payload["type"] == "response"
|
||||||
|
assert published_payload["text"] == "Hello, this is the agent response!"
|
||||||
|
assert published_payload["conversation_id"] == conversation_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_response_web_connection_cleanup() -> None:
|
||||||
|
"""_send_response web case always closes Redis connection (try/finally)."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.publish = AsyncMock()
|
||||||
|
mock_redis.aclose = AsyncMock()
|
||||||
|
|
||||||
|
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
|
||||||
|
from orchestrator.tasks import _send_response
|
||||||
|
|
||||||
|
await _send_response("web", "test", {"conversation_id": "c1", "tenant_id": "t1"})
|
||||||
|
|
||||||
|
mock_redis.aclose.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_response_web_missing_conversation_id_logs_warning() -> None:
|
||||||
|
"""_send_response web case logs warning if conversation_id missing."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
|
||||||
|
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
|
||||||
|
with patch("orchestrator.tasks.logger") as mock_logger:
|
||||||
|
from orchestrator.tasks import _send_response
|
||||||
|
|
||||||
|
await _send_response("web", "test", {"tenant_id": "t1"})
|
||||||
|
|
||||||
|
mock_logger.warning.assert_called()
|
||||||
|
mock_redis.publish.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test typing indicator
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_typing_indicator_sent_before_dispatch() -> None:
|
||||||
|
"""WebSocket handler sends {'type': 'typing'} immediately after receiving user message."""
|
||||||
|
# We test the typing indicator by calling the handler function directly
|
||||||
|
# with a mocked WebSocket and mocked Celery dispatch.
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||||
|
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
# First receive_json returns auth message, second returns the user message
|
||||||
|
mock_ws.receive_json = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
{"type": "auth", "userId": str(uuid.uuid4()), "role": "customer_operator", "tenantId": str(uuid.uuid4())},
|
||||||
|
{"type": "message", "text": "Hello agent", "agentId": str(uuid.uuid4()), "conversationId": str(uuid.uuid4())},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# accept must be awaitable
|
||||||
|
mock_ws.accept = AsyncMock()
|
||||||
|
mock_ws.send_json = AsyncMock()
|
||||||
|
|
||||||
|
# Mock DB session that returns a conversation
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_conv = MagicMock()
|
||||||
|
mock_conv.id = uuid.uuid4()
|
||||||
|
mock_conv.tenant_id = uuid.uuid4()
|
||||||
|
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv))
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
# Mock Redis pub-sub: raise after publishing once so handler exits
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_pubsub = AsyncMock()
|
||||||
|
mock_pubsub.subscribe = AsyncMock()
|
||||||
|
mock_pubsub.get_message = AsyncMock(return_value={
|
||||||
|
"type": "message",
|
||||||
|
"data": json.dumps({
|
||||||
|
"type": "response",
|
||||||
|
"text": "Agent reply",
|
||||||
|
"conversation_id": str(uuid.uuid4()),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
mock_pubsub.unsubscribe = AsyncMock()
|
||||||
|
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
||||||
|
mock_redis.aclose = AsyncMock()
|
||||||
|
|
||||||
|
mock_handle = MagicMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("gateway.channels.web.async_session_factory", return_value=mock_session),
|
||||||
|
patch("gateway.channels.web.aioredis.from_url", return_value=mock_redis),
|
||||||
|
patch("gateway.channels.web.handle_message") as mock_handle_msg,
|
||||||
|
patch("gateway.channels.web.configure_rls_hook"),
|
||||||
|
patch("gateway.channels.web.current_tenant_id"),
|
||||||
|
):
|
||||||
|
mock_handle_msg.delay = MagicMock()
|
||||||
|
|
||||||
|
from gateway.channels.web import _handle_websocket_connection
|
||||||
|
|
||||||
|
# Run the handler — it should send typing indicator then process
|
||||||
|
# Use asyncio.wait_for to prevent infinite loop if something hangs
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
_handle_websocket_connection(mock_ws, str(uuid.uuid4())),
|
||||||
|
timeout=2.0,
|
||||||
|
)
|
||||||
|
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check typing was sent before Celery dispatch
|
||||||
|
send_calls = mock_ws.send_json.call_args_list
|
||||||
|
typing_calls = [c for c in send_calls if c[0][0].get("type") == "typing"]
|
||||||
|
assert len(typing_calls) >= 1, f"Expected typing indicator, got send_json calls: {send_calls}"
|
||||||
Reference in New Issue
Block a user