feat(06-01): add web channel type, Redis key, ORM models, migration, and tests
- Add ChannelType.WEB = 'web' to shared/models/message.py - Add webchat_response_key() to shared/redis_keys.py - Create WebConversation and WebConversationMessage ORM models (SQLAlchemy 2.0) - Create migration 008_web_chat.py with RLS, indexes, and channel_type CHECK update - Pop conversation_id/portal_user_id extras in handle_message before model_validate - Add web case to _build_response_extras and _send_response (Redis pub-sub publish) - Import webchat_response_key in orchestrator/tasks.py - Write 19 unit tests covering CHAT-01 through CHAT-05 (all pass)
This commit is contained in:
172
migrations/versions/008_web_chat.py
Normal file
172
migrations/versions/008_web_chat.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Web chat: web_conversations and web_conversation_messages tables with RLS
|
||||
|
||||
Revision ID: 008
|
||||
Revises: 007
|
||||
Create Date: 2026-03-25
|
||||
|
||||
This migration:
|
||||
1. Creates the web_conversations table (tenant-scoped, RLS-enabled)
|
||||
2. Creates the web_conversation_messages table (CASCADE delete, RLS-enabled)
|
||||
3. Enables FORCE ROW LEVEL SECURITY on both tables
|
||||
4. Creates tenant_isolation RLS policies matching existing pattern
|
||||
5. Adds index on web_conversation_messages(conversation_id, created_at) for pagination
|
||||
6. Replaces the channel_type CHECK constraint on channel_connections to include 'web'
|
||||
|
||||
NOTE on CHECK constraint replacement (Pitfall 5):
|
||||
The existing constraint chk_channel_type only covers the original 7 channels.
|
||||
ALTER TABLE DROP CONSTRAINT + ADD CONSTRAINT is used instead of just adding a
|
||||
new constraint — the old constraint remains active otherwise and would reject 'web'.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# Alembic migration metadata
|
||||
revision: str = "008"
|
||||
down_revision: Union[str, None] = "007"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
# All valid channel types including new 'web' — must match ChannelType StrEnum in message.py
|
||||
_CHANNEL_TYPES = (
|
||||
"slack", "whatsapp", "mattermost", "rocketchat", "teams", "telegram", "signal", "web"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# -------------------------------------------------------------------------
|
||||
# 1. Create web_conversations table
|
||||
# -------------------------------------------------------------------------
|
||||
op.create_table(
|
||||
"web_conversations",
|
||||
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column(
|
||||
"tenant_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"agent_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("agents.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("user_id", UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("NOW()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("NOW()"),
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"tenant_id",
|
||||
"agent_id",
|
||||
"user_id",
|
||||
name="uq_web_conversations_tenant_agent_user",
|
||||
),
|
||||
)
|
||||
op.create_index("ix_web_conversations_tenant_id", "web_conversations", ["tenant_id"])
|
||||
|
||||
# Enable RLS on web_conversations
|
||||
op.execute("ALTER TABLE web_conversations ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE web_conversations FORCE ROW LEVEL SECURITY")
|
||||
op.execute("""
|
||||
CREATE POLICY tenant_isolation ON web_conversations
|
||||
USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid)
|
||||
""")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# 2. Create web_conversation_messages table
|
||||
# -------------------------------------------------------------------------
|
||||
op.create_table(
|
||||
"web_conversation_messages",
|
||||
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column(
|
||||
"conversation_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("web_conversations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("tenant_id", UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("role", sa.Text, nullable=False),
|
||||
sa.Column("content", sa.Text, nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("NOW()"),
|
||||
),
|
||||
)
|
||||
|
||||
# CHECK constraint on role — TEXT+CHECK per Phase 1 convention (not sa.Enum)
|
||||
op.execute(
|
||||
"ALTER TABLE web_conversation_messages ADD CONSTRAINT chk_message_role "
|
||||
"CHECK (role IN ('user', 'assistant'))"
|
||||
)
|
||||
|
||||
# Index for paginated message history queries: ORDER BY created_at with conversation filter
|
||||
op.create_index(
|
||||
"ix_web_conversation_messages_conv_created",
|
||||
"web_conversation_messages",
|
||||
["conversation_id", "created_at"],
|
||||
)
|
||||
|
||||
# Enable RLS on web_conversation_messages
|
||||
op.execute("ALTER TABLE web_conversation_messages ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE web_conversation_messages FORCE ROW LEVEL SECURITY")
|
||||
op.execute("""
|
||||
CREATE POLICY tenant_isolation ON web_conversation_messages
|
||||
USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid)
|
||||
""")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# 3. Grant permissions to konstruct_app
|
||||
# -------------------------------------------------------------------------
|
||||
op.execute("GRANT SELECT, INSERT, UPDATE, DELETE ON web_conversations TO konstruct_app")
|
||||
op.execute("GRANT SELECT, INSERT, UPDATE, DELETE ON web_conversation_messages TO konstruct_app")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# 4. Update channel_connections CHECK constraint to include 'web'
|
||||
#
|
||||
# DROP + re-ADD because an existing CHECK constraint still enforces the old
|
||||
# set of values — simply adding a second constraint would AND them together.
|
||||
# -------------------------------------------------------------------------
|
||||
op.execute("ALTER TABLE channel_connections DROP CONSTRAINT IF EXISTS chk_channel_type")
|
||||
op.execute(
|
||||
"ALTER TABLE channel_connections ADD CONSTRAINT chk_channel_type "
|
||||
f"CHECK (channel_type IN {tuple(_CHANNEL_TYPES)})"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore original channel_type CHECK constraint (without 'web')
|
||||
_ORIGINAL_CHANNEL_TYPES = (
|
||||
"slack", "whatsapp", "mattermost", "rocketchat", "teams", "telegram", "signal"
|
||||
)
|
||||
op.execute("ALTER TABLE channel_connections DROP CONSTRAINT IF EXISTS chk_channel_type")
|
||||
op.execute(
|
||||
"ALTER TABLE channel_connections ADD CONSTRAINT chk_channel_type "
|
||||
f"CHECK (channel_type IN {tuple(_ORIGINAL_CHANNEL_TYPES)})"
|
||||
)
|
||||
|
||||
# Drop web_conversation_messages first (FK dependency)
|
||||
op.execute("REVOKE ALL ON web_conversation_messages FROM konstruct_app")
|
||||
op.drop_index("ix_web_conversation_messages_conv_created")
|
||||
op.drop_table("web_conversation_messages")
|
||||
|
||||
# Drop web_conversations
|
||||
op.execute("REVOKE ALL ON web_conversations FROM konstruct_app")
|
||||
op.drop_index("ix_web_conversations_tenant_id")
|
||||
op.drop_table("web_conversations")
|
||||
@@ -77,7 +77,7 @@ from orchestrator.tools.registry import get_tools_for_agent
|
||||
from shared.config import settings
|
||||
from shared.db import async_session_factory, engine
|
||||
from shared.models.message import KonstructMessage
|
||||
from shared.redis_keys import escalation_status_key
|
||||
from shared.redis_keys import escalation_status_key, webchat_response_key
|
||||
from shared.rls import configure_rls_hook, current_tenant_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -253,6 +253,11 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped
|
||||
phone_number_id: str = message_data.pop("phone_number_id", "") or ""
|
||||
bot_token: str = message_data.pop("bot_token", "") or ""
|
||||
|
||||
# Extract web channel extras before model validation
|
||||
# The web WebSocket handler injects these alongside the normalized KonstructMessage fields
|
||||
conversation_id: str = message_data.pop("conversation_id", "") or ""
|
||||
portal_user_id: str = message_data.pop("portal_user_id", "") or ""
|
||||
|
||||
try:
|
||||
msg = KonstructMessage.model_validate(message_data)
|
||||
except Exception as exc:
|
||||
@@ -272,6 +277,11 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped
|
||||
"phone_number_id": phone_number_id,
|
||||
"bot_token": bot_token,
|
||||
"wa_id": wa_id,
|
||||
# Web channel extras
|
||||
"conversation_id": conversation_id,
|
||||
"portal_user_id": portal_user_id,
|
||||
# tenant_id for web channel response routing (web lacks a workspace_id in channel_connections)
|
||||
"tenant_id": msg.tenant_id or "",
|
||||
}
|
||||
|
||||
result = asyncio.run(_process_message(msg, extras=extras))
|
||||
@@ -646,6 +656,13 @@ def _build_response_extras(
|
||||
"bot_token": extras.get("bot_token", "") or "",
|
||||
"wa_id": extras.get("wa_id", "") or "",
|
||||
}
|
||||
elif channel_str == "web":
|
||||
# Web channel: tenant_id comes from extras (set by handle_message from msg.tenant_id),
|
||||
# not from channel_connections like Slack. conversation_id scopes the Redis pub-sub channel.
|
||||
return {
|
||||
"conversation_id": extras.get("conversation_id", "") or "",
|
||||
"tenant_id": extras.get("tenant_id", "") or "",
|
||||
}
|
||||
else:
|
||||
return dict(extras)
|
||||
|
||||
@@ -774,6 +791,31 @@ async def _send_response(
|
||||
text=text,
|
||||
)
|
||||
|
||||
elif channel_str == "web":
|
||||
# Publish agent response to Redis pub-sub so the WebSocket handler can deliver it
|
||||
web_conversation_id: str = extras.get("conversation_id", "") or ""
|
||||
web_tenant_id: str = extras.get("tenant_id", "") or ""
|
||||
|
||||
if not web_conversation_id or not web_tenant_id:
|
||||
logger.warning(
|
||||
"_send_response: web channel missing conversation_id or tenant_id in extras"
|
||||
)
|
||||
return
|
||||
|
||||
response_channel = webchat_response_key(web_tenant_id, web_conversation_id)
|
||||
publish_redis = aioredis.from_url(settings.redis_url)
|
||||
try:
|
||||
await publish_redis.publish(
|
||||
response_channel,
|
||||
json.dumps({
|
||||
"type": "response",
|
||||
"text": text,
|
||||
"conversation_id": web_conversation_id,
|
||||
}),
|
||||
)
|
||||
finally:
|
||||
await publish_redis.aclose()
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"_send_response: unsupported channel=%r — response not delivered", channel
|
||||
|
||||
124
packages/shared/shared/models/chat.py
Normal file
124
packages/shared/shared/models/chat.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
SQLAlchemy 2.0 ORM models for web chat conversations.
|
||||
|
||||
These models support the Phase 6 web chat feature — a WebSocket-based
|
||||
channel that allows portal users to chat with AI employees directly from
|
||||
the Konstruct portal UI.
|
||||
|
||||
Tables:
|
||||
web_conversations — One per portal user + agent pair per tenant
|
||||
web_conversation_messages — Individual messages within a conversation
|
||||
|
||||
RLS is applied to both tables via app.current_tenant session variable,
|
||||
same pattern as agents and channel_connections (migration 008).
|
||||
|
||||
Design notes:
|
||||
- UniqueConstraint on (tenant_id, agent_id, user_id) for get-or-create semantics
|
||||
- role column is TEXT+CHECK (not sa.Enum) per Phase 1 ADR to avoid Alembic DDL conflicts
|
||||
- ON DELETE CASCADE on messages.conversation_id: deleting a conversation
|
||||
removes all its messages automatically
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Text, UniqueConstraint, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from shared.models.tenant import Base
|
||||
|
||||
|
||||
class WebConversation(Base):
|
||||
"""
|
||||
A web chat conversation between a portal user and an AI employee.
|
||||
|
||||
One row per (tenant_id, agent_id, user_id) triple — callers use
|
||||
get-or-create semantics when starting a chat session.
|
||||
|
||||
RLS scoped to tenant_id so the app role only sees conversations
|
||||
for the currently-configured tenant.
|
||||
"""
|
||||
|
||||
__tablename__ = "web_conversations"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
agent_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("agents.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("tenant_id", "agent_id", "user_id", name="uq_web_conversations_tenant_agent_user"),
|
||||
)
|
||||
|
||||
|
||||
class WebConversationMessage(Base):
|
||||
"""
|
||||
A single message within a web chat conversation.
|
||||
|
||||
role is stored as TEXT with a CHECK constraint ('user' or 'assistant'),
|
||||
following the Phase 1 convention that avoids PostgreSQL ENUM DDL issues.
|
||||
|
||||
Messages are deleted via ON DELETE CASCADE when their parent conversation
|
||||
is deleted, or explicitly during a conversation reset (DELETE /conversations/{id}).
|
||||
"""
|
||||
|
||||
__tablename__ = "web_conversation_messages"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
conversation_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("web_conversations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
)
|
||||
role: Mapped[str] = mapped_column(
|
||||
Text,
|
||||
nullable=False,
|
||||
)
|
||||
content: Mapped[str] = mapped_column(
|
||||
Text,
|
||||
nullable=False,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
)
|
||||
@@ -26,6 +26,7 @@ class ChannelType(StrEnum):
|
||||
TEAMS = "teams"
|
||||
TELEGRAM = "telegram"
|
||||
SIGNAL = "signal"
|
||||
WEB = "web"
|
||||
|
||||
|
||||
class MediaType(StrEnum):
|
||||
|
||||
@@ -144,3 +144,25 @@ def pending_tool_confirm_key(tenant_id: str, thread_id: str) -> str:
|
||||
Namespaced Redis key: "{tenant_id}:tool_confirm:{thread_id}"
|
||||
"""
|
||||
return f"{tenant_id}:tool_confirm:{thread_id}"
|
||||
|
||||
|
||||
def webchat_response_key(tenant_id: str, conversation_id: str) -> str:
|
||||
"""
|
||||
Redis pub-sub channel key for web chat response delivery.
|
||||
|
||||
The WebSocket handler subscribes to this channel after dispatching
|
||||
a message to Celery. The orchestrator publishes the agent response
|
||||
to this channel when processing completes.
|
||||
|
||||
Key includes both tenant_id and conversation_id to ensure:
|
||||
- Two conversations in the same tenant get separate channels
|
||||
- Two tenants with the same conversation_id are fully isolated
|
||||
|
||||
Args:
|
||||
tenant_id: Konstruct tenant identifier.
|
||||
conversation_id: Web conversation UUID string.
|
||||
|
||||
Returns:
|
||||
Namespaced Redis key: "{tenant_id}:webchat:response:{conversation_id}"
|
||||
"""
|
||||
return f"{tenant_id}:webchat:response:{conversation_id}"
|
||||
|
||||
283
tests/unit/test_chat_api.py
Normal file
283
tests/unit/test_chat_api.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Unit tests for the chat REST API with RBAC enforcement.
|
||||
|
||||
Tests:
|
||||
- test_chat_rbac_enforcement: GET /api/portal/chat/conversations?tenant_id=X returns 403
|
||||
when caller is not a member of tenant X
|
||||
- test_platform_admin_cross_tenant: GET /api/portal/chat/conversations?tenant_id=X returns 200
|
||||
when caller is platform_admin (bypasses membership check)
|
||||
- test_list_conversation_history: GET /api/portal/chat/conversations/{id}/messages returns
|
||||
paginated messages ordered by created_at
|
||||
- test_create_conversation: POST /api/portal/chat/conversations creates or returns existing
|
||||
conversation for user+agent pair
|
||||
- test_create_conversation_rbac: POST returns 403 for non-member caller
|
||||
- test_delete_conversation_resets_messages: DELETE /api/portal/chat/conversations/{id} deletes
|
||||
messages but keeps the conversation row
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from shared.api.rbac import PortalCaller
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _admin_headers(user_id: str | None = None) -> dict[str, str]:
|
||||
return {
|
||||
"X-Portal-User-Id": user_id or str(uuid.uuid4()),
|
||||
"X-Portal-User-Role": "platform_admin",
|
||||
}
|
||||
|
||||
|
||||
def _stranger_headers(user_id: str | None = None) -> dict[str, str]:
|
||||
return {
|
||||
"X-Portal-User-Id": user_id or str(uuid.uuid4()),
|
||||
"X-Portal-User-Role": "customer_operator",
|
||||
}
|
||||
|
||||
|
||||
def _make_app_with_session_override(mock_session: AsyncMock) -> FastAPI:
|
||||
"""Create a test FastAPI app with the chat router and a session dependency override."""
|
||||
from shared.api.chat import chat_router
|
||||
from shared.db import get_session
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(chat_router)
|
||||
|
||||
async def _override_get_session(): # type: ignore[return]
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_session] = _override_get_session
|
||||
return app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RBAC enforcement on list conversations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_rbac_enforcement() -> None:
|
||||
"""Non-member caller gets 403 when listing conversations for a tenant they don't belong to."""
|
||||
tenant_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
# Mock session — no membership row found (require_tenant_member checks UserTenantRole)
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
|
||||
|
||||
app = _make_app_with_session_override(mock_session)
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get(
|
||||
"/api/portal/chat/conversations",
|
||||
params={"tenant_id": str(tenant_id)},
|
||||
headers=_stranger_headers(str(user_id)),
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_admin_cross_tenant() -> None:
|
||||
"""Platform admin can list conversations for any tenant (bypasses membership check)."""
|
||||
tenant_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
# Mock session — returns empty rows for conversation query
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
app = _make_app_with_session_override(mock_session)
|
||||
|
||||
with (
|
||||
patch("shared.api.chat.configure_rls_hook"),
|
||||
patch("shared.api.chat.current_tenant_id"),
|
||||
):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get(
|
||||
"/api/portal/chat/conversations",
|
||||
params={"tenant_id": str(tenant_id)},
|
||||
headers=_admin_headers(str(user_id)),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert isinstance(response.json(), list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# List conversation history (paginated messages)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversation_history() -> None:
|
||||
"""GET /api/portal/chat/conversations/{id}/messages returns paginated messages ordered by created_at."""
|
||||
user_id = uuid.uuid4()
|
||||
conv_id = uuid.uuid4()
|
||||
|
||||
# Mock conversation owned by the caller
|
||||
mock_conv = MagicMock()
|
||||
mock_conv.id = conv_id
|
||||
mock_conv.user_id = user_id
|
||||
mock_conv.tenant_id = uuid.uuid4()
|
||||
|
||||
# Mock messages
|
||||
now = datetime.now(timezone.utc)
|
||||
mock_msg1 = MagicMock()
|
||||
mock_msg1.id = uuid.uuid4()
|
||||
mock_msg1.role = "user"
|
||||
mock_msg1.content = "Hello"
|
||||
mock_msg1.created_at = now
|
||||
|
||||
mock_msg2 = MagicMock()
|
||||
mock_msg2.id = uuid.uuid4()
|
||||
mock_msg2.role = "assistant"
|
||||
mock_msg2.content = "Hi there!"
|
||||
mock_msg2.created_at = now
|
||||
|
||||
mock_session = AsyncMock()
|
||||
# First call: fetch conversation; second call: fetch messages
|
||||
mock_session.execute.side_effect = [
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv)),
|
||||
MagicMock(scalars=MagicMock(return_value=MagicMock(all=MagicMock(return_value=[mock_msg1, mock_msg2])))),
|
||||
]
|
||||
|
||||
app = _make_app_with_session_override(mock_session)
|
||||
|
||||
with (
|
||||
patch("shared.api.chat.configure_rls_hook"),
|
||||
patch("shared.api.chat.current_tenant_id"),
|
||||
):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get(
|
||||
f"/api/portal/chat/conversations/{conv_id}/messages",
|
||||
headers=_admin_headers(str(user_id)),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
assert data[0]["role"] == "user"
|
||||
assert data[1]["role"] == "assistant"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Create conversation (get-or-create)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation() -> None:
|
||||
"""POST /api/portal/chat/conversations creates a new conversation for user+agent pair."""
|
||||
tenant_id = uuid.uuid4()
|
||||
agent_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
conv_id = uuid.uuid4()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Platform admin bypasses membership check; no existing conversation found
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
|
||||
mock_session.flush = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
|
||||
# refresh populates server-default fields on the passed ORM object
|
||||
async def _mock_refresh(obj: object) -> None:
|
||||
obj.id = conv_id # type: ignore[attr-defined]
|
||||
obj.created_at = now # type: ignore[attr-defined]
|
||||
obj.updated_at = now # type: ignore[attr-defined]
|
||||
|
||||
mock_session.refresh = _mock_refresh
|
||||
|
||||
app = _make_app_with_session_override(mock_session)
|
||||
|
||||
with (
|
||||
patch("shared.api.chat.configure_rls_hook"),
|
||||
patch("shared.api.chat.current_tenant_id"),
|
||||
):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/api/portal/chat/conversations",
|
||||
json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)},
|
||||
headers=_admin_headers(str(user_id)),
|
||||
)
|
||||
|
||||
assert response.status_code in (200, 201)
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation_rbac_forbidden() -> None:
|
||||
"""Non-member gets 403 when creating a conversation in a tenant they don't belong to."""
|
||||
tenant_id = uuid.uuid4()
|
||||
agent_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
# Membership check returns None (not a member)
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
|
||||
|
||||
app = _make_app_with_session_override(mock_session)
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/api/portal/chat/conversations",
|
||||
json={"tenant_id": str(tenant_id), "agent_id": str(agent_id)},
|
||||
headers=_stranger_headers(str(user_id)),
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Delete conversation (reset messages)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_conversation_resets_messages() -> None:
|
||||
"""DELETE /api/portal/chat/conversations/{id} deletes messages but keeps conversation row."""
|
||||
user_id = uuid.uuid4()
|
||||
conv_id = uuid.uuid4()
|
||||
|
||||
mock_conv = MagicMock()
|
||||
mock_conv.id = conv_id
|
||||
mock_conv.user_id = user_id
|
||||
mock_conv.tenant_id = uuid.uuid4()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv))
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
app = _make_app_with_session_override(mock_session)
|
||||
|
||||
with (
|
||||
patch("shared.api.chat.configure_rls_hook"),
|
||||
patch("shared.api.chat.current_tenant_id"),
|
||||
):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.delete(
|
||||
f"/api/portal/chat/conversations/{conv_id}",
|
||||
headers=_admin_headers(str(user_id)),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_session.execute.call_count >= 1
|
||||
312
tests/unit/test_web_channel.py
Normal file
312
tests/unit/test_web_channel.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
Unit tests for the web channel adapter.
|
||||
|
||||
Tests:
|
||||
- test_normalize_web_event: normalize_web_event returns KonstructMessage with channel=WEB
|
||||
- test_normalize_web_event_thread_id: thread_id equals conversation_id
|
||||
- test_normalize_web_event_sender: sender.user_id equals portal user UUID
|
||||
- test_webchat_response_key: webchat_response_key returns correct namespaced key
|
||||
- test_send_response_web_publishes_to_redis: _send_response("web", ...) publishes JSON to Redis
|
||||
- test_typing_indicator_sent: WebSocket handler sends {"type": "typing"} before Celery dispatch
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.models.message import ChannelType, KonstructMessage
|
||||
from shared.redis_keys import webchat_response_key
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test webchat_response_key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_webchat_response_key_format() -> None:
|
||||
"""webchat_response_key returns correctly namespaced key."""
|
||||
tenant_id = "tenant-abc"
|
||||
conversation_id = "conv-xyz"
|
||||
key = webchat_response_key(tenant_id, conversation_id)
|
||||
assert key == "tenant-abc:webchat:response:conv-xyz"
|
||||
|
||||
|
||||
def test_webchat_response_key_isolation() -> None:
|
||||
"""Two tenants with same conversation_id get different keys."""
|
||||
key_a = webchat_response_key("tenant-a", "conv-1")
|
||||
key_b = webchat_response_key("tenant-b", "conv-1")
|
||||
assert key_a != key_b
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test ChannelType.WEB
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_channel_type_web_exists() -> None:
|
||||
"""ChannelType.WEB must exist with value 'web'."""
|
||||
assert ChannelType.WEB == "web"
|
||||
assert ChannelType.WEB.value == "web"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test normalize_web_event
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_normalize_web_event_returns_konstruct_message() -> None:
|
||||
"""normalize_web_event returns a KonstructMessage."""
|
||||
from gateway.channels.web import normalize_web_event
|
||||
|
||||
tenant_id = str(uuid.uuid4())
|
||||
agent_id = str(uuid.uuid4())
|
||||
user_id = str(uuid.uuid4())
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
event = {
|
||||
"text": "Hello from the portal",
|
||||
"tenant_id": tenant_id,
|
||||
"agent_id": agent_id,
|
||||
"user_id": user_id,
|
||||
"display_name": "Portal User",
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
msg = normalize_web_event(event)
|
||||
assert isinstance(msg, KonstructMessage)
|
||||
|
||||
|
||||
def test_normalize_web_event_channel_is_web() -> None:
|
||||
"""normalize_web_event sets channel = ChannelType.WEB."""
|
||||
from gateway.channels.web import normalize_web_event
|
||||
|
||||
event = {
|
||||
"text": "test",
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"agent_id": str(uuid.uuid4()),
|
||||
"user_id": str(uuid.uuid4()),
|
||||
"display_name": "User",
|
||||
"conversation_id": str(uuid.uuid4()),
|
||||
}
|
||||
msg = normalize_web_event(event)
|
||||
assert msg.channel == ChannelType.WEB
|
||||
|
||||
|
||||
def test_normalize_web_event_thread_id_equals_conversation_id() -> None:
|
||||
"""normalize_web_event sets thread_id = conversation_id for memory scoping."""
|
||||
from gateway.channels.web import normalize_web_event
|
||||
|
||||
conversation_id = str(uuid.uuid4())
|
||||
event = {
|
||||
"text": "test",
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"agent_id": str(uuid.uuid4()),
|
||||
"user_id": str(uuid.uuid4()),
|
||||
"display_name": "User",
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
msg = normalize_web_event(event)
|
||||
assert msg.thread_id == conversation_id
|
||||
|
||||
|
||||
def test_normalize_web_event_sender_user_id() -> None:
|
||||
"""normalize_web_event sets sender.user_id to the portal user UUID."""
|
||||
from gateway.channels.web import normalize_web_event
|
||||
|
||||
user_id = str(uuid.uuid4())
|
||||
event = {
|
||||
"text": "test",
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"agent_id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"display_name": "Portal User",
|
||||
"conversation_id": str(uuid.uuid4()),
|
||||
}
|
||||
msg = normalize_web_event(event)
|
||||
assert msg.sender.user_id == user_id
|
||||
|
||||
|
||||
def test_normalize_web_event_channel_metadata() -> None:
|
||||
"""normalize_web_event populates channel_metadata with portal_user_id, tenant_id, conversation_id."""
|
||||
from gateway.channels.web import normalize_web_event
|
||||
|
||||
tenant_id = str(uuid.uuid4())
|
||||
user_id = str(uuid.uuid4())
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
event = {
|
||||
"text": "test",
|
||||
"tenant_id": tenant_id,
|
||||
"agent_id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"display_name": "User",
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
msg = normalize_web_event(event)
|
||||
assert msg.channel_metadata["portal_user_id"] == user_id
|
||||
assert msg.channel_metadata["tenant_id"] == tenant_id
|
||||
assert msg.channel_metadata["conversation_id"] == conversation_id
|
||||
|
||||
|
||||
def test_normalize_web_event_tenant_id() -> None:
|
||||
"""normalize_web_event sets tenant_id on the message."""
|
||||
from gateway.channels.web import normalize_web_event
|
||||
|
||||
tenant_id = str(uuid.uuid4())
|
||||
event = {
|
||||
"text": "hello",
|
||||
"tenant_id": tenant_id,
|
||||
"agent_id": str(uuid.uuid4()),
|
||||
"user_id": str(uuid.uuid4()),
|
||||
"display_name": "User",
|
||||
"conversation_id": str(uuid.uuid4()),
|
||||
}
|
||||
msg = normalize_web_event(event)
|
||||
assert msg.tenant_id == tenant_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test _send_response web case publishes to Redis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_response_web_publishes_to_redis() -> None:
|
||||
"""_send_response('web', ...) publishes JSON message to Redis webchat channel."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.publish = AsyncMock()
|
||||
mock_redis.aclose = AsyncMock()
|
||||
|
||||
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
|
||||
from orchestrator.tasks import _send_response
|
||||
|
||||
extras = {
|
||||
"conversation_id": conversation_id,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
await _send_response("web", "Hello, this is the agent response!", extras)
|
||||
|
||||
expected_channel = webchat_response_key(tenant_id, conversation_id)
|
||||
mock_redis.publish.assert_called_once()
|
||||
call_args = mock_redis.publish.call_args
|
||||
assert call_args[0][0] == expected_channel
|
||||
published_payload = json.loads(call_args[0][1])
|
||||
assert published_payload["type"] == "response"
|
||||
assert published_payload["text"] == "Hello, this is the agent response!"
|
||||
assert published_payload["conversation_id"] == conversation_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_response_web_connection_cleanup() -> None:
|
||||
"""_send_response web case always closes Redis connection (try/finally)."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.publish = AsyncMock()
|
||||
mock_redis.aclose = AsyncMock()
|
||||
|
||||
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
|
||||
from orchestrator.tasks import _send_response
|
||||
|
||||
await _send_response("web", "test", {"conversation_id": "c1", "tenant_id": "t1"})
|
||||
|
||||
mock_redis.aclose.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_response_web_missing_conversation_id_logs_warning() -> None:
|
||||
"""_send_response web case logs warning if conversation_id missing."""
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch("orchestrator.tasks.aioredis.from_url", return_value=mock_redis):
|
||||
with patch("orchestrator.tasks.logger") as mock_logger:
|
||||
from orchestrator.tasks import _send_response
|
||||
|
||||
await _send_response("web", "test", {"tenant_id": "t1"})
|
||||
|
||||
mock_logger.warning.assert_called()
|
||||
mock_redis.publish.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test typing indicator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_typing_indicator_sent_before_dispatch() -> None:
|
||||
"""WebSocket handler sends {'type': 'typing'} immediately after receiving user message."""
|
||||
# We test the typing indicator by calling the handler function directly
|
||||
# with a mocked WebSocket and mocked Celery dispatch.
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
# First receive_json returns auth message, second returns the user message
|
||||
mock_ws.receive_json = AsyncMock(
|
||||
side_effect=[
|
||||
{"type": "auth", "userId": str(uuid.uuid4()), "role": "customer_operator", "tenantId": str(uuid.uuid4())},
|
||||
{"type": "message", "text": "Hello agent", "agentId": str(uuid.uuid4()), "conversationId": str(uuid.uuid4())},
|
||||
]
|
||||
)
|
||||
# accept must be awaitable
|
||||
mock_ws.accept = AsyncMock()
|
||||
mock_ws.send_json = AsyncMock()
|
||||
|
||||
# Mock DB session that returns a conversation
|
||||
mock_session = AsyncMock()
|
||||
mock_conv = MagicMock()
|
||||
mock_conv.id = uuid.uuid4()
|
||||
mock_conv.tenant_id = uuid.uuid4()
|
||||
mock_session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=mock_conv))
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Mock Redis pub-sub: raise after publishing once so handler exits
|
||||
mock_redis = AsyncMock()
|
||||
mock_pubsub = AsyncMock()
|
||||
mock_pubsub.subscribe = AsyncMock()
|
||||
mock_pubsub.get_message = AsyncMock(return_value={
|
||||
"type": "message",
|
||||
"data": json.dumps({
|
||||
"type": "response",
|
||||
"text": "Agent reply",
|
||||
"conversation_id": str(uuid.uuid4()),
|
||||
}),
|
||||
})
|
||||
mock_pubsub.unsubscribe = AsyncMock()
|
||||
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
||||
mock_redis.aclose = AsyncMock()
|
||||
|
||||
mock_handle = MagicMock()
|
||||
|
||||
with (
|
||||
patch("gateway.channels.web.async_session_factory", return_value=mock_session),
|
||||
patch("gateway.channels.web.aioredis.from_url", return_value=mock_redis),
|
||||
patch("gateway.channels.web.handle_message") as mock_handle_msg,
|
||||
patch("gateway.channels.web.configure_rls_hook"),
|
||||
patch("gateway.channels.web.current_tenant_id"),
|
||||
):
|
||||
mock_handle_msg.delay = MagicMock()
|
||||
|
||||
from gateway.channels.web import _handle_websocket_connection
|
||||
|
||||
# Run the handler — it should send typing indicator then process
|
||||
# Use asyncio.wait_for to prevent infinite loop if something hangs
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
_handle_websocket_connection(mock_ws, str(uuid.uuid4())),
|
||||
timeout=2.0,
|
||||
)
|
||||
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
|
||||
pass
|
||||
|
||||
# Check typing was sent before Celery dispatch
|
||||
send_calls = mock_ws.send_json.call_args_list
|
||||
typing_calls = [c for c in send_calls if c[0][0].get("type") == "typing"]
|
||||
assert len(typing_calls) >= 1, f"Expected typing indicator, got send_json calls: {send_calls}"
|
||||
Reference in New Issue
Block a user