feat(08-03): push notification backend — DB model, migration, API router, VAPID setup

- Add PushSubscription ORM model with unique(user_id, endpoint) constraint
- Add Alembic migration 012 for push_subscriptions table
- Add push router (subscribe, unsubscribe, send) in shared/api/push.py
- Mount push router in gateway/main.py
- Add pywebpush to gateway dependencies for server-side VAPID delivery
- Wire push trigger into WebSocket handler (fires when client disconnects mid-stream)
- Add VAPID keys to .env / .env.example
- Add push/install i18n keys in en/es/pt message files
This commit is contained in:
2026-03-25 21:26:51 -06:00
parent 5c30651754
commit 7d3a393758
9 changed files with 774 additions and 199 deletions

View File

@@ -27,6 +27,11 @@ Design notes:
- DB access uses configure_rls_hook + current_tenant_id context var per project pattern
- WebSocket is a long-lived connection; each message/response cycle is synchronous
within the connection but non-blocking for other connections
Push notifications:
- Connected users are tracked in _connected_users (in-memory dict)
- When the WebSocket send for "done" raises (client disconnected mid-stream),
a push notification is fired so the user sees the response on their device.
"""
from __future__ import annotations
@@ -40,7 +45,7 @@ from typing import Any
import redis.asyncio as aioredis
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from sqlalchemy import select, text
from sqlalchemy import delete, select, text
from orchestrator.agents.builder import build_messages_with_memory, build_system_prompt
from orchestrator.agents.runner import run_agent_streaming
@@ -64,6 +69,89 @@ web_chat_router = APIRouter(tags=["web-chat"])
# Timeout for waiting for an agent response via Redis pub-sub (seconds)
_RESPONSE_TIMEOUT_SECONDS = 180
# ---------------------------------------------------------------------------
# Connected user tracking — used to decide whether to send push notifications
# ---------------------------------------------------------------------------
# Maps user_id -> set of conversation_ids with active WebSocket connections.
# When a user disconnects, their entry is removed. If the agent response
# finishes after disconnect, a push notification is sent.
_connected_users: dict[str, set[str]] = {}
def _mark_connected(user_id: str, conversation_id: str) -> None:
"""Record that user_id has an active WebSocket for conversation_id."""
if user_id not in _connected_users:
_connected_users[user_id] = set()
_connected_users[user_id].add(conversation_id)
def _mark_disconnected(user_id: str, conversation_id: str) -> None:
"""Remove the active WebSocket record for user_id + conversation_id."""
if user_id in _connected_users:
_connected_users[user_id].discard(conversation_id)
if not _connected_users[user_id]:
del _connected_users[user_id]
def is_user_connected(user_id: str) -> bool:
"""Return True if the user has any active WebSocket connection."""
return user_id in _connected_users and bool(_connected_users[user_id])
async def _send_push_notification(
user_id: str,
title: str,
body: str,
conversation_id: str | None = None,
) -> None:
"""
Fire-and-forget push notification delivery.
Queries push_subscriptions for the user and calls pywebpush directly.
Deletes stale (410 Gone) subscriptions automatically.
Silently ignores errors — push is best-effort.
"""
from shared.models.push import PushSubscription
from shared.api.push import _send_push
try:
user_uuid = uuid.UUID(user_id)
payload = {
"title": title,
"body": body,
"data": {"conversationId": conversation_id},
}
async with async_session_factory() as session:
result = await session.execute(
select(PushSubscription).where(PushSubscription.user_id == user_uuid)
)
subscriptions = result.scalars().all()
if not subscriptions:
return
stale_endpoints: list[str] = []
for sub in subscriptions:
try:
ok = await _send_push(sub, payload)
if not ok:
stale_endpoints.append(sub.endpoint)
except Exception as exc:
logger.warning("Push delivery failed for user=%s: %s", user_id, exc)
if stale_endpoints:
await session.execute(
delete(PushSubscription).where(
PushSubscription.user_id == user_uuid,
PushSubscription.endpoint.in_(stale_endpoints),
)
)
await session.commit()
except Exception as exc:
logger.warning("Push notification send error for user=%s: %s", user_id, exc)
def normalize_web_event(event: dict[str, Any]) -> KonstructMessage:
"""
@@ -164,223 +252,248 @@ async def _handle_websocket_connection(
user_id_str, user_role, tenant_id_str, conversation_id,
)
# Track this user as connected (for push notification gating)
_mark_connected(user_id_str, conversation_id)
# -------------------------------------------------------------------------
# Step 2: Message loop
# -------------------------------------------------------------------------
while True:
try:
msg_data = await websocket.receive_json()
except (WebSocketDisconnect, Exception):
break
try:
while True:
try:
msg_data = await websocket.receive_json()
except (WebSocketDisconnect, Exception):
break
if msg_data.get("type") != "message":
continue
if msg_data.get("type") != "message":
continue
text_content: str = msg_data.get("text", "") or ""
agent_id_str: str = msg_data.get("agentId", "") or ""
msg_conversation_id: str = msg_data.get("conversationId", conversation_id) or conversation_id
display_name: str = msg_data.get("displayName", "Portal User")
text_content: str = msg_data.get("text", "") or ""
agent_id_str: str = msg_data.get("agentId", "") or ""
msg_conversation_id: str = msg_data.get("conversationId", conversation_id) or conversation_id
display_name: str = msg_data.get("displayName", "Portal User")
# -------------------------------------------------------------------
# a. Send typing indicator IMMEDIATELY — before any DB or Celery work
# -------------------------------------------------------------------
await websocket.send_json({"type": "typing"})
# -------------------------------------------------------------------
# a. Send typing indicator IMMEDIATELY — before any DB or Celery work
# -------------------------------------------------------------------
await websocket.send_json({"type": "typing"})
# -------------------------------------------------------------------
# b. Save user message to web_conversation_messages
# -------------------------------------------------------------------
configure_rls_hook(engine)
rls_token = current_tenant_id.set(tenant_uuid)
saved_conversation_id = msg_conversation_id
# -------------------------------------------------------------------
# b. Save user message to web_conversation_messages
# -------------------------------------------------------------------
configure_rls_hook(engine)
rls_token = current_tenant_id.set(tenant_uuid)
saved_conversation_id = msg_conversation_id
try:
async with async_session_factory() as session:
# Look up the conversation to get tenant-scoped context
conv_stmt = select(WebConversation).where(
WebConversation.id == uuid.UUID(msg_conversation_id)
try:
async with async_session_factory() as session:
# Look up the conversation to get tenant-scoped context
conv_stmt = select(WebConversation).where(
WebConversation.id == uuid.UUID(msg_conversation_id)
)
conv_result = await session.execute(conv_stmt)
conversation = conv_result.scalar_one_or_none()
if conversation is not None:
# Save user message
user_msg = WebConversationMessage(
conversation_id=uuid.UUID(msg_conversation_id),
tenant_id=tenant_uuid,
role="user",
content=text_content,
)
session.add(user_msg)
# Update conversation timestamp
await session.execute(
text(
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
),
{"conv_id": str(msg_conversation_id)},
)
await session.commit()
saved_conversation_id = msg_conversation_id
except Exception:
logger.exception(
"Failed to save user message for conversation=%s", msg_conversation_id
)
conv_result = await session.execute(conv_stmt)
conversation = conv_result.scalar_one_or_none()
finally:
current_tenant_id.reset(rls_token)
if conversation is not None:
# Save user message
user_msg = WebConversationMessage(
conversation_id=uuid.UUID(msg_conversation_id),
tenant_id=tenant_uuid,
role="user",
content=text_content,
)
session.add(user_msg)
# Update conversation timestamp
await session.execute(
text(
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
),
{"conv_id": str(msg_conversation_id)},
)
await session.commit()
saved_conversation_id = msg_conversation_id
except Exception:
logger.exception(
"Failed to save user message for conversation=%s", msg_conversation_id
)
finally:
current_tenant_id.reset(rls_token)
# -------------------------------------------------------------------
# c. Build KonstructMessage and stream LLM response DIRECTLY
#
# Bypasses Celery entirely for web chat — calls the LLM pool's
# streaming endpoint from the WebSocket handler. This eliminates
# ~5-10s of Celery queue + Redis pub-sub round-trip overhead.
# Slack/WhatsApp still use Celery (async webhook pattern).
# -------------------------------------------------------------------
event = {
"text": text_content,
"tenant_id": tenant_id_str,
"agent_id": agent_id_str,
"user_id": user_id_str,
"display_name": display_name,
"conversation_id": saved_conversation_id,
}
normalized_msg = normalize_web_event(event)
# Load agent for this tenant
agent: Agent | None = None
rls_token3 = current_tenant_id.set(tenant_uuid)
try:
async with async_session_factory() as session:
from sqlalchemy import select as sa_select
agent_stmt = sa_select(Agent).where(
Agent.tenant_id == tenant_uuid,
Agent.is_active == True,
).limit(1)
agent_result = await session.execute(agent_stmt)
agent = agent_result.scalar_one_or_none()
finally:
current_tenant_id.reset(rls_token3)
if agent is None:
await websocket.send_json({
"type": "done",
"text": "No active AI employee is configured for this workspace.",
"conversation_id": saved_conversation_id,
})
continue
# Build memory-enriched messages (Redis sliding window only — fast)
redis_mem = aioredis.from_url(settings.redis_url)
try:
recent_messages = await get_recent_messages(
redis_mem, tenant_id_str, str(agent.id), user_id_str
)
finally:
await redis_mem.aclose()
enriched_messages = build_messages_with_memory(
agent=agent,
current_message=text_content,
recent_messages=recent_messages,
relevant_context=[],
channel="web",
)
# Stream LLM response directly to WebSocket — no Celery, no pub-sub
response_text = ""
try:
async for token in run_agent_streaming(
msg=normalized_msg,
agent=agent,
messages=enriched_messages,
):
response_text += token
try:
await websocket.send_json({"type": "chunk", "text": token})
except Exception:
break # Client disconnected
except Exception:
logger.exception("Direct streaming failed for conversation=%s", saved_conversation_id)
if not response_text:
response_text = "I encountered an error processing your message. Please try again."
# Save to Redis sliding window (fire-and-forget, non-blocking)
redis_mem2 = aioredis.from_url(settings.redis_url)
try:
await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "user", text_content)
if response_text:
await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "assistant", response_text)
finally:
await redis_mem2.aclose()
# Fire-and-forget embedding for long-term memory
try:
embed_and_store.delay({
# -------------------------------------------------------------------
# c. Build KonstructMessage and stream LLM response DIRECTLY
#
# Bypasses Celery entirely for web chat — calls the LLM pool's
# streaming endpoint from the WebSocket handler. This eliminates
# ~5-10s of Celery queue + Redis pub-sub round-trip overhead.
# Slack/WhatsApp still use Celery (async webhook pattern).
# -------------------------------------------------------------------
event = {
"text": text_content,
"tenant_id": tenant_id_str,
"agent_id": str(agent.id),
"agent_id": agent_id_str,
"user_id": user_id_str,
"role": "user",
"content": text_content,
})
if response_text:
"display_name": display_name,
"conversation_id": saved_conversation_id,
}
normalized_msg = normalize_web_event(event)
# Load agent for this tenant
agent: Agent | None = None
rls_token3 = current_tenant_id.set(tenant_uuid)
try:
async with async_session_factory() as session:
from sqlalchemy import select as sa_select
agent_stmt = sa_select(Agent).where(
Agent.tenant_id == tenant_uuid,
Agent.is_active == True,
).limit(1)
agent_result = await session.execute(agent_stmt)
agent = agent_result.scalar_one_or_none()
finally:
current_tenant_id.reset(rls_token3)
if agent is None:
await websocket.send_json({
"type": "done",
"text": "No active AI employee is configured for this workspace.",
"conversation_id": saved_conversation_id,
})
continue
# Build memory-enriched messages (Redis sliding window only — fast)
redis_mem = aioredis.from_url(settings.redis_url)
try:
recent_messages = await get_recent_messages(
redis_mem, tenant_id_str, str(agent.id), user_id_str
)
finally:
await redis_mem.aclose()
enriched_messages = build_messages_with_memory(
agent=agent,
current_message=text_content,
recent_messages=recent_messages,
relevant_context=[],
channel="web",
)
# Stream LLM response directly to WebSocket — no Celery, no pub-sub
response_text = ""
ws_disconnected_during_stream = False
try:
async for token in run_agent_streaming(
msg=normalized_msg,
agent=agent,
messages=enriched_messages,
):
response_text += token
try:
await websocket.send_json({"type": "chunk", "text": token})
except Exception:
ws_disconnected_during_stream = True
break # Client disconnected
except Exception:
logger.exception("Direct streaming failed for conversation=%s", saved_conversation_id)
if not response_text:
response_text = "I encountered an error processing your message. Please try again."
# Save to Redis sliding window (fire-and-forget, non-blocking)
redis_mem2 = aioredis.from_url(settings.redis_url)
try:
await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "user", text_content)
if response_text:
await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "assistant", response_text)
finally:
await redis_mem2.aclose()
# Fire-and-forget embedding for long-term memory
try:
embed_and_store.delay({
"tenant_id": tenant_id_str,
"agent_id": str(agent.id),
"user_id": user_id_str,
"role": "assistant",
"content": response_text,
"role": "user",
"content": text_content,
})
except Exception:
pass # Non-fatal — memory will rebuild over time
# -------------------------------------------------------------------
# e. Save assistant message and send final "done" to client
# -------------------------------------------------------------------
if response_text:
rls_token2 = current_tenant_id.set(tenant_uuid)
try:
async with async_session_factory() as session:
assistant_msg = WebConversationMessage(
conversation_id=uuid.UUID(saved_conversation_id),
tenant_id=tenant_uuid,
role="assistant",
content=response_text,
)
session.add(assistant_msg)
await session.execute(
text(
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
),
{"conv_id": str(saved_conversation_id)},
)
await session.commit()
if response_text:
embed_and_store.delay({
"tenant_id": tenant_id_str,
"agent_id": str(agent.id),
"user_id": user_id_str,
"role": "assistant",
"content": response_text,
})
except Exception:
logger.exception(
"Failed to save assistant message for conversation=%s", saved_conversation_id
pass # Non-fatal — memory will rebuild over time
# -------------------------------------------------------------------
# e. Save assistant message and send final "done" to client
# -------------------------------------------------------------------
if response_text:
rls_token2 = current_tenant_id.set(tenant_uuid)
try:
async with async_session_factory() as session:
assistant_msg = WebConversationMessage(
conversation_id=uuid.UUID(saved_conversation_id),
tenant_id=tenant_uuid,
role="assistant",
content=response_text,
)
session.add(assistant_msg)
await session.execute(
text(
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
),
{"conv_id": str(saved_conversation_id)},
)
await session.commit()
except Exception:
logger.exception(
"Failed to save assistant message for conversation=%s", saved_conversation_id
)
finally:
current_tenant_id.reset(rls_token2)
# If user disconnected during streaming, send push notification
if ws_disconnected_during_stream or not is_user_connected(user_id_str):
agent_name = agent.name if hasattr(agent, "name") and agent.name else "Your AI Employee"
preview = response_text[:100] + ("..." if len(response_text) > 100 else "")
asyncio.create_task(
_send_push_notification(
user_id=user_id_str,
title=f"{agent_name} replied",
body=preview,
conversation_id=saved_conversation_id,
)
)
if ws_disconnected_during_stream:
break # Stop the message loop — WS is gone
# Signal stream completion to the client
try:
await websocket.send_json({
"type": "done",
"text": response_text,
"conversation_id": saved_conversation_id,
})
except Exception:
pass # Client already disconnected
else:
logger.warning(
"No response received for conversation=%s", saved_conversation_id,
)
finally:
current_tenant_id.reset(rls_token2)
try:
await websocket.send_json({
"type": "error",
"message": "I'm having trouble responding right now. Please try again.",
})
except Exception:
pass # Client already disconnected
# Signal stream completion to the client
try:
await websocket.send_json({
"type": "done",
"text": response_text,
"conversation_id": saved_conversation_id,
})
except Exception:
pass # Client already disconnected
else:
logger.warning(
"No response received for conversation=%s", saved_conversation_id,
)
try:
await websocket.send_json({
"type": "error",
"message": "I'm having trouble responding right now. Please try again.",
})
except Exception:
pass # Client already disconnected
finally:
# Always untrack this user when connection ends
_mark_disconnected(user_id_str, conversation_id)
@web_chat_router.websocket("/chat/ws/{conversation_id}")

View File

@@ -48,6 +48,7 @@ from shared.api import (
invitations_router,
llm_keys_router,
portal_router,
push_router,
templates_router,
usage_router,
webhook_router,
@@ -158,6 +159,11 @@ app.include_router(templates_router)
app.include_router(chat_router) # REST: /api/portal/chat/*
app.include_router(web_chat_router) # WebSocket: /chat/ws/{conversation_id}
# ---------------------------------------------------------------------------
# Phase 8 Push Notification router
# ---------------------------------------------------------------------------
app.include_router(push_router) # Push subscribe/unsubscribe/send
# ---------------------------------------------------------------------------
# Routes

View File

@@ -18,6 +18,7 @@ dependencies = [
"httpx>=0.28.0",
"redis>=5.0.0",
"boto3>=1.35.0",
"pywebpush>=2.0.0",
]
[tool.uv.sources]