feat(08-03): push notification backend — DB model, migration, API router, VAPID setup

- Add PushSubscription ORM model with unique(user_id, endpoint) constraint
- Add Alembic migration 012 for push_subscriptions table
- Add push router (subscribe, unsubscribe, send) in shared/api/push.py
- Mount push router in gateway/main.py
- Add pywebpush to gateway dependencies for server-side VAPID delivery
- Wire push trigger into WebSocket handler (fires when client disconnects mid-stream)
- Add VAPID keys to .env / .env.example
- Add push/install i18n keys in en/es/pt message files
This commit is contained in:
2026-03-25 21:26:51 -06:00
parent 5c30651754
commit 7d3a393758
9 changed files with 774 additions and 199 deletions

View File

@@ -61,3 +61,11 @@ DEBUG=false
# Tenant rate limits (requests per minute defaults)
DEFAULT_RATE_LIMIT_RPM=60
# -----------------------------------------------------------------------------
# Web Push Notifications (VAPID keys)
# Generate with: cd packages/portal && npx web-push generate-vapid-keys
# -----------------------------------------------------------------------------
NEXT_PUBLIC_VAPID_PUBLIC_KEY=your-vapid-public-key
VAPID_PRIVATE_KEY=your-vapid-private-key
VAPID_CLAIMS_EMAIL=admin@yourdomain.com

View File

@@ -0,0 +1,91 @@
"""Push subscriptions table for Web Push notifications
Revision ID: 012
Revises: 011
Create Date: 2026-03-26
Creates the push_subscriptions table so the gateway can store browser
push subscriptions and deliver Web Push notifications when an AI employee
responds and the user's WebSocket is not connected.
No RLS policy is applied — the API filters by user_id at the application
layer (push subscriptions are portal-user-scoped, not tenant-scoped).
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "012"
down_revision: Union[str, None] = "011"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"push_subscriptions",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("portal_users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"tenant_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("tenants.id", ondelete="SET NULL"),
nullable=True,
comment="Optional tenant scope for notification routing",
),
sa.Column(
"endpoint",
sa.Text,
nullable=False,
comment="Push service URL (browser-provided)",
),
sa.Column(
"p256dh",
sa.Text,
nullable=False,
comment="ECDH public key for payload encryption",
),
sa.Column(
"auth",
sa.Text,
nullable=False,
comment="Auth secret for payload encryption",
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_id", "endpoint", name="uq_push_user_endpoint"),
)
op.create_index("ix_push_subscriptions_user_id", "push_subscriptions", ["user_id"])
op.create_index("ix_push_subscriptions_tenant_id", "push_subscriptions", ["tenant_id"])
def downgrade() -> None:
op.drop_index("ix_push_subscriptions_tenant_id", table_name="push_subscriptions")
op.drop_index("ix_push_subscriptions_user_id", table_name="push_subscriptions")
op.drop_table("push_subscriptions")

View File

@@ -27,6 +27,11 @@ Design notes:
- DB access uses configure_rls_hook + current_tenant_id context var per project pattern
- WebSocket is a long-lived connection; each message/response cycle is synchronous
within the connection but non-blocking for other connections
Push notifications:
- Connected users are tracked in _connected_users (in-memory dict)
- When the WebSocket send for "done" raises (client disconnected mid-stream),
a push notification is fired so the user sees the response on their device.
"""
from __future__ import annotations
@@ -40,7 +45,7 @@ from typing import Any
import redis.asyncio as aioredis
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from sqlalchemy import select, text
from sqlalchemy import delete, select, text
from orchestrator.agents.builder import build_messages_with_memory, build_system_prompt
from orchestrator.agents.runner import run_agent_streaming
@@ -64,6 +69,89 @@ web_chat_router = APIRouter(tags=["web-chat"])
# Timeout for waiting for an agent response via Redis pub-sub (seconds)
_RESPONSE_TIMEOUT_SECONDS = 180
# ---------------------------------------------------------------------------
# Connected user tracking — used to decide whether to send push notifications
# ---------------------------------------------------------------------------
# Maps user_id -> set of conversation_ids with active WebSocket connections.
# When a user disconnects, their entry is removed. If the agent response
# finishes after disconnect, a push notification is sent.
_connected_users: dict[str, set[str]] = {}
def _mark_connected(user_id: str, conversation_id: str) -> None:
"""Record that user_id has an active WebSocket for conversation_id."""
if user_id not in _connected_users:
_connected_users[user_id] = set()
_connected_users[user_id].add(conversation_id)
def _mark_disconnected(user_id: str, conversation_id: str) -> None:
"""Remove the active WebSocket record for user_id + conversation_id."""
if user_id in _connected_users:
_connected_users[user_id].discard(conversation_id)
if not _connected_users[user_id]:
del _connected_users[user_id]
def is_user_connected(user_id: str) -> bool:
"""Return True if the user has any active WebSocket connection."""
return user_id in _connected_users and bool(_connected_users[user_id])
async def _send_push_notification(
user_id: str,
title: str,
body: str,
conversation_id: str | None = None,
) -> None:
"""
Fire-and-forget push notification delivery.
Queries push_subscriptions for the user and calls pywebpush directly.
Deletes stale (410 Gone) subscriptions automatically.
Silently ignores errors — push is best-effort.
"""
from shared.models.push import PushSubscription
from shared.api.push import _send_push
try:
user_uuid = uuid.UUID(user_id)
payload = {
"title": title,
"body": body,
"data": {"conversationId": conversation_id},
}
async with async_session_factory() as session:
result = await session.execute(
select(PushSubscription).where(PushSubscription.user_id == user_uuid)
)
subscriptions = result.scalars().all()
if not subscriptions:
return
stale_endpoints: list[str] = []
for sub in subscriptions:
try:
ok = await _send_push(sub, payload)
if not ok:
stale_endpoints.append(sub.endpoint)
except Exception as exc:
logger.warning("Push delivery failed for user=%s: %s", user_id, exc)
if stale_endpoints:
await session.execute(
delete(PushSubscription).where(
PushSubscription.user_id == user_uuid,
PushSubscription.endpoint.in_(stale_endpoints),
)
)
await session.commit()
except Exception as exc:
logger.warning("Push notification send error for user=%s: %s", user_id, exc)
def normalize_web_event(event: dict[str, Any]) -> KonstructMessage:
"""
@@ -164,223 +252,248 @@ async def _handle_websocket_connection(
user_id_str, user_role, tenant_id_str, conversation_id,
)
# Track this user as connected (for push notification gating)
_mark_connected(user_id_str, conversation_id)
# -------------------------------------------------------------------------
# Step 2: Message loop
# -------------------------------------------------------------------------
while True:
try:
msg_data = await websocket.receive_json()
except (WebSocketDisconnect, Exception):
break
try:
while True:
try:
msg_data = await websocket.receive_json()
except (WebSocketDisconnect, Exception):
break
if msg_data.get("type") != "message":
continue
if msg_data.get("type") != "message":
continue
text_content: str = msg_data.get("text", "") or ""
agent_id_str: str = msg_data.get("agentId", "") or ""
msg_conversation_id: str = msg_data.get("conversationId", conversation_id) or conversation_id
display_name: str = msg_data.get("displayName", "Portal User")
text_content: str = msg_data.get("text", "") or ""
agent_id_str: str = msg_data.get("agentId", "") or ""
msg_conversation_id: str = msg_data.get("conversationId", conversation_id) or conversation_id
display_name: str = msg_data.get("displayName", "Portal User")
# -------------------------------------------------------------------
# a. Send typing indicator IMMEDIATELY — before any DB or Celery work
# -------------------------------------------------------------------
await websocket.send_json({"type": "typing"})
# -------------------------------------------------------------------
# a. Send typing indicator IMMEDIATELY — before any DB or Celery work
# -------------------------------------------------------------------
await websocket.send_json({"type": "typing"})
# -------------------------------------------------------------------
# b. Save user message to web_conversation_messages
# -------------------------------------------------------------------
configure_rls_hook(engine)
rls_token = current_tenant_id.set(tenant_uuid)
saved_conversation_id = msg_conversation_id
# -------------------------------------------------------------------
# b. Save user message to web_conversation_messages
# -------------------------------------------------------------------
configure_rls_hook(engine)
rls_token = current_tenant_id.set(tenant_uuid)
saved_conversation_id = msg_conversation_id
try:
async with async_session_factory() as session:
# Look up the conversation to get tenant-scoped context
conv_stmt = select(WebConversation).where(
WebConversation.id == uuid.UUID(msg_conversation_id)
try:
async with async_session_factory() as session:
# Look up the conversation to get tenant-scoped context
conv_stmt = select(WebConversation).where(
WebConversation.id == uuid.UUID(msg_conversation_id)
)
conv_result = await session.execute(conv_stmt)
conversation = conv_result.scalar_one_or_none()
if conversation is not None:
# Save user message
user_msg = WebConversationMessage(
conversation_id=uuid.UUID(msg_conversation_id),
tenant_id=tenant_uuid,
role="user",
content=text_content,
)
session.add(user_msg)
# Update conversation timestamp
await session.execute(
text(
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
),
{"conv_id": str(msg_conversation_id)},
)
await session.commit()
saved_conversation_id = msg_conversation_id
except Exception:
logger.exception(
"Failed to save user message for conversation=%s", msg_conversation_id
)
conv_result = await session.execute(conv_stmt)
conversation = conv_result.scalar_one_or_none()
finally:
current_tenant_id.reset(rls_token)
if conversation is not None:
# Save user message
user_msg = WebConversationMessage(
conversation_id=uuid.UUID(msg_conversation_id),
tenant_id=tenant_uuid,
role="user",
content=text_content,
)
session.add(user_msg)
# Update conversation timestamp
await session.execute(
text(
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
),
{"conv_id": str(msg_conversation_id)},
)
await session.commit()
saved_conversation_id = msg_conversation_id
except Exception:
logger.exception(
"Failed to save user message for conversation=%s", msg_conversation_id
)
finally:
current_tenant_id.reset(rls_token)
# -------------------------------------------------------------------
# c. Build KonstructMessage and stream LLM response DIRECTLY
#
# Bypasses Celery entirely for web chat — calls the LLM pool's
# streaming endpoint from the WebSocket handler. This eliminates
# ~5-10s of Celery queue + Redis pub-sub round-trip overhead.
# Slack/WhatsApp still use Celery (async webhook pattern).
# -------------------------------------------------------------------
event = {
"text": text_content,
"tenant_id": tenant_id_str,
"agent_id": agent_id_str,
"user_id": user_id_str,
"display_name": display_name,
"conversation_id": saved_conversation_id,
}
normalized_msg = normalize_web_event(event)
# Load agent for this tenant
agent: Agent | None = None
rls_token3 = current_tenant_id.set(tenant_uuid)
try:
async with async_session_factory() as session:
from sqlalchemy import select as sa_select
agent_stmt = sa_select(Agent).where(
Agent.tenant_id == tenant_uuid,
Agent.is_active == True,
).limit(1)
agent_result = await session.execute(agent_stmt)
agent = agent_result.scalar_one_or_none()
finally:
current_tenant_id.reset(rls_token3)
if agent is None:
await websocket.send_json({
"type": "done",
"text": "No active AI employee is configured for this workspace.",
"conversation_id": saved_conversation_id,
})
continue
# Build memory-enriched messages (Redis sliding window only — fast)
redis_mem = aioredis.from_url(settings.redis_url)
try:
recent_messages = await get_recent_messages(
redis_mem, tenant_id_str, str(agent.id), user_id_str
)
finally:
await redis_mem.aclose()
enriched_messages = build_messages_with_memory(
agent=agent,
current_message=text_content,
recent_messages=recent_messages,
relevant_context=[],
channel="web",
)
# Stream LLM response directly to WebSocket — no Celery, no pub-sub
response_text = ""
try:
async for token in run_agent_streaming(
msg=normalized_msg,
agent=agent,
messages=enriched_messages,
):
response_text += token
try:
await websocket.send_json({"type": "chunk", "text": token})
except Exception:
break # Client disconnected
except Exception:
logger.exception("Direct streaming failed for conversation=%s", saved_conversation_id)
if not response_text:
response_text = "I encountered an error processing your message. Please try again."
# Save to Redis sliding window (fire-and-forget, non-blocking)
redis_mem2 = aioredis.from_url(settings.redis_url)
try:
await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "user", text_content)
if response_text:
await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "assistant", response_text)
finally:
await redis_mem2.aclose()
# Fire-and-forget embedding for long-term memory
try:
embed_and_store.delay({
# -------------------------------------------------------------------
# c. Build KonstructMessage and stream LLM response DIRECTLY
#
# Bypasses Celery entirely for web chat — calls the LLM pool's
# streaming endpoint from the WebSocket handler. This eliminates
# ~5-10s of Celery queue + Redis pub-sub round-trip overhead.
# Slack/WhatsApp still use Celery (async webhook pattern).
# -------------------------------------------------------------------
event = {
"text": text_content,
"tenant_id": tenant_id_str,
"agent_id": str(agent.id),
"agent_id": agent_id_str,
"user_id": user_id_str,
"role": "user",
"content": text_content,
})
if response_text:
"display_name": display_name,
"conversation_id": saved_conversation_id,
}
normalized_msg = normalize_web_event(event)
# Load agent for this tenant
agent: Agent | None = None
rls_token3 = current_tenant_id.set(tenant_uuid)
try:
async with async_session_factory() as session:
from sqlalchemy import select as sa_select
agent_stmt = sa_select(Agent).where(
Agent.tenant_id == tenant_uuid,
Agent.is_active == True,
).limit(1)
agent_result = await session.execute(agent_stmt)
agent = agent_result.scalar_one_or_none()
finally:
current_tenant_id.reset(rls_token3)
if agent is None:
await websocket.send_json({
"type": "done",
"text": "No active AI employee is configured for this workspace.",
"conversation_id": saved_conversation_id,
})
continue
# Build memory-enriched messages (Redis sliding window only — fast)
redis_mem = aioredis.from_url(settings.redis_url)
try:
recent_messages = await get_recent_messages(
redis_mem, tenant_id_str, str(agent.id), user_id_str
)
finally:
await redis_mem.aclose()
enriched_messages = build_messages_with_memory(
agent=agent,
current_message=text_content,
recent_messages=recent_messages,
relevant_context=[],
channel="web",
)
# Stream LLM response directly to WebSocket — no Celery, no pub-sub
response_text = ""
ws_disconnected_during_stream = False
try:
async for token in run_agent_streaming(
msg=normalized_msg,
agent=agent,
messages=enriched_messages,
):
response_text += token
try:
await websocket.send_json({"type": "chunk", "text": token})
except Exception:
ws_disconnected_during_stream = True
break # Client disconnected
except Exception:
logger.exception("Direct streaming failed for conversation=%s", saved_conversation_id)
if not response_text:
response_text = "I encountered an error processing your message. Please try again."
# Save to Redis sliding window (fire-and-forget, non-blocking)
redis_mem2 = aioredis.from_url(settings.redis_url)
try:
await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "user", text_content)
if response_text:
await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "assistant", response_text)
finally:
await redis_mem2.aclose()
# Fire-and-forget embedding for long-term memory
try:
embed_and_store.delay({
"tenant_id": tenant_id_str,
"agent_id": str(agent.id),
"user_id": user_id_str,
"role": "assistant",
"content": response_text,
"role": "user",
"content": text_content,
})
except Exception:
pass # Non-fatal — memory will rebuild over time
# -------------------------------------------------------------------
# e. Save assistant message and send final "done" to client
# -------------------------------------------------------------------
if response_text:
rls_token2 = current_tenant_id.set(tenant_uuid)
try:
async with async_session_factory() as session:
assistant_msg = WebConversationMessage(
conversation_id=uuid.UUID(saved_conversation_id),
tenant_id=tenant_uuid,
role="assistant",
content=response_text,
)
session.add(assistant_msg)
await session.execute(
text(
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
),
{"conv_id": str(saved_conversation_id)},
)
await session.commit()
if response_text:
embed_and_store.delay({
"tenant_id": tenant_id_str,
"agent_id": str(agent.id),
"user_id": user_id_str,
"role": "assistant",
"content": response_text,
})
except Exception:
logger.exception(
"Failed to save assistant message for conversation=%s", saved_conversation_id
pass # Non-fatal — memory will rebuild over time
# -------------------------------------------------------------------
# e. Save assistant message and send final "done" to client
# -------------------------------------------------------------------
if response_text:
rls_token2 = current_tenant_id.set(tenant_uuid)
try:
async with async_session_factory() as session:
assistant_msg = WebConversationMessage(
conversation_id=uuid.UUID(saved_conversation_id),
tenant_id=tenant_uuid,
role="assistant",
content=response_text,
)
session.add(assistant_msg)
await session.execute(
text(
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
),
{"conv_id": str(saved_conversation_id)},
)
await session.commit()
except Exception:
logger.exception(
"Failed to save assistant message for conversation=%s", saved_conversation_id
)
finally:
current_tenant_id.reset(rls_token2)
# If user disconnected during streaming, send push notification
if ws_disconnected_during_stream or not is_user_connected(user_id_str):
agent_name = agent.name if hasattr(agent, "name") and agent.name else "Your AI Employee"
preview = response_text[:100] + ("..." if len(response_text) > 100 else "")
asyncio.create_task(
_send_push_notification(
user_id=user_id_str,
title=f"{agent_name} replied",
body=preview,
conversation_id=saved_conversation_id,
)
)
if ws_disconnected_during_stream:
break # Stop the message loop — WS is gone
# Signal stream completion to the client
try:
await websocket.send_json({
"type": "done",
"text": response_text,
"conversation_id": saved_conversation_id,
})
except Exception:
pass # Client already disconnected
else:
logger.warning(
"No response received for conversation=%s", saved_conversation_id,
)
finally:
current_tenant_id.reset(rls_token2)
try:
await websocket.send_json({
"type": "error",
"message": "I'm having trouble responding right now. Please try again.",
})
except Exception:
pass # Client already disconnected
# Signal stream completion to the client
try:
await websocket.send_json({
"type": "done",
"text": response_text,
"conversation_id": saved_conversation_id,
})
except Exception:
pass # Client already disconnected
else:
logger.warning(
"No response received for conversation=%s", saved_conversation_id,
)
try:
await websocket.send_json({
"type": "error",
"message": "I'm having trouble responding right now. Please try again.",
})
except Exception:
pass # Client already disconnected
finally:
# Always untrack this user when connection ends
_mark_disconnected(user_id_str, conversation_id)
@web_chat_router.websocket("/chat/ws/{conversation_id}")

View File

@@ -48,6 +48,7 @@ from shared.api import (
invitations_router,
llm_keys_router,
portal_router,
push_router,
templates_router,
usage_router,
webhook_router,
@@ -158,6 +159,11 @@ app.include_router(templates_router)
app.include_router(chat_router) # REST: /api/portal/chat/*
app.include_router(web_chat_router) # WebSocket: /chat/ws/{conversation_id}
# ---------------------------------------------------------------------------
# Phase 8 Push Notification router
# ---------------------------------------------------------------------------
app.include_router(push_router) # Push subscribe/unsubscribe/send
# ---------------------------------------------------------------------------
# Routes

View File

@@ -18,6 +18,7 @@ dependencies = [
"httpx>=0.28.0",
"redis>=5.0.0",
"boto3>=1.35.0",
"pywebpush>=2.0.0",
]
[tool.uv.sources]

Submodule packages/portal updated: acba978f2f...c35a982236

View File

@@ -10,6 +10,7 @@ from shared.api.chat import chat_router
from shared.api.invitations import invitations_router
from shared.api.llm_keys import llm_keys_router
from shared.api.portal import portal_router
from shared.api.push import push_router
from shared.api.templates import templates_router
from shared.api.usage import usage_router
@@ -23,4 +24,5 @@ __all__ = [
"invitations_router",
"templates_router",
"chat_router",
"push_router",
]

View File

@@ -0,0 +1,232 @@
"""
FastAPI push notification API — subscription management and send endpoint.
Provides Web Push subscription storage so the gateway can deliver
push notifications when an AI employee responds and the user's
WebSocket is not connected.
Endpoints:
POST /api/portal/push/subscribe — store browser push subscription
DELETE /api/portal/push/unsubscribe — remove subscription by endpoint
POST /api/portal/push/send — internal: send push to user (called by WS handler)
Authentication:
subscribe / unsubscribe: require portal user headers (X-Portal-User-Id)
send: internal endpoint — requires same portal headers but is called by
the gateway WebSocket handler when user is offline
Push delivery:
Uses pywebpush for VAPID-signed Web Push delivery.
Handles 410 Gone responses by deleting stale subscriptions.
"""
from __future__ import annotations
import json
import logging
import os
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import delete, select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from shared.api.rbac import PortalCaller, get_portal_caller
from shared.db import get_session
from shared.models.push import PushSubscription, PushSubscriptionCreate, PushSubscriptionOut, PushSendRequest
logger = logging.getLogger(__name__)
push_router = APIRouter(prefix="/api/portal/push", tags=["push"])
# ---------------------------------------------------------------------------
# VAPID config (read from environment at import time)
# ---------------------------------------------------------------------------
VAPID_PRIVATE_KEY: str = os.environ.get("VAPID_PRIVATE_KEY", "")
VAPID_PUBLIC_KEY: str = os.environ.get("NEXT_PUBLIC_VAPID_PUBLIC_KEY", "")
VAPID_CLAIMS_EMAIL: str = os.environ.get("VAPID_CLAIMS_EMAIL", "admin@konstruct.dev")
# ---------------------------------------------------------------------------
# Helper — send a single push notification via pywebpush
# ---------------------------------------------------------------------------
async def _send_push(subscription: PushSubscription, payload: dict[str, object]) -> bool:
"""
Send a Web Push notification to a single subscription.
Returns True on success, False if the subscription is stale (410 Gone).
Raises on other errors so the caller can decide how to handle them.
"""
if not VAPID_PRIVATE_KEY:
logger.warning("VAPID_PRIVATE_KEY not set — skipping push notification")
return True
try:
from pywebpush import WebPusher, webpush, WebPushException # type: ignore[import]
subscription_info = {
"endpoint": subscription.endpoint,
"keys": {
"p256dh": subscription.p256dh,
"auth": subscription.auth,
},
}
webpush(
subscription_info=subscription_info,
data=json.dumps(payload),
vapid_private_key=VAPID_PRIVATE_KEY,
vapid_claims={
"sub": f"mailto:{VAPID_CLAIMS_EMAIL}",
},
)
return True
except Exception as exc:
# Check for 410 Gone — subscription is no longer valid
exc_str = str(exc)
if "410" in exc_str or "Gone" in exc_str or "expired" in exc_str.lower():
logger.info("Push subscription stale (410 Gone): %s", subscription.endpoint[:40])
return False
logger.error("Push delivery failed: %s", exc_str)
raise
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@push_router.post("/subscribe", status_code=status.HTTP_201_CREATED, response_model=PushSubscriptionOut)
async def subscribe(
body: PushSubscriptionCreate,
caller: PortalCaller = Depends(get_portal_caller),
session: AsyncSession = Depends(get_session),
) -> PushSubscriptionOut:
"""
Store a browser push subscription for the authenticated user.
Uses INSERT ... ON CONFLICT (user_id, endpoint) DO UPDATE so
re-subscribing the same browser updates the keys without creating
a duplicate row.
"""
stmt = (
pg_insert(PushSubscription)
.values(
user_id=caller.user_id,
tenant_id=uuid.UUID(body.tenant_id) if body.tenant_id else None,
endpoint=body.endpoint,
p256dh=body.p256dh,
auth=body.auth,
)
.on_conflict_do_update(
constraint="uq_push_user_endpoint",
set_={
"p256dh": body.p256dh,
"auth": body.auth,
"tenant_id": uuid.UUID(body.tenant_id) if body.tenant_id else None,
},
)
.returning(PushSubscription)
)
result = await session.execute(stmt)
row = result.scalar_one()
await session.commit()
return PushSubscriptionOut(
id=str(row.id),
endpoint=row.endpoint,
created_at=row.created_at,
)
class UnsubscribeRequest(BaseModel):
endpoint: str
@push_router.delete("/unsubscribe", status_code=status.HTTP_204_NO_CONTENT)
async def unsubscribe(
body: UnsubscribeRequest,
caller: PortalCaller = Depends(get_portal_caller),
session: AsyncSession = Depends(get_session),
) -> None:
"""Remove a push subscription for the authenticated user."""
await session.execute(
delete(PushSubscription).where(
PushSubscription.user_id == caller.user_id,
PushSubscription.endpoint == body.endpoint,
)
)
await session.commit()
@push_router.post("/send", status_code=status.HTTP_200_OK)
async def send_push(
body: PushSendRequest,
caller: PortalCaller = Depends(get_portal_caller),
session: AsyncSession = Depends(get_session),
) -> dict[str, object]:
"""
Internal endpoint — send a push notification to all subscriptions for a user.
Called by the gateway WebSocket handler when the agent responds but
the user's WebSocket is no longer connected.
Handles 410 Gone by deleting stale subscriptions.
Returns counts of delivered and stale subscriptions.
"""
try:
target_user_id = uuid.UUID(body.user_id)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Invalid user_id") from exc
# Fetch all subscriptions for this user
result = await session.execute(
select(PushSubscription).where(PushSubscription.user_id == target_user_id)
)
subscriptions = result.scalars().all()
if not subscriptions:
return {"delivered": 0, "stale": 0, "total": 0}
payload = {
"title": body.title,
"body": body.body,
"data": {
"conversationId": body.conversation_id,
},
}
delivered = 0
stale_endpoints: list[str] = []
for sub in subscriptions:
try:
ok = await _send_push(sub, payload)
if ok:
delivered += 1
else:
stale_endpoints.append(sub.endpoint)
except Exception as exc:
logger.error("Push send error for user %s: %s", body.user_id, exc)
# Delete stale subscriptions
if stale_endpoints:
await session.execute(
delete(PushSubscription).where(
PushSubscription.user_id == target_user_id,
PushSubscription.endpoint.in_(stale_endpoints),
)
)
await session.commit()
return {
"delivered": delivered,
"stale": len(stale_endpoints),
"total": len(subscriptions),
}

View File

@@ -0,0 +1,122 @@
"""
Push subscription model for Web Push notifications.
Stores browser push subscriptions for portal users so the gateway can
send push notifications when an AI employee responds and the user's
WebSocket is not connected.
Push subscriptions are per-user, per-browser-endpoint. No RLS is applied
to this table — the API filters by user_id in the query (push subscriptions
are portal-user-scoped, not tenant-scoped).
"""
from __future__ import annotations
import uuid
from datetime import datetime
from pydantic import BaseModel
from sqlalchemy import DateTime, ForeignKey, String, Text, UniqueConstraint, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from shared.models.tenant import Base
class PushSubscription(Base):
"""
Browser push subscription for a portal user.
endpoint: The push service URL provided by the browser.
p256dh: ECDH public key for message encryption.
auth: Auth secret for message encryption.
Unique constraint on (user_id, endpoint) — one subscription per
browser per user. Upsert on conflict avoids duplicates on re-subscribe.
"""
__tablename__ = "push_subscriptions"
__table_args__ = (
UniqueConstraint("user_id", "endpoint", name="uq_push_user_endpoint"),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
server_default=func.gen_random_uuid(),
)
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("portal_users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
tenant_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="SET NULL"),
nullable=True,
index=True,
comment="Optional tenant scope for notification routing",
)
endpoint: Mapped[str] = mapped_column(
Text,
nullable=False,
comment="Push service URL (browser-provided)",
)
p256dh: Mapped[str] = mapped_column(
Text,
nullable=False,
comment="ECDH public key for payload encryption",
)
auth: Mapped[str] = mapped_column(
Text,
nullable=False,
comment="Auth secret for payload encryption",
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
onupdate=func.now(),
)
def __repr__(self) -> str:
return f"<PushSubscription user={self.user_id} endpoint={self.endpoint[:40]!r}>"
# ---------------------------------------------------------------------------
# Pydantic schemas
# ---------------------------------------------------------------------------
class PushSubscriptionCreate(BaseModel):
"""Payload for POST /portal/push/subscribe."""
endpoint: str
p256dh: str
auth: str
tenant_id: str | None = None
class PushSubscriptionOut(BaseModel):
"""Response body for subscription operations."""
id: str
endpoint: str
created_at: datetime
model_config = {"from_attributes": True}
class PushSendRequest(BaseModel):
"""Internal payload for POST /portal/push/send."""
user_id: str
title: str
body: str
conversation_id: str | None = None