Files
konstruct/packages/gateway/gateway/channels/web.py
Adolfo Delorenzo dd80e2b822 perf: bypass Celery for web chat — stream LLM directly from WebSocket
Eliminates 5-10s of overhead by calling the LLM pool's streaming
endpoint directly from the WebSocket handler instead of going through
Celery queue → worker → asyncio.run() → Redis pub-sub → WebSocket.

New flow: WebSocket → agent lookup → memory → LLM stream → WebSocket
Old flow: WebSocket → Celery → worker → DB → memory → LLM → Redis → WebSocket

Memory still saved (Redis sliding window + fire-and-forget embedding).
Slack/WhatsApp still use Celery (async webhook pattern).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 18:32:16 -06:00

404 lines
16 KiB
Python

"""
Web Channel Adapter — WebSocket endpoint and message normalizer.
The web channel lets portal users chat with AI employees directly from
the Konstruct portal UI. Messages flow through the same agent pipeline
as Slack and WhatsApp — the only difference is the transport layer.
Message flow:
1. Browser opens WebSocket at /chat/ws/{conversation_id}
2. Client sends {"type": "auth", "userId": ..., "role": ..., "tenantId": ...}
NOTE: Browsers cannot set custom HTTP headers on WebSocket connections,
so auth credentials are sent as the first JSON message (Pitfall 1).
3. For each user message (type="message"):
a. Server immediately sends {"type": "typing"} to client (CHAT-05)
b. normalize_web_event() converts to KonstructMessage (channel=WEB)
c. User message saved to web_conversation_messages
d. handle_message.delay(msg | extras) dispatches to Celery pipeline
e. Server subscribes to Redis pub-sub channel for the response
f. When orchestrator publishes the response:
- Save assistant message to web_conversation_messages
- Send {"type": "response", "text": ..., "conversation_id": ...} to client
4. On disconnect: unsubscribe and close all Redis connections
Design notes:
- thread_id = conversation_id — scopes agent memory to one conversation (Pitfall 3)
- Redis pub-sub connections closed in try/finally to prevent leaks (Pitfall 2)
- DB access uses configure_rls_hook + current_tenant_id context var per project pattern
- WebSocket is a long-lived connection; each message/response cycle is synchronous
within the connection but non-blocking for other connections
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from datetime import datetime, timezone
from typing import Any
import redis.asyncio as aioredis
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from sqlalchemy import select, text
from orchestrator.agents.builder import build_messages_with_memory, build_system_prompt
from orchestrator.agents.runner import run_agent_streaming
from orchestrator.memory.short_term import get_recent_messages, append_message
from orchestrator.tasks import handle_message, embed_and_store
from shared.config import settings
from shared.db import async_session_factory, engine
from shared.models.chat import WebConversation, WebConversationMessage
from shared.models.message import ChannelType, KonstructMessage, MessageContent, SenderInfo
from shared.models.tenant import Agent
from shared.redis_keys import webchat_response_key
from shared.rls import configure_rls_hook, current_tenant_id
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Router — mounted in gateway/main.py
# ---------------------------------------------------------------------------
web_chat_router = APIRouter(tags=["web-chat"])
# Timeout for waiting for an agent response via Redis pub-sub (seconds)
_RESPONSE_TIMEOUT_SECONDS = 180
def normalize_web_event(event: dict[str, Any]) -> KonstructMessage:
"""
Normalize a web channel event dict into a KonstructMessage.
The web channel normalizer sets thread_id = conversation_id so that
the agent memory pipeline scopes context to this conversation (Pitfall 3).
Args:
event: Dict with keys: text, tenant_id, agent_id, user_id,
display_name, conversation_id.
Returns:
KonstructMessage with channel=WEB, thread_id=conversation_id.
"""
tenant_id: str = event.get("tenant_id", "") or ""
user_id: str = event.get("user_id", "") or ""
display_name: str = event.get("display_name", "Portal User")
conversation_id: str = event.get("conversation_id", "") or ""
text_content: str = event.get("text", "") or ""
return KonstructMessage(
id=str(uuid.uuid4()),
tenant_id=tenant_id,
channel=ChannelType.WEB,
channel_metadata={
"portal_user_id": user_id,
"tenant_id": tenant_id,
"conversation_id": conversation_id,
},
sender=SenderInfo(
user_id=user_id,
display_name=display_name,
),
content=MessageContent(
text=text_content,
),
timestamp=datetime.now(timezone.utc),
thread_id=conversation_id,
reply_to=None,
context={},
)
async def _handle_websocket_connection(
websocket: WebSocket,
conversation_id: str,
) -> None:
"""
Core WebSocket connection handler — separated for testability.
Lifecycle:
1. Accept connection
2. Wait for auth message (browser cannot send custom headers)
3. Loop: receive messages → type indicator → Celery dispatch → Redis subscribe → response
Args:
websocket: The FastAPI WebSocket connection.
conversation_id: The conversation UUID from the URL path.
"""
await websocket.accept()
# -------------------------------------------------------------------------
# Step 1: Auth handshake
# Browsers cannot send custom HTTP headers on WebSocket connections.
# Auth credentials are sent as the first JSON message.
# -------------------------------------------------------------------------
try:
auth_msg = await websocket.receive_json()
except WebSocketDisconnect:
return
if auth_msg.get("type") != "auth":
await websocket.send_json({"type": "error", "message": "First message must be auth"})
await websocket.close(code=4001)
return
user_id_str: str = auth_msg.get("userId", "") or ""
user_role: str = auth_msg.get("role", "") or ""
tenant_id_str: str = auth_msg.get("tenantId", "") or ""
if not user_id_str or not tenant_id_str:
await websocket.send_json({"type": "error", "message": "Missing userId or tenantId in auth"})
await websocket.close(code=4001)
return
# Validate UUID format
try:
uuid.UUID(user_id_str)
tenant_uuid = uuid.UUID(tenant_id_str)
except (ValueError, AttributeError):
await websocket.send_json({"type": "error", "message": "Invalid UUID format in auth"})
await websocket.close(code=4001)
return
logger.info(
"WebSocket auth: user=%s role=%s tenant=%s conversation=%s",
user_id_str, user_role, tenant_id_str, conversation_id,
)
# -------------------------------------------------------------------------
# Step 2: Message loop
# -------------------------------------------------------------------------
while True:
try:
msg_data = await websocket.receive_json()
except (WebSocketDisconnect, Exception):
break
if msg_data.get("type") != "message":
continue
text_content: str = msg_data.get("text", "") or ""
agent_id_str: str = msg_data.get("agentId", "") or ""
msg_conversation_id: str = msg_data.get("conversationId", conversation_id) or conversation_id
display_name: str = msg_data.get("displayName", "Portal User")
# -------------------------------------------------------------------
# a. Send typing indicator IMMEDIATELY — before any DB or Celery work
# -------------------------------------------------------------------
await websocket.send_json({"type": "typing"})
# -------------------------------------------------------------------
# b. Save user message to web_conversation_messages
# -------------------------------------------------------------------
configure_rls_hook(engine)
rls_token = current_tenant_id.set(tenant_uuid)
saved_conversation_id = msg_conversation_id
try:
async with async_session_factory() as session:
# Look up the conversation to get tenant-scoped context
conv_stmt = select(WebConversation).where(
WebConversation.id == uuid.UUID(msg_conversation_id)
)
conv_result = await session.execute(conv_stmt)
conversation = conv_result.scalar_one_or_none()
if conversation is not None:
# Save user message
user_msg = WebConversationMessage(
conversation_id=uuid.UUID(msg_conversation_id),
tenant_id=tenant_uuid,
role="user",
content=text_content,
)
session.add(user_msg)
# Update conversation timestamp
await session.execute(
text(
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
),
{"conv_id": str(msg_conversation_id)},
)
await session.commit()
saved_conversation_id = msg_conversation_id
except Exception:
logger.exception(
"Failed to save user message for conversation=%s", msg_conversation_id
)
finally:
current_tenant_id.reset(rls_token)
# -------------------------------------------------------------------
# c. Build KonstructMessage and stream LLM response DIRECTLY
#
# Bypasses Celery entirely for web chat — calls the LLM pool's
# streaming endpoint from the WebSocket handler. This eliminates
# ~5-10s of Celery queue + Redis pub-sub round-trip overhead.
# Slack/WhatsApp still use Celery (async webhook pattern).
# -------------------------------------------------------------------
event = {
"text": text_content,
"tenant_id": tenant_id_str,
"agent_id": agent_id_str,
"user_id": user_id_str,
"display_name": display_name,
"conversation_id": saved_conversation_id,
}
normalized_msg = normalize_web_event(event)
# Load agent for this tenant
agent: Agent | None = None
rls_token3 = current_tenant_id.set(tenant_uuid)
try:
async with async_session_factory() as session:
from sqlalchemy import select as sa_select
agent_stmt = sa_select(Agent).where(
Agent.tenant_id == tenant_uuid,
Agent.is_active == True,
).limit(1)
agent_result = await session.execute(agent_stmt)
agent = agent_result.scalar_one_or_none()
finally:
current_tenant_id.reset(rls_token3)
if agent is None:
await websocket.send_json({
"type": "done",
"text": "No active AI employee is configured for this workspace.",
"conversation_id": saved_conversation_id,
})
continue
# Build memory-enriched messages (Redis sliding window only — fast)
redis_mem = aioredis.from_url(settings.redis_url)
try:
recent_messages = await get_recent_messages(
redis_mem, tenant_id_str, str(agent.id), user_id_str
)
finally:
await redis_mem.aclose()
enriched_messages = build_messages_with_memory(
agent=agent,
current_message=text_content,
recent_messages=recent_messages,
relevant_context=[],
channel="web",
)
# Stream LLM response directly to WebSocket — no Celery, no pub-sub
response_text = ""
try:
async for token in run_agent_streaming(
msg=normalized_msg,
agent=agent,
messages=enriched_messages,
):
response_text += token
try:
await websocket.send_json({"type": "chunk", "text": token})
except Exception:
break # Client disconnected
except Exception:
logger.exception("Direct streaming failed for conversation=%s", saved_conversation_id)
if not response_text:
response_text = "I encountered an error processing your message. Please try again."
# Save to Redis sliding window (fire-and-forget, non-blocking)
redis_mem2 = aioredis.from_url(settings.redis_url)
try:
await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "user", text_content)
if response_text:
await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "assistant", response_text)
finally:
await redis_mem2.aclose()
# Fire-and-forget embedding for long-term memory
try:
embed_and_store.delay({
"tenant_id": tenant_id_str,
"agent_id": str(agent.id),
"user_id": user_id_str,
"role": "user",
"content": text_content,
})
if response_text:
embed_and_store.delay({
"tenant_id": tenant_id_str,
"agent_id": str(agent.id),
"user_id": user_id_str,
"role": "assistant",
"content": response_text,
})
except Exception:
pass # Non-fatal — memory will rebuild over time
# -------------------------------------------------------------------
# e. Save assistant message and send final "done" to client
# -------------------------------------------------------------------
if response_text:
rls_token2 = current_tenant_id.set(tenant_uuid)
try:
async with async_session_factory() as session:
assistant_msg = WebConversationMessage(
conversation_id=uuid.UUID(saved_conversation_id),
tenant_id=tenant_uuid,
role="assistant",
content=response_text,
)
session.add(assistant_msg)
await session.execute(
text(
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
),
{"conv_id": str(saved_conversation_id)},
)
await session.commit()
except Exception:
logger.exception(
"Failed to save assistant message for conversation=%s", saved_conversation_id
)
finally:
current_tenant_id.reset(rls_token2)
# Signal stream completion to the client
await websocket.send_json({
"type": "done",
"text": response_text,
"conversation_id": saved_conversation_id,
})
else:
logger.warning(
"No response received within %ds for conversation=%s",
_RESPONSE_TIMEOUT_SECONDS,
saved_conversation_id,
)
await websocket.send_json({
"type": "error",
"message": "Agent did not respond in time. Please try again.",
})
@web_chat_router.websocket("/chat/ws/{conversation_id}")
async def chat_websocket(websocket: WebSocket, conversation_id: str) -> None:
"""
WebSocket endpoint for web chat.
URL: /chat/ws/{conversation_id}
Protocol:
1. Connect
2. Send: {"type": "auth", "userId": "...", "role": "...", "tenantId": "..."}
3. Send: {"type": "message", "text": "...", "agentId": "...", "conversationId": "..."}
4. Receive: {"type": "typing"}
5. Receive: {"type": "response", "text": "...", "conversation_id": "..."}
Closes with code 4001 on auth failure.
"""
try:
await _handle_websocket_connection(websocket, conversation_id)
except WebSocketDisconnect:
logger.info("WebSocket disconnected for conversation=%s", conversation_id)
except Exception:
logger.exception("Unhandled error in WebSocket handler for conversation=%s", conversation_id)