Files
konstruct/packages/gateway/gateway/channels/web.py
Adolfo Delorenzo 61b8762bac feat(streaming): update WebSocket handler to forward streaming chunks to browser
- Pub-sub loop now handles 'chunk' and 'done' message types (not just 'response')
- 'chunk' messages are forwarded immediately via websocket.send_json
- 'done' message breaks the loop and triggers DB persistence of full response
- Sends final 'done' JSON to browser to signal stream completion
- Legacy 'response' type no longer emitted from orchestrator (now unified as 'done')
2026-03-25 17:57:08 -06:00

370 lines
15 KiB
Python

"""
Web Channel Adapter — WebSocket endpoint and message normalizer.
The web channel lets portal users chat with AI employees directly from
the Konstruct portal UI. Messages flow through the same agent pipeline
as Slack and WhatsApp — the only difference is the transport layer.
Message flow:
1. Browser opens WebSocket at /chat/ws/{conversation_id}
2. Client sends {"type": "auth", "userId": ..., "role": ..., "tenantId": ...}
NOTE: Browsers cannot set custom HTTP headers on WebSocket connections,
so auth credentials are sent as the first JSON message (Pitfall 1).
3. For each user message (type="message"):
a. Server immediately sends {"type": "typing"} to client (CHAT-05)
b. normalize_web_event() converts to KonstructMessage (channel=WEB)
c. User message saved to web_conversation_messages
d. handle_message.delay(msg | extras) dispatches to Celery pipeline
e. Server subscribes to Redis pub-sub channel for the response
f. When orchestrator publishes the response:
- Save assistant message to web_conversation_messages
- Send {"type": "response", "text": ..., "conversation_id": ...} to client
4. On disconnect: unsubscribe and close all Redis connections
Design notes:
- thread_id = conversation_id — scopes agent memory to one conversation (Pitfall 3)
- Redis pub-sub connections closed in try/finally to prevent leaks (Pitfall 2)
- DB access uses configure_rls_hook + current_tenant_id context var per project pattern
- WebSocket is a long-lived connection; each message/response cycle is synchronous
within the connection but non-blocking for other connections
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from datetime import datetime, timezone
from typing import Any
import redis.asyncio as aioredis
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from sqlalchemy import select, text
from orchestrator.tasks import handle_message
from shared.config import settings
from shared.db import async_session_factory, engine
from shared.models.chat import WebConversation, WebConversationMessage
from shared.models.message import ChannelType, KonstructMessage, MessageContent, SenderInfo
from shared.redis_keys import webchat_response_key
from shared.rls import configure_rls_hook, current_tenant_id
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Router — mounted in gateway/main.py
# ---------------------------------------------------------------------------
web_chat_router = APIRouter(tags=["web-chat"])
# Timeout for waiting for an agent response via Redis pub-sub (seconds)
_RESPONSE_TIMEOUT_SECONDS = 180
def normalize_web_event(event: dict[str, Any]) -> KonstructMessage:
"""
Normalize a web channel event dict into a KonstructMessage.
The web channel normalizer sets thread_id = conversation_id so that
the agent memory pipeline scopes context to this conversation (Pitfall 3).
Args:
event: Dict with keys: text, tenant_id, agent_id, user_id,
display_name, conversation_id.
Returns:
KonstructMessage with channel=WEB, thread_id=conversation_id.
"""
tenant_id: str = event.get("tenant_id", "") or ""
user_id: str = event.get("user_id", "") or ""
display_name: str = event.get("display_name", "Portal User")
conversation_id: str = event.get("conversation_id", "") or ""
text_content: str = event.get("text", "") or ""
return KonstructMessage(
id=str(uuid.uuid4()),
tenant_id=tenant_id,
channel=ChannelType.WEB,
channel_metadata={
"portal_user_id": user_id,
"tenant_id": tenant_id,
"conversation_id": conversation_id,
},
sender=SenderInfo(
user_id=user_id,
display_name=display_name,
),
content=MessageContent(
text=text_content,
),
timestamp=datetime.now(timezone.utc),
thread_id=conversation_id,
reply_to=None,
context={},
)
async def _handle_websocket_connection(
websocket: WebSocket,
conversation_id: str,
) -> None:
"""
Core WebSocket connection handler — separated for testability.
Lifecycle:
1. Accept connection
2. Wait for auth message (browser cannot send custom headers)
3. Loop: receive messages → type indicator → Celery dispatch → Redis subscribe → response
Args:
websocket: The FastAPI WebSocket connection.
conversation_id: The conversation UUID from the URL path.
"""
await websocket.accept()
# -------------------------------------------------------------------------
# Step 1: Auth handshake
# Browsers cannot send custom HTTP headers on WebSocket connections.
# Auth credentials are sent as the first JSON message.
# -------------------------------------------------------------------------
try:
auth_msg = await websocket.receive_json()
except WebSocketDisconnect:
return
if auth_msg.get("type") != "auth":
await websocket.send_json({"type": "error", "message": "First message must be auth"})
await websocket.close(code=4001)
return
user_id_str: str = auth_msg.get("userId", "") or ""
user_role: str = auth_msg.get("role", "") or ""
tenant_id_str: str = auth_msg.get("tenantId", "") or ""
if not user_id_str or not tenant_id_str:
await websocket.send_json({"type": "error", "message": "Missing userId or tenantId in auth"})
await websocket.close(code=4001)
return
# Validate UUID format
try:
uuid.UUID(user_id_str)
tenant_uuid = uuid.UUID(tenant_id_str)
except (ValueError, AttributeError):
await websocket.send_json({"type": "error", "message": "Invalid UUID format in auth"})
await websocket.close(code=4001)
return
logger.info(
"WebSocket auth: user=%s role=%s tenant=%s conversation=%s",
user_id_str, user_role, tenant_id_str, conversation_id,
)
# -------------------------------------------------------------------------
# Step 2: Message loop
# -------------------------------------------------------------------------
while True:
try:
msg_data = await websocket.receive_json()
except (WebSocketDisconnect, Exception):
break
if msg_data.get("type") != "message":
continue
text_content: str = msg_data.get("text", "") or ""
agent_id_str: str = msg_data.get("agentId", "") or ""
msg_conversation_id: str = msg_data.get("conversationId", conversation_id) or conversation_id
display_name: str = msg_data.get("displayName", "Portal User")
# -------------------------------------------------------------------
# a. Send typing indicator IMMEDIATELY — before any DB or Celery work
# -------------------------------------------------------------------
await websocket.send_json({"type": "typing"})
# -------------------------------------------------------------------
# b. Save user message to web_conversation_messages
# -------------------------------------------------------------------
configure_rls_hook(engine)
rls_token = current_tenant_id.set(tenant_uuid)
saved_conversation_id = msg_conversation_id
try:
async with async_session_factory() as session:
# Look up the conversation to get tenant-scoped context
conv_stmt = select(WebConversation).where(
WebConversation.id == uuid.UUID(msg_conversation_id)
)
conv_result = await session.execute(conv_stmt)
conversation = conv_result.scalar_one_or_none()
if conversation is not None:
# Save user message
user_msg = WebConversationMessage(
conversation_id=uuid.UUID(msg_conversation_id),
tenant_id=tenant_uuid,
role="user",
content=text_content,
)
session.add(user_msg)
# Update conversation timestamp
await session.execute(
text(
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
),
{"conv_id": str(msg_conversation_id)},
)
await session.commit()
saved_conversation_id = msg_conversation_id
except Exception:
logger.exception(
"Failed to save user message for conversation=%s", msg_conversation_id
)
finally:
current_tenant_id.reset(rls_token)
# -------------------------------------------------------------------
# c. Normalize and dispatch to Celery pipeline
# -------------------------------------------------------------------
event = {
"text": text_content,
"tenant_id": tenant_id_str,
"agent_id": agent_id_str,
"user_id": user_id_str,
"display_name": display_name,
"conversation_id": saved_conversation_id,
}
normalized_msg = normalize_web_event(event)
extras = {
"conversation_id": saved_conversation_id,
"portal_user_id": user_id_str,
}
task_payload = normalized_msg.model_dump(mode="json") | extras
handle_message.delay(task_payload)
# -------------------------------------------------------------------
# d. Subscribe to Redis pub-sub and forward streaming chunks to client
#
# The orchestrator publishes two message types to the response channel:
# {"type": "chunk", "text": "<token>"} — zero or more times (streaming)
# {"type": "done", "text": "<full>", "conversation_id": "..."} — final marker
#
# We forward each "chunk" immediately to the browser so text appears
# word-by-word. On "done" we save the full response to the DB.
# -------------------------------------------------------------------
response_channel = webchat_response_key(tenant_id_str, saved_conversation_id)
subscribe_redis = aioredis.from_url(settings.redis_url)
response_text: str = ""
try:
pubsub = subscribe_redis.pubsub()
await pubsub.subscribe(response_channel)
deadline = asyncio.get_event_loop().time() + _RESPONSE_TIMEOUT_SECONDS
while asyncio.get_event_loop().time() < deadline:
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if message and message.get("type") == "message":
try:
payload = json.loads(message["data"])
except (json.JSONDecodeError, KeyError):
await asyncio.sleep(0.01)
continue
msg_type = payload.get("type")
if msg_type == "chunk":
# Forward token immediately — do not break the loop
token = payload.get("text", "")
if token:
try:
await websocket.send_json({
"type": "chunk",
"text": token,
})
except Exception:
# Client disconnected mid-stream — exit cleanly
break
elif msg_type == "done":
# Final marker — full text for DB persistence
response_text = payload.get("text", "")
break
else:
await asyncio.sleep(0.05)
await pubsub.unsubscribe(response_channel)
finally:
await subscribe_redis.aclose()
# -------------------------------------------------------------------
# e. Save assistant message and send final "done" to client
# -------------------------------------------------------------------
if response_text:
rls_token2 = current_tenant_id.set(tenant_uuid)
try:
async with async_session_factory() as session:
assistant_msg = WebConversationMessage(
conversation_id=uuid.UUID(saved_conversation_id),
tenant_id=tenant_uuid,
role="assistant",
content=response_text,
)
session.add(assistant_msg)
await session.execute(
text(
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
),
{"conv_id": str(saved_conversation_id)},
)
await session.commit()
except Exception:
logger.exception(
"Failed to save assistant message for conversation=%s", saved_conversation_id
)
finally:
current_tenant_id.reset(rls_token2)
# Signal stream completion to the client
await websocket.send_json({
"type": "done",
"text": response_text,
"conversation_id": saved_conversation_id,
})
else:
logger.warning(
"No response received within %ds for conversation=%s",
_RESPONSE_TIMEOUT_SECONDS,
saved_conversation_id,
)
await websocket.send_json({
"type": "error",
"message": "Agent did not respond in time. Please try again.",
})
@web_chat_router.websocket("/chat/ws/{conversation_id}")
async def chat_websocket(websocket: WebSocket, conversation_id: str) -> None:
"""
WebSocket endpoint for web chat.
URL: /chat/ws/{conversation_id}
Protocol:
1. Connect
2. Send: {"type": "auth", "userId": "...", "role": "...", "tenantId": "..."}
3. Send: {"type": "message", "text": "...", "agentId": "...", "conversationId": "..."}
4. Receive: {"type": "typing"}
5. Receive: {"type": "response", "text": "...", "conversation_id": "..."}
Closes with code 4001 on auth failure.
"""
try:
await _handle_websocket_connection(websocket, conversation_id)
except WebSocketDisconnect:
logger.info("WebSocket disconnected for conversation=%s", conversation_id)
except Exception:
logger.exception("Unhandled error in WebSocket handler for conversation=%s", conversation_id)