feat(06-01): WebSocket endpoint, chat REST API, orchestrator wiring, gateway mounting

- Create gateway/channels/web.py with normalize_web_event() and /chat/ws/{conversation_id}
  WebSocket endpoint (auth via first JSON message, typing indicator, Redis pub-sub response)
- Create shared/api/chat.py with GET/POST/DELETE /api/portal/chat/conversations* REST API
  with require_tenant_member RBAC enforcement and RLS context var setup
- Add chat_router to shared/api/__init__.py exports
- Mount chat_router and web_chat_router in gateway/main.py (Phase 6 Web Chat routers)
- All 19 unit tests pass; full 313-test suite green
This commit is contained in:
2026-03-25 10:26:54 -06:00
parent c72beb916b
commit 56c11a0f1a
4 changed files with 706 additions and 0 deletions

View File

@@ -0,0 +1,340 @@
"""
Web Channel Adapter — WebSocket endpoint and message normalizer.
The web channel lets portal users chat with AI employees directly from
the Konstruct portal UI. Messages flow through the same agent pipeline
as Slack and WhatsApp — the only difference is the transport layer.
Message flow:
1. Browser opens WebSocket at /chat/ws/{conversation_id}
2. Client sends {"type": "auth", "userId": ..., "role": ..., "tenantId": ...}
NOTE: Browsers cannot set custom HTTP headers on WebSocket connections,
so auth credentials are sent as the first JSON message (Pitfall 1).
3. For each user message (type="message"):
a. Server immediately sends {"type": "typing"} to client (CHAT-05)
b. normalize_web_event() converts to KonstructMessage (channel=WEB)
c. User message saved to web_conversation_messages
d. handle_message.delay(msg | extras) dispatches to Celery pipeline
e. Server subscribes to Redis pub-sub channel for the response
f. When orchestrator publishes the response:
- Save assistant message to web_conversation_messages
- Send {"type": "response", "text": ..., "conversation_id": ...} to client
4. On disconnect: unsubscribe and close all Redis connections
Design notes:
- thread_id = conversation_id — scopes agent memory to one conversation (Pitfall 3)
- Redis pub-sub connections closed in try/finally to prevent leaks (Pitfall 2)
- DB access uses configure_rls_hook + current_tenant_id context var per project pattern
- WebSocket is a long-lived connection; each message/response cycle is synchronous
within the connection but non-blocking for other connections
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from datetime import datetime, timezone
from typing import Any
import redis.asyncio as aioredis
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from sqlalchemy import select, text
from orchestrator.tasks import handle_message
from shared.config import settings
from shared.db import async_session_factory, engine
from shared.models.chat import WebConversation, WebConversationMessage
from shared.models.message import ChannelType, KonstructMessage, MessageContent, SenderInfo
from shared.redis_keys import webchat_response_key
from shared.rls import configure_rls_hook, current_tenant_id
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Router — mounted in gateway/main.py
# ---------------------------------------------------------------------------
web_chat_router = APIRouter(tags=["web-chat"])
# Timeout for waiting for an agent response via Redis pub-sub (seconds)
_RESPONSE_TIMEOUT_SECONDS = 60
def normalize_web_event(event: dict[str, Any]) -> KonstructMessage:
"""
Normalize a web channel event dict into a KonstructMessage.
The web channel normalizer sets thread_id = conversation_id so that
the agent memory pipeline scopes context to this conversation (Pitfall 3).
Args:
event: Dict with keys: text, tenant_id, agent_id, user_id,
display_name, conversation_id.
Returns:
KonstructMessage with channel=WEB, thread_id=conversation_id.
"""
tenant_id: str = event.get("tenant_id", "") or ""
user_id: str = event.get("user_id", "") or ""
display_name: str = event.get("display_name", "Portal User")
conversation_id: str = event.get("conversation_id", "") or ""
text_content: str = event.get("text", "") or ""
return KonstructMessage(
id=str(uuid.uuid4()),
tenant_id=tenant_id,
channel=ChannelType.WEB,
channel_metadata={
"portal_user_id": user_id,
"tenant_id": tenant_id,
"conversation_id": conversation_id,
},
sender=SenderInfo(
user_id=user_id,
display_name=display_name,
),
content=MessageContent(
text=text_content,
),
timestamp=datetime.now(timezone.utc),
thread_id=conversation_id,
reply_to=None,
context={},
)
async def _handle_websocket_connection(
websocket: WebSocket,
conversation_id: str,
) -> None:
"""
Core WebSocket connection handler — separated for testability.
Lifecycle:
1. Accept connection
2. Wait for auth message (browser cannot send custom headers)
3. Loop: receive messages → type indicator → Celery dispatch → Redis subscribe → response
Args:
websocket: The FastAPI WebSocket connection.
conversation_id: The conversation UUID from the URL path.
"""
await websocket.accept()
# -------------------------------------------------------------------------
# Step 1: Auth handshake
# Browsers cannot send custom HTTP headers on WebSocket connections.
# Auth credentials are sent as the first JSON message.
# -------------------------------------------------------------------------
try:
auth_msg = await websocket.receive_json()
except WebSocketDisconnect:
return
if auth_msg.get("type") != "auth":
await websocket.send_json({"type": "error", "message": "First message must be auth"})
await websocket.close(code=4001)
return
user_id_str: str = auth_msg.get("userId", "") or ""
user_role: str = auth_msg.get("role", "") or ""
tenant_id_str: str = auth_msg.get("tenantId", "") or ""
if not user_id_str or not tenant_id_str:
await websocket.send_json({"type": "error", "message": "Missing userId or tenantId in auth"})
await websocket.close(code=4001)
return
# Validate UUID format
try:
uuid.UUID(user_id_str)
tenant_uuid = uuid.UUID(tenant_id_str)
except (ValueError, AttributeError):
await websocket.send_json({"type": "error", "message": "Invalid UUID format in auth"})
await websocket.close(code=4001)
return
logger.info(
"WebSocket auth: user=%s role=%s tenant=%s conversation=%s",
user_id_str, user_role, tenant_id_str, conversation_id,
)
# -------------------------------------------------------------------------
# Step 2: Message loop
# -------------------------------------------------------------------------
while True:
try:
msg_data = await websocket.receive_json()
except (WebSocketDisconnect, Exception):
break
if msg_data.get("type") != "message":
continue
text_content: str = msg_data.get("text", "") or ""
agent_id_str: str = msg_data.get("agentId", "") or ""
msg_conversation_id: str = msg_data.get("conversationId", conversation_id) or conversation_id
display_name: str = msg_data.get("displayName", "Portal User")
# -------------------------------------------------------------------
# a. Send typing indicator IMMEDIATELY — before any DB or Celery work
# -------------------------------------------------------------------
await websocket.send_json({"type": "typing"})
# -------------------------------------------------------------------
# b. Save user message to web_conversation_messages
# -------------------------------------------------------------------
configure_rls_hook(engine)
rls_token = current_tenant_id.set(tenant_uuid)
saved_conversation_id = msg_conversation_id
try:
async with async_session_factory() as session:
# Look up the conversation to get tenant-scoped context
conv_stmt = select(WebConversation).where(
WebConversation.id == uuid.UUID(msg_conversation_id)
)
conv_result = await session.execute(conv_stmt)
conversation = conv_result.scalar_one_or_none()
if conversation is not None:
# Save user message
user_msg = WebConversationMessage(
conversation_id=uuid.UUID(msg_conversation_id),
tenant_id=tenant_uuid,
role="user",
content=text_content,
)
session.add(user_msg)
# Update conversation timestamp
await session.execute(
text(
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
),
{"conv_id": str(msg_conversation_id)},
)
await session.commit()
saved_conversation_id = msg_conversation_id
except Exception:
logger.exception(
"Failed to save user message for conversation=%s", msg_conversation_id
)
finally:
current_tenant_id.reset(rls_token)
# -------------------------------------------------------------------
# c. Normalize and dispatch to Celery pipeline
# -------------------------------------------------------------------
event = {
"text": text_content,
"tenant_id": tenant_id_str,
"agent_id": agent_id_str,
"user_id": user_id_str,
"display_name": display_name,
"conversation_id": saved_conversation_id,
}
normalized_msg = normalize_web_event(event)
extras = {
"conversation_id": saved_conversation_id,
"portal_user_id": user_id_str,
}
task_payload = normalized_msg.model_dump(mode="json") | extras
handle_message.delay(task_payload)
# -------------------------------------------------------------------
# d. Subscribe to Redis pub-sub and wait for agent response
# -------------------------------------------------------------------
response_channel = webchat_response_key(tenant_id_str, saved_conversation_id)
subscribe_redis = aioredis.from_url(settings.redis_url)
try:
pubsub = subscribe_redis.pubsub()
await pubsub.subscribe(response_channel)
response_text: str = ""
deadline = asyncio.get_event_loop().time() + _RESPONSE_TIMEOUT_SECONDS
while asyncio.get_event_loop().time() < deadline:
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if message and message.get("type") == "message":
try:
payload = json.loads(message["data"])
response_text = payload.get("text", "")
except (json.JSONDecodeError, KeyError):
pass
break
await asyncio.sleep(0.05)
await pubsub.unsubscribe(response_channel)
finally:
await subscribe_redis.aclose()
# -------------------------------------------------------------------
# e. Save assistant message and send response to client
# -------------------------------------------------------------------
if response_text:
rls_token2 = current_tenant_id.set(tenant_uuid)
try:
async with async_session_factory() as session:
assistant_msg = WebConversationMessage(
conversation_id=uuid.UUID(saved_conversation_id),
tenant_id=tenant_uuid,
role="assistant",
content=response_text,
)
session.add(assistant_msg)
await session.execute(
text(
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
),
{"conv_id": str(saved_conversation_id)},
)
await session.commit()
except Exception:
logger.exception(
"Failed to save assistant message for conversation=%s", saved_conversation_id
)
finally:
current_tenant_id.reset(rls_token2)
await websocket.send_json({
"type": "response",
"text": response_text,
"conversation_id": saved_conversation_id,
})
else:
logger.warning(
"No response received within %ds for conversation=%s",
_RESPONSE_TIMEOUT_SECONDS,
saved_conversation_id,
)
await websocket.send_json({
"type": "error",
"message": "Agent did not respond in time. Please try again.",
})
@web_chat_router.websocket("/chat/ws/{conversation_id}")
async def chat_websocket(websocket: WebSocket, conversation_id: str) -> None:
"""
WebSocket endpoint for web chat.
URL: /chat/ws/{conversation_id}
Protocol:
1. Connect
2. Send: {"type": "auth", "userId": "...", "role": "...", "tenantId": "..."}
3. Send: {"type": "message", "text": "...", "agentId": "...", "conversationId": "..."}
4. Receive: {"type": "typing"}
5. Receive: {"type": "response", "text": "...", "conversation_id": "..."}
Closes with code 4001 on auth failure.
"""
try:
await _handle_websocket_connection(websocket, conversation_id)
except WebSocketDisconnect:
logger.info("WebSocket disconnected for conversation=%s", conversation_id)
except Exception:
logger.exception("Unhandled error in WebSocket handler for conversation=%s", conversation_id)

View File

@@ -39,10 +39,12 @@ from slack_bolt.adapter.fastapi.async_handler import AsyncSlackRequestHandler
from slack_bolt.async_app import AsyncApp from slack_bolt.async_app import AsyncApp
from gateway.channels.slack import register_slack_handlers from gateway.channels.slack import register_slack_handlers
from gateway.channels.web import web_chat_router
from gateway.channels.whatsapp import whatsapp_router from gateway.channels.whatsapp import whatsapp_router
from shared.api import ( from shared.api import (
billing_router, billing_router,
channels_router, channels_router,
chat_router,
invitations_router, invitations_router,
llm_keys_router, llm_keys_router,
portal_router, portal_router,
@@ -146,6 +148,12 @@ app.include_router(invitations_router)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
app.include_router(templates_router) app.include_router(templates_router)
# ---------------------------------------------------------------------------
# Phase 6 Web Chat routers
# ---------------------------------------------------------------------------
app.include_router(chat_router) # REST: /api/portal/chat/*
app.include_router(web_chat_router) # WebSocket: /chat/ws/{conversation_id}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Routes # Routes

View File

@@ -6,6 +6,7 @@ Import and mount these routers in service main.py files.
from shared.api.billing import billing_router, webhook_router from shared.api.billing import billing_router, webhook_router
from shared.api.channels import channels_router from shared.api.channels import channels_router
from shared.api.chat import chat_router
from shared.api.invitations import invitations_router from shared.api.invitations import invitations_router
from shared.api.llm_keys import llm_keys_router from shared.api.llm_keys import llm_keys_router
from shared.api.portal import portal_router from shared.api.portal import portal_router
@@ -21,4 +22,5 @@ __all__ = [
"usage_router", "usage_router",
"invitations_router", "invitations_router",
"templates_router", "templates_router",
"chat_router",
] ]

View File

@@ -0,0 +1,356 @@
"""
FastAPI chat REST API — conversation CRUD with RBAC.
Provides conversation management for the Phase 6 web chat feature.
All endpoints require portal authentication via X-Portal-User-Id headers
and enforce tenant membership (or platform_admin bypass).
Endpoints:
GET /api/portal/chat/conversations — list conversations
POST /api/portal/chat/conversations — create or get-or-create
GET /api/portal/chat/conversations/{id}/messages — paginated history
DELETE /api/portal/chat/conversations/{id} — reset conversation
RBAC:
- platform_admin: can access any tenant's conversations
- customer_admin / customer_operator: must be a member of the target tenant
- Other roles: 403
RLS:
All DB queries set current_tenant_id context var before executing so
PostgreSQL's FORCE ROW LEVEL SECURITY policy is applied automatically.
"""
from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import delete, select, text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from shared.api.rbac import PortalCaller, get_portal_caller, require_tenant_member
from shared.db import get_session, engine
from shared.models.chat import WebConversation, WebConversationMessage
from shared.models.tenant import Agent
from shared.rls import configure_rls_hook, current_tenant_id
chat_router = APIRouter(prefix="/api/portal/chat", tags=["chat"])
# ---------------------------------------------------------------------------
# Pydantic schemas
# ---------------------------------------------------------------------------
class ConversationOut(BaseModel):
id: str
tenant_id: str
agent_id: str
agent_name: str | None = None
user_id: str
created_at: datetime
updated_at: datetime
last_message_preview: str | None = None
class ConversationCreate(BaseModel):
tenant_id: uuid.UUID
agent_id: uuid.UUID
class MessageOut(BaseModel):
id: str
role: str
content: str
created_at: datetime
class DeleteResult(BaseModel):
deleted: bool
conversation_id: str
# ---------------------------------------------------------------------------
# Helper: configure RLS and set context var
# ---------------------------------------------------------------------------
def _rls_set(engine_: Any, tenant_uuid: uuid.UUID) -> Any:
"""Configure RLS hook and set the tenant context variable."""
configure_rls_hook(engine_)
return current_tenant_id.set(tenant_uuid)
# ---------------------------------------------------------------------------
# GET /api/portal/chat/conversations
# ---------------------------------------------------------------------------
@chat_router.get("/conversations", response_model=list[ConversationOut])
async def list_conversations(
tenant_id: uuid.UUID = Query(...),
caller: PortalCaller = Depends(get_portal_caller),
session: AsyncSession = Depends(get_session),
) -> list[ConversationOut]:
"""
List conversations for the authenticated user within a tenant.
Platform admins can see all conversations for the tenant.
Other users see only their own conversations.
"""
# RBAC — raises 403 if caller is not a member (platform_admin bypasses)
await require_tenant_member(tenant_id=tenant_id, caller=caller, session=session)
token = _rls_set(engine, tenant_id)
try:
stmt = (
select(WebConversation, Agent.name.label("agent_name"))
.join(Agent, WebConversation.agent_id == Agent.id, isouter=True)
.where(WebConversation.tenant_id == tenant_id)
)
# Non-admins only see their own conversations
if caller.role != "platform_admin":
stmt = stmt.where(WebConversation.user_id == caller.user_id)
stmt = stmt.order_by(WebConversation.updated_at.desc())
result = await session.execute(stmt)
rows = result.all()
conversations: list[ConversationOut] = []
for row in rows:
conv = row[0]
agent_name = row[1] if len(row) > 1 else None
conversations.append(
ConversationOut(
id=str(conv.id),
tenant_id=str(conv.tenant_id),
agent_id=str(conv.agent_id),
agent_name=agent_name,
user_id=str(conv.user_id),
created_at=conv.created_at,
updated_at=conv.updated_at,
)
)
return conversations
finally:
current_tenant_id.reset(token)
# ---------------------------------------------------------------------------
# POST /api/portal/chat/conversations
# ---------------------------------------------------------------------------
@chat_router.post("/conversations", response_model=ConversationOut, status_code=status.HTTP_200_OK)
async def create_conversation(
body: ConversationCreate,
caller: PortalCaller = Depends(get_portal_caller),
session: AsyncSession = Depends(get_session),
) -> ConversationOut:
"""
Create or get an existing conversation for the caller + agent pair.
Uses get-or-create semantics: if a conversation already exists for this
(tenant_id, agent_id, user_id) triple, it is returned rather than creating
a duplicate.
"""
# RBAC
await require_tenant_member(tenant_id=body.tenant_id, caller=caller, session=session)
token = _rls_set(engine, body.tenant_id)
try:
# Check for existing conversation
existing_stmt = select(WebConversation).where(
WebConversation.tenant_id == body.tenant_id,
WebConversation.agent_id == body.agent_id,
WebConversation.user_id == caller.user_id,
)
existing_result = await session.execute(existing_stmt)
existing = existing_result.scalar_one_or_none()
if existing is not None:
return ConversationOut(
id=str(existing.id),
tenant_id=str(existing.tenant_id),
agent_id=str(existing.agent_id),
user_id=str(existing.user_id),
created_at=existing.created_at,
updated_at=existing.updated_at,
)
# Create new conversation
new_conv = WebConversation(
id=uuid.uuid4(),
tenant_id=body.tenant_id,
agent_id=body.agent_id,
user_id=caller.user_id,
)
session.add(new_conv)
try:
await session.flush()
await session.commit()
await session.refresh(new_conv)
except IntegrityError:
# Race condition: another request created it between our SELECT and INSERT
await session.rollback()
existing_result2 = await session.execute(existing_stmt)
existing2 = existing_result2.scalar_one_or_none()
if existing2 is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create conversation",
)
return ConversationOut(
id=str(existing2.id),
tenant_id=str(existing2.tenant_id),
agent_id=str(existing2.agent_id),
user_id=str(existing2.user_id),
created_at=existing2.created_at,
updated_at=existing2.updated_at,
)
return ConversationOut(
id=str(new_conv.id),
tenant_id=str(new_conv.tenant_id),
agent_id=str(new_conv.agent_id),
user_id=str(new_conv.user_id),
created_at=new_conv.created_at,
updated_at=new_conv.updated_at,
)
finally:
current_tenant_id.reset(token)
# ---------------------------------------------------------------------------
# GET /api/portal/chat/conversations/{id}/messages
# ---------------------------------------------------------------------------
@chat_router.get("/conversations/{conversation_id}/messages", response_model=list[MessageOut])
async def list_messages(
conversation_id: uuid.UUID,
limit: int = Query(default=50, ge=1, le=200),
before: str | None = Query(default=None),
caller: PortalCaller = Depends(get_portal_caller),
session: AsyncSession = Depends(get_session),
) -> list[MessageOut]:
"""
Return paginated message history for a conversation.
Messages ordered by created_at ASC (oldest first).
Cursor pagination via `before` parameter (message ID).
Ownership enforced: caller must own the conversation OR be platform_admin.
"""
# Fetch conversation first to verify ownership and get tenant_id
conv_stmt = select(WebConversation).where(WebConversation.id == conversation_id)
conv_result = await session.execute(conv_stmt)
conversation = conv_result.scalar_one_or_none()
if conversation is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Conversation not found")
# Ownership check: caller owns the conversation or is platform_admin
if caller.role != "platform_admin" and conversation.user_id != caller.user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have access to this conversation",
)
token = _rls_set(engine, conversation.tenant_id)
try:
msg_stmt = (
select(WebConversationMessage)
.where(WebConversationMessage.conversation_id == conversation_id)
.order_by(WebConversationMessage.created_at.asc())
.limit(limit)
)
if before:
try:
before_uuid = uuid.UUID(before)
# Get the cursor message's created_at
cursor_stmt = select(WebConversationMessage.created_at).where(
WebConversationMessage.id == before_uuid
)
cursor_result = await session.execute(cursor_stmt)
cursor_ts = cursor_result.scalar_one_or_none()
if cursor_ts is not None:
msg_stmt = msg_stmt.where(WebConversationMessage.created_at < cursor_ts)
except (ValueError, AttributeError):
pass # Invalid cursor — ignore and return from start
msg_result = await session.execute(msg_stmt)
messages = msg_result.scalars().all()
return [
MessageOut(
id=str(m.id),
role=m.role,
content=m.content,
created_at=m.created_at,
)
for m in messages
]
finally:
current_tenant_id.reset(token)
# ---------------------------------------------------------------------------
# DELETE /api/portal/chat/conversations/{id}
# ---------------------------------------------------------------------------
@chat_router.delete("/conversations/{conversation_id}", response_model=DeleteResult)
async def reset_conversation(
conversation_id: uuid.UUID,
caller: PortalCaller = Depends(get_portal_caller),
session: AsyncSession = Depends(get_session),
) -> DeleteResult:
"""
Reset a conversation by deleting all messages.
The conversation row is kept but all messages are deleted.
Updates updated_at on the conversation.
Ownership enforced: caller must own the conversation OR be platform_admin.
"""
# Fetch conversation
conv_stmt = select(WebConversation).where(WebConversation.id == conversation_id)
conv_result = await session.execute(conv_stmt)
conversation = conv_result.scalar_one_or_none()
if conversation is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Conversation not found")
# Ownership check
if caller.role != "platform_admin" and conversation.user_id != caller.user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have access to this conversation",
)
token = _rls_set(engine, conversation.tenant_id)
try:
# Delete all messages for this conversation
delete_stmt = delete(WebConversationMessage).where(
WebConversationMessage.conversation_id == conversation_id
)
await session.execute(delete_stmt)
# Update conversation timestamp
await session.execute(
text("UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"),
{"conv_id": str(conversation_id)},
)
await session.commit()
return DeleteResult(deleted=True, conversation_id=str(conversation_id))
finally:
current_tenant_id.reset(token)