feat(streaming): update WebSocket handler to forward streaming chunks to browser

- Pub-sub loop now handles 'chunk' and 'done' message types (not just 'response')
- 'chunk' messages are forwarded immediately via websocket.send_json
- 'done' message breaks the loop and triggers DB persistence of full response
- Sends final 'done' JSON to browser to signal stream completion
- Legacy 'response' type no longer emitted from orchestrator (now unified as 'done')
This commit is contained in:
2026-03-25 17:57:08 -06:00
parent 5fb79beb76
commit 61b8762bac

View File

@@ -245,15 +245,22 @@ async def _handle_websocket_connection(
handle_message.delay(task_payload) handle_message.delay(task_payload)
# ------------------------------------------------------------------- # -------------------------------------------------------------------
# d. Subscribe to Redis pub-sub and wait for agent response # d. Subscribe to Redis pub-sub and forward streaming chunks to client
#
# The orchestrator publishes two message types to the response channel:
# {"type": "chunk", "text": "<token>"} — zero or more times (streaming)
# {"type": "done", "text": "<full>", "conversation_id": "..."} — final marker
#
# We forward each "chunk" immediately to the browser so text appears
# word-by-word. On "done" we save the full response to the DB.
# ------------------------------------------------------------------- # -------------------------------------------------------------------
response_channel = webchat_response_key(tenant_id_str, saved_conversation_id) response_channel = webchat_response_key(tenant_id_str, saved_conversation_id)
subscribe_redis = aioredis.from_url(settings.redis_url) subscribe_redis = aioredis.from_url(settings.redis_url)
response_text: str = ""
try: try:
pubsub = subscribe_redis.pubsub() pubsub = subscribe_redis.pubsub()
await pubsub.subscribe(response_channel) await pubsub.subscribe(response_channel)
response_text: str = ""
deadline = asyncio.get_event_loop().time() + _RESPONSE_TIMEOUT_SECONDS deadline = asyncio.get_event_loop().time() + _RESPONSE_TIMEOUT_SECONDS
while asyncio.get_event_loop().time() < deadline: while asyncio.get_event_loop().time() < deadline:
@@ -261,10 +268,31 @@ async def _handle_websocket_connection(
if message and message.get("type") == "message": if message and message.get("type") == "message":
try: try:
payload = json.loads(message["data"]) payload = json.loads(message["data"])
response_text = payload.get("text", "")
except (json.JSONDecodeError, KeyError): except (json.JSONDecodeError, KeyError):
pass await asyncio.sleep(0.01)
continue
msg_type = payload.get("type")
if msg_type == "chunk":
# Forward token immediately — do not break the loop
token = payload.get("text", "")
if token:
try:
await websocket.send_json({
"type": "chunk",
"text": token,
})
except Exception:
# Client disconnected mid-stream — exit cleanly
break break
elif msg_type == "done":
# Final marker — full text for DB persistence
response_text = payload.get("text", "")
break
else:
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
await pubsub.unsubscribe(response_channel) await pubsub.unsubscribe(response_channel)
@@ -272,7 +300,7 @@ async def _handle_websocket_connection(
await subscribe_redis.aclose() await subscribe_redis.aclose()
# ------------------------------------------------------------------- # -------------------------------------------------------------------
# e. Save assistant message and send response to client # e. Save assistant message and send final "done" to client
# ------------------------------------------------------------------- # -------------------------------------------------------------------
if response_text: if response_text:
rls_token2 = current_tenant_id.set(tenant_uuid) rls_token2 = current_tenant_id.set(tenant_uuid)
@@ -299,8 +327,9 @@ async def _handle_websocket_connection(
finally: finally:
current_tenant_id.reset(rls_token2) current_tenant_id.reset(rls_token2)
# Signal stream completion to the client
await websocket.send_json({ await websocket.send_json({
"type": "response", "type": "done",
"text": response_text, "text": response_text,
"conversation_id": saved_conversation_id, "conversation_id": saved_conversation_id,
}) })