- Pub-sub loop now handles 'chunk' and 'done' message types (not just 'response') - 'chunk' messages are forwarded immediately via websocket.send_json - 'done' message breaks the loop and triggers DB persistence of full response - Sends final 'done' JSON to browser to signal stream completion - Legacy 'response' type no longer emitted from orchestrator (now unified as 'done')
370 lines
15 KiB
Python
370 lines
15 KiB
Python
"""
|
|
Web Channel Adapter — WebSocket endpoint and message normalizer.
|
|
|
|
The web channel lets portal users chat with AI employees directly from
|
|
the Konstruct portal UI. Messages flow through the same agent pipeline
|
|
as Slack and WhatsApp — the only difference is the transport layer.
|
|
|
|
Message flow:
|
|
1. Browser opens WebSocket at /chat/ws/{conversation_id}
|
|
2. Client sends {"type": "auth", "userId": ..., "role": ..., "tenantId": ...}
|
|
NOTE: Browsers cannot set custom HTTP headers on WebSocket connections,
|
|
so auth credentials are sent as the first JSON message (Pitfall 1).
|
|
3. For each user message (type="message"):
|
|
a. Server immediately sends {"type": "typing"} to client (CHAT-05)
|
|
b. normalize_web_event() converts to KonstructMessage (channel=WEB)
|
|
c. User message saved to web_conversation_messages
|
|
d. handle_message.delay(msg | extras) dispatches to Celery pipeline
|
|
e. Server subscribes to Redis pub-sub channel for the response
|
|
f. When orchestrator publishes the response:
|
|
- Save assistant message to web_conversation_messages
|
|
- Send {"type": "response", "text": ..., "conversation_id": ...} to client
|
|
4. On disconnect: unsubscribe and close all Redis connections
|
|
|
|
Design notes:
|
|
- thread_id = conversation_id — scopes agent memory to one conversation (Pitfall 3)
|
|
- Redis pub-sub connections closed in try/finally to prevent leaks (Pitfall 2)
|
|
- DB access uses configure_rls_hook + current_tenant_id context var per project pattern
|
|
- WebSocket is a long-lived connection; each message/response cycle is synchronous
|
|
within the connection but non-blocking for other connections
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
import redis.asyncio as aioredis
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
from sqlalchemy import select, text
|
|
|
|
from orchestrator.tasks import handle_message
|
|
from shared.config import settings
|
|
from shared.db import async_session_factory, engine
|
|
from shared.models.chat import WebConversation, WebConversationMessage
|
|
from shared.models.message import ChannelType, KonstructMessage, MessageContent, SenderInfo
|
|
from shared.redis_keys import webchat_response_key
|
|
from shared.rls import configure_rls_hook, current_tenant_id
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Router — mounted in gateway/main.py
|
|
# ---------------------------------------------------------------------------
|
|
web_chat_router = APIRouter(tags=["web-chat"])
|
|
|
|
# Timeout for waiting for an agent response via Redis pub-sub (seconds)
|
|
_RESPONSE_TIMEOUT_SECONDS = 180
|
|
|
|
|
|
def normalize_web_event(event: dict[str, Any]) -> KonstructMessage:
|
|
"""
|
|
Normalize a web channel event dict into a KonstructMessage.
|
|
|
|
The web channel normalizer sets thread_id = conversation_id so that
|
|
the agent memory pipeline scopes context to this conversation (Pitfall 3).
|
|
|
|
Args:
|
|
event: Dict with keys: text, tenant_id, agent_id, user_id,
|
|
display_name, conversation_id.
|
|
|
|
Returns:
|
|
KonstructMessage with channel=WEB, thread_id=conversation_id.
|
|
"""
|
|
tenant_id: str = event.get("tenant_id", "") or ""
|
|
user_id: str = event.get("user_id", "") or ""
|
|
display_name: str = event.get("display_name", "Portal User")
|
|
conversation_id: str = event.get("conversation_id", "") or ""
|
|
text_content: str = event.get("text", "") or ""
|
|
|
|
return KonstructMessage(
|
|
id=str(uuid.uuid4()),
|
|
tenant_id=tenant_id,
|
|
channel=ChannelType.WEB,
|
|
channel_metadata={
|
|
"portal_user_id": user_id,
|
|
"tenant_id": tenant_id,
|
|
"conversation_id": conversation_id,
|
|
},
|
|
sender=SenderInfo(
|
|
user_id=user_id,
|
|
display_name=display_name,
|
|
),
|
|
content=MessageContent(
|
|
text=text_content,
|
|
),
|
|
timestamp=datetime.now(timezone.utc),
|
|
thread_id=conversation_id,
|
|
reply_to=None,
|
|
context={},
|
|
)
|
|
|
|
|
|
async def _handle_websocket_connection(
|
|
websocket: WebSocket,
|
|
conversation_id: str,
|
|
) -> None:
|
|
"""
|
|
Core WebSocket connection handler — separated for testability.
|
|
|
|
Lifecycle:
|
|
1. Accept connection
|
|
2. Wait for auth message (browser cannot send custom headers)
|
|
3. Loop: receive messages → type indicator → Celery dispatch → Redis subscribe → response
|
|
|
|
Args:
|
|
websocket: The FastAPI WebSocket connection.
|
|
conversation_id: The conversation UUID from the URL path.
|
|
"""
|
|
await websocket.accept()
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Step 1: Auth handshake
|
|
# Browsers cannot send custom HTTP headers on WebSocket connections.
|
|
# Auth credentials are sent as the first JSON message.
|
|
# -------------------------------------------------------------------------
|
|
try:
|
|
auth_msg = await websocket.receive_json()
|
|
except WebSocketDisconnect:
|
|
return
|
|
|
|
if auth_msg.get("type") != "auth":
|
|
await websocket.send_json({"type": "error", "message": "First message must be auth"})
|
|
await websocket.close(code=4001)
|
|
return
|
|
|
|
user_id_str: str = auth_msg.get("userId", "") or ""
|
|
user_role: str = auth_msg.get("role", "") or ""
|
|
tenant_id_str: str = auth_msg.get("tenantId", "") or ""
|
|
|
|
if not user_id_str or not tenant_id_str:
|
|
await websocket.send_json({"type": "error", "message": "Missing userId or tenantId in auth"})
|
|
await websocket.close(code=4001)
|
|
return
|
|
|
|
# Validate UUID format
|
|
try:
|
|
uuid.UUID(user_id_str)
|
|
tenant_uuid = uuid.UUID(tenant_id_str)
|
|
except (ValueError, AttributeError):
|
|
await websocket.send_json({"type": "error", "message": "Invalid UUID format in auth"})
|
|
await websocket.close(code=4001)
|
|
return
|
|
|
|
logger.info(
|
|
"WebSocket auth: user=%s role=%s tenant=%s conversation=%s",
|
|
user_id_str, user_role, tenant_id_str, conversation_id,
|
|
)
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Step 2: Message loop
|
|
# -------------------------------------------------------------------------
|
|
while True:
|
|
try:
|
|
msg_data = await websocket.receive_json()
|
|
except (WebSocketDisconnect, Exception):
|
|
break
|
|
|
|
if msg_data.get("type") != "message":
|
|
continue
|
|
|
|
text_content: str = msg_data.get("text", "") or ""
|
|
agent_id_str: str = msg_data.get("agentId", "") or ""
|
|
msg_conversation_id: str = msg_data.get("conversationId", conversation_id) or conversation_id
|
|
display_name: str = msg_data.get("displayName", "Portal User")
|
|
|
|
# -------------------------------------------------------------------
|
|
# a. Send typing indicator IMMEDIATELY — before any DB or Celery work
|
|
# -------------------------------------------------------------------
|
|
await websocket.send_json({"type": "typing"})
|
|
|
|
# -------------------------------------------------------------------
|
|
# b. Save user message to web_conversation_messages
|
|
# -------------------------------------------------------------------
|
|
configure_rls_hook(engine)
|
|
rls_token = current_tenant_id.set(tenant_uuid)
|
|
saved_conversation_id = msg_conversation_id
|
|
|
|
try:
|
|
async with async_session_factory() as session:
|
|
# Look up the conversation to get tenant-scoped context
|
|
conv_stmt = select(WebConversation).where(
|
|
WebConversation.id == uuid.UUID(msg_conversation_id)
|
|
)
|
|
conv_result = await session.execute(conv_stmt)
|
|
conversation = conv_result.scalar_one_or_none()
|
|
|
|
if conversation is not None:
|
|
# Save user message
|
|
user_msg = WebConversationMessage(
|
|
conversation_id=uuid.UUID(msg_conversation_id),
|
|
tenant_id=tenant_uuid,
|
|
role="user",
|
|
content=text_content,
|
|
)
|
|
session.add(user_msg)
|
|
|
|
# Update conversation timestamp
|
|
await session.execute(
|
|
text(
|
|
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
|
|
),
|
|
{"conv_id": str(msg_conversation_id)},
|
|
)
|
|
await session.commit()
|
|
saved_conversation_id = msg_conversation_id
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to save user message for conversation=%s", msg_conversation_id
|
|
)
|
|
finally:
|
|
current_tenant_id.reset(rls_token)
|
|
|
|
# -------------------------------------------------------------------
|
|
# c. Normalize and dispatch to Celery pipeline
|
|
# -------------------------------------------------------------------
|
|
event = {
|
|
"text": text_content,
|
|
"tenant_id": tenant_id_str,
|
|
"agent_id": agent_id_str,
|
|
"user_id": user_id_str,
|
|
"display_name": display_name,
|
|
"conversation_id": saved_conversation_id,
|
|
}
|
|
normalized_msg = normalize_web_event(event)
|
|
|
|
extras = {
|
|
"conversation_id": saved_conversation_id,
|
|
"portal_user_id": user_id_str,
|
|
}
|
|
task_payload = normalized_msg.model_dump(mode="json") | extras
|
|
handle_message.delay(task_payload)
|
|
|
|
# -------------------------------------------------------------------
|
|
# d. Subscribe to Redis pub-sub and forward streaming chunks to client
|
|
#
|
|
# The orchestrator publishes two message types to the response channel:
|
|
# {"type": "chunk", "text": "<token>"} — zero or more times (streaming)
|
|
# {"type": "done", "text": "<full>", "conversation_id": "..."} — final marker
|
|
#
|
|
# We forward each "chunk" immediately to the browser so text appears
|
|
# word-by-word. On "done" we save the full response to the DB.
|
|
# -------------------------------------------------------------------
|
|
response_channel = webchat_response_key(tenant_id_str, saved_conversation_id)
|
|
subscribe_redis = aioredis.from_url(settings.redis_url)
|
|
response_text: str = ""
|
|
try:
|
|
pubsub = subscribe_redis.pubsub()
|
|
await pubsub.subscribe(response_channel)
|
|
|
|
deadline = asyncio.get_event_loop().time() + _RESPONSE_TIMEOUT_SECONDS
|
|
|
|
while asyncio.get_event_loop().time() < deadline:
|
|
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
|
if message and message.get("type") == "message":
|
|
try:
|
|
payload = json.loads(message["data"])
|
|
except (json.JSONDecodeError, KeyError):
|
|
await asyncio.sleep(0.01)
|
|
continue
|
|
|
|
msg_type = payload.get("type")
|
|
|
|
if msg_type == "chunk":
|
|
# Forward token immediately — do not break the loop
|
|
token = payload.get("text", "")
|
|
if token:
|
|
try:
|
|
await websocket.send_json({
|
|
"type": "chunk",
|
|
"text": token,
|
|
})
|
|
except Exception:
|
|
# Client disconnected mid-stream — exit cleanly
|
|
break
|
|
|
|
elif msg_type == "done":
|
|
# Final marker — full text for DB persistence
|
|
response_text = payload.get("text", "")
|
|
break
|
|
|
|
else:
|
|
await asyncio.sleep(0.05)
|
|
|
|
await pubsub.unsubscribe(response_channel)
|
|
finally:
|
|
await subscribe_redis.aclose()
|
|
|
|
# -------------------------------------------------------------------
|
|
# e. Save assistant message and send final "done" to client
|
|
# -------------------------------------------------------------------
|
|
if response_text:
|
|
rls_token2 = current_tenant_id.set(tenant_uuid)
|
|
try:
|
|
async with async_session_factory() as session:
|
|
assistant_msg = WebConversationMessage(
|
|
conversation_id=uuid.UUID(saved_conversation_id),
|
|
tenant_id=tenant_uuid,
|
|
role="assistant",
|
|
content=response_text,
|
|
)
|
|
session.add(assistant_msg)
|
|
await session.execute(
|
|
text(
|
|
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
|
|
),
|
|
{"conv_id": str(saved_conversation_id)},
|
|
)
|
|
await session.commit()
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to save assistant message for conversation=%s", saved_conversation_id
|
|
)
|
|
finally:
|
|
current_tenant_id.reset(rls_token2)
|
|
|
|
# Signal stream completion to the client
|
|
await websocket.send_json({
|
|
"type": "done",
|
|
"text": response_text,
|
|
"conversation_id": saved_conversation_id,
|
|
})
|
|
else:
|
|
logger.warning(
|
|
"No response received within %ds for conversation=%s",
|
|
_RESPONSE_TIMEOUT_SECONDS,
|
|
saved_conversation_id,
|
|
)
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Agent did not respond in time. Please try again.",
|
|
})
|
|
|
|
|
|
@web_chat_router.websocket("/chat/ws/{conversation_id}")
|
|
async def chat_websocket(websocket: WebSocket, conversation_id: str) -> None:
|
|
"""
|
|
WebSocket endpoint for web chat.
|
|
|
|
URL: /chat/ws/{conversation_id}
|
|
|
|
Protocol:
|
|
1. Connect
|
|
2. Send: {"type": "auth", "userId": "...", "role": "...", "tenantId": "..."}
|
|
3. Send: {"type": "message", "text": "...", "agentId": "...", "conversationId": "..."}
|
|
4. Receive: {"type": "typing"}
|
|
5. Receive: {"type": "response", "text": "...", "conversation_id": "..."}
|
|
|
|
Closes with code 4001 on auth failure.
|
|
"""
|
|
try:
|
|
await _handle_websocket_connection(websocket, conversation_id)
|
|
except WebSocketDisconnect:
|
|
logger.info("WebSocket disconnected for conversation=%s", conversation_id)
|
|
except Exception:
|
|
logger.exception("Unhandled error in WebSocket handler for conversation=%s", conversation_id)
|