perf: bypass Celery for web chat — stream LLM directly from WebSocket

Eliminates 5-10s of overhead by calling the LLM pool's streaming
endpoint directly from the WebSocket handler instead of going through
Celery queue → worker → asyncio.run() → Redis pub-sub → WebSocket.

New flow: WebSocket → agent lookup → memory → LLM stream → WebSocket
Old flow: WebSocket → Celery → worker → DB → memory → LLM → Redis → WebSocket

Memory still saved (Redis sliding window + fire-and-forget embedding).
Slack/WhatsApp still use Celery (async webhook pattern).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-25 18:32:16 -06:00
parent 2116059157
commit dd80e2b822

View File

@@ -42,11 +42,15 @@ import redis.asyncio as aioredis
from fastapi import APIRouter, WebSocket, WebSocketDisconnect from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from sqlalchemy import select, text from sqlalchemy import select, text
from orchestrator.tasks import handle_message from orchestrator.agents.builder import build_messages_with_memory, build_system_prompt
from orchestrator.agents.runner import run_agent_streaming
from orchestrator.memory.short_term import get_recent_messages, append_message
from orchestrator.tasks import handle_message, embed_and_store
from shared.config import settings from shared.config import settings
from shared.db import async_session_factory, engine from shared.db import async_session_factory, engine
from shared.models.chat import WebConversation, WebConversationMessage from shared.models.chat import WebConversation, WebConversationMessage
from shared.models.message import ChannelType, KonstructMessage, MessageContent, SenderInfo from shared.models.message import ChannelType, KonstructMessage, MessageContent, SenderInfo
from shared.models.tenant import Agent
from shared.redis_keys import webchat_response_key from shared.redis_keys import webchat_response_key
from shared.rls import configure_rls_hook, current_tenant_id from shared.rls import configure_rls_hook, current_tenant_id
@@ -225,7 +229,12 @@ async def _handle_websocket_connection(
current_tenant_id.reset(rls_token) current_tenant_id.reset(rls_token)
# ------------------------------------------------------------------- # -------------------------------------------------------------------
# c. Normalize and dispatch to Celery pipeline # c. Build KonstructMessage and stream LLM response DIRECTLY
#
# Bypasses Celery entirely for web chat — calls the LLM pool's
# streaming endpoint from the WebSocket handler. This eliminates
# ~5-10s of Celery queue + Redis pub-sub round-trip overhead.
# Slack/WhatsApp still use Celery (async webhook pattern).
# ------------------------------------------------------------------- # -------------------------------------------------------------------
event = { event = {
"text": text_content, "text": text_content,
@@ -237,67 +246,92 @@ async def _handle_websocket_connection(
} }
normalized_msg = normalize_web_event(event) normalized_msg = normalize_web_event(event)
extras = { # Load agent for this tenant
agent: Agent | None = None
rls_token3 = current_tenant_id.set(tenant_uuid)
try:
async with async_session_factory() as session:
from sqlalchemy import select as sa_select
agent_stmt = sa_select(Agent).where(
Agent.tenant_id == tenant_uuid,
Agent.is_active == True,
).limit(1)
agent_result = await session.execute(agent_stmt)
agent = agent_result.scalar_one_or_none()
finally:
current_tenant_id.reset(rls_token3)
if agent is None:
await websocket.send_json({
"type": "done",
"text": "No active AI employee is configured for this workspace.",
"conversation_id": saved_conversation_id, "conversation_id": saved_conversation_id,
"portal_user_id": user_id_str, })
}
task_payload = normalized_msg.model_dump(mode="json") | extras
handle_message.delay(task_payload)
# -------------------------------------------------------------------
# d. Subscribe to Redis pub-sub and forward streaming chunks to client
#
# The orchestrator publishes two message types to the response channel:
# {"type": "chunk", "text": "<token>"} — zero or more times (streaming)
# {"type": "done", "text": "<full>", "conversation_id": "..."} — final marker
#
# We forward each "chunk" immediately to the browser so text appears
# word-by-word. On "done" we save the full response to the DB.
# -------------------------------------------------------------------
response_channel = webchat_response_key(tenant_id_str, saved_conversation_id)
subscribe_redis = aioredis.from_url(settings.redis_url)
response_text: str = ""
try:
pubsub = subscribe_redis.pubsub()
await pubsub.subscribe(response_channel)
deadline = asyncio.get_event_loop().time() + _RESPONSE_TIMEOUT_SECONDS
while asyncio.get_event_loop().time() < deadline:
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if message and message.get("type") == "message":
try:
payload = json.loads(message["data"])
except (json.JSONDecodeError, KeyError):
await asyncio.sleep(0.01)
continue continue
msg_type = payload.get("type") # Build memory-enriched messages (Redis sliding window only — fast)
redis_mem = aioredis.from_url(settings.redis_url)
if msg_type == "chunk":
# Forward token immediately — do not break the loop
token = payload.get("text", "")
if token:
try: try:
await websocket.send_json({ recent_messages = await get_recent_messages(
"type": "chunk", redis_mem, tenant_id_str, str(agent.id), user_id_str
"text": token, )
finally:
await redis_mem.aclose()
enriched_messages = build_messages_with_memory(
agent=agent,
current_message=text_content,
recent_messages=recent_messages,
relevant_context=[],
channel="web",
)
# Stream LLM response directly to WebSocket — no Celery, no pub-sub
response_text = ""
try:
async for token in run_agent_streaming(
msg=normalized_msg,
agent=agent,
messages=enriched_messages,
):
response_text += token
try:
await websocket.send_json({"type": "chunk", "text": token})
except Exception:
break # Client disconnected
except Exception:
logger.exception("Direct streaming failed for conversation=%s", saved_conversation_id)
if not response_text:
response_text = "I encountered an error processing your message. Please try again."
# Save to Redis sliding window (fire-and-forget, non-blocking)
redis_mem2 = aioredis.from_url(settings.redis_url)
try:
await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "user", text_content)
if response_text:
await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "assistant", response_text)
finally:
await redis_mem2.aclose()
# Fire-and-forget embedding for long-term memory
try:
embed_and_store.delay({
"tenant_id": tenant_id_str,
"agent_id": str(agent.id),
"user_id": user_id_str,
"role": "user",
"content": text_content,
})
if response_text:
embed_and_store.delay({
"tenant_id": tenant_id_str,
"agent_id": str(agent.id),
"user_id": user_id_str,
"role": "assistant",
"content": response_text,
}) })
except Exception: except Exception:
# Client disconnected mid-stream — exit cleanly pass # Non-fatal — memory will rebuild over time
break
elif msg_type == "done":
# Final marker — full text for DB persistence
response_text = payload.get("text", "")
break
else:
await asyncio.sleep(0.05)
await pubsub.unsubscribe(response_channel)
finally:
await subscribe_redis.aclose()
# ------------------------------------------------------------------- # -------------------------------------------------------------------
# e. Save assistant message and send final "done" to client # e. Save assistant message and send final "done" to client