feat(08-03): push notification backend — DB model, migration, API router, VAPID setup

- Add PushSubscription ORM model with unique(user_id, endpoint) constraint
- Add Alembic migration 012 for push_subscriptions table
- Add push router (subscribe, unsubscribe, send) in shared/api/push.py
- Mount push router in gateway/main.py
- Add pywebpush to gateway dependencies for server-side VAPID delivery
- Wire push trigger into WebSocket handler (fires when client disconnects mid-stream)
- Add VAPID keys to .env / .env.example
- Add push/install i18n keys in en/es/pt message files
This commit is contained in:
2026-03-25 21:26:51 -06:00
parent 5c30651754
commit 7d3a393758
9 changed files with 774 additions and 199 deletions

View File

@@ -61,3 +61,11 @@ DEBUG=false
# Tenant rate limits (requests per minute defaults)
DEFAULT_RATE_LIMIT_RPM=60
# -----------------------------------------------------------------------------
# Web Push Notifications (VAPID keys)
# Generate with: cd packages/portal && npx web-push generate-vapid-keys
# -----------------------------------------------------------------------------
NEXT_PUBLIC_VAPID_PUBLIC_KEY=your-vapid-public-key
VAPID_PRIVATE_KEY=your-vapid-private-key
VAPID_CLAIMS_EMAIL=admin@yourdomain.com

View File

@@ -0,0 +1,91 @@
"""Push subscriptions table for Web Push notifications
Revision ID: 012
Revises: 011
Create Date: 2026-03-26
Creates the push_subscriptions table so the gateway can store browser
push subscriptions and deliver Web Push notifications when an AI employee
responds and the user's WebSocket is not connected.
No RLS policy is applied — the API filters by user_id at the application
layer (push subscriptions are portal-user-scoped, not tenant-scoped).
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "012"
down_revision: Union[str, None] = "011"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"push_subscriptions",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("portal_users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"tenant_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("tenants.id", ondelete="SET NULL"),
nullable=True,
comment="Optional tenant scope for notification routing",
),
sa.Column(
"endpoint",
sa.Text,
nullable=False,
comment="Push service URL (browser-provided)",
),
sa.Column(
"p256dh",
sa.Text,
nullable=False,
comment="ECDH public key for payload encryption",
),
sa.Column(
"auth",
sa.Text,
nullable=False,
comment="Auth secret for payload encryption",
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_id", "endpoint", name="uq_push_user_endpoint"),
)
op.create_index("ix_push_subscriptions_user_id", "push_subscriptions", ["user_id"])
op.create_index("ix_push_subscriptions_tenant_id", "push_subscriptions", ["tenant_id"])
def downgrade() -> None:
op.drop_index("ix_push_subscriptions_tenant_id", table_name="push_subscriptions")
op.drop_index("ix_push_subscriptions_user_id", table_name="push_subscriptions")
op.drop_table("push_subscriptions")

View File

@@ -27,6 +27,11 @@ Design notes:
- DB access uses configure_rls_hook + current_tenant_id context var per project pattern
- WebSocket is a long-lived connection; each message/response cycle is synchronous
within the connection but non-blocking for other connections
Push notifications:
- Connected users are tracked in _connected_users (in-memory dict)
- When the WebSocket send for "done" raises (client disconnected mid-stream),
a push notification is fired so the user sees the response on their device.
"""
from __future__ import annotations
@@ -40,7 +45,7 @@ from typing import Any
import redis.asyncio as aioredis
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from sqlalchemy import select, text
from sqlalchemy import delete, select, text
from orchestrator.agents.builder import build_messages_with_memory, build_system_prompt
from orchestrator.agents.runner import run_agent_streaming
@@ -64,6 +69,89 @@ web_chat_router = APIRouter(tags=["web-chat"])
# Timeout for waiting for an agent response via Redis pub-sub (seconds)
_RESPONSE_TIMEOUT_SECONDS = 180
# ---------------------------------------------------------------------------
# Connected user tracking — used to decide whether to send push notifications
# ---------------------------------------------------------------------------
# Maps user_id -> set of conversation_ids with active WebSocket connections.
# When a user disconnects, their entry is removed. If the agent response
# finishes after disconnect, a push notification is sent.
_connected_users: dict[str, set[str]] = {}
def _mark_connected(user_id: str, conversation_id: str) -> None:
"""Record that user_id has an active WebSocket for conversation_id."""
if user_id not in _connected_users:
_connected_users[user_id] = set()
_connected_users[user_id].add(conversation_id)
def _mark_disconnected(user_id: str, conversation_id: str) -> None:
"""Remove the active WebSocket record for user_id + conversation_id."""
if user_id in _connected_users:
_connected_users[user_id].discard(conversation_id)
if not _connected_users[user_id]:
del _connected_users[user_id]
def is_user_connected(user_id: str) -> bool:
"""Return True if the user has any active WebSocket connection."""
return user_id in _connected_users and bool(_connected_users[user_id])
async def _send_push_notification(
user_id: str,
title: str,
body: str,
conversation_id: str | None = None,
) -> None:
"""
Fire-and-forget push notification delivery.
Queries push_subscriptions for the user and calls pywebpush directly.
Deletes stale (410 Gone) subscriptions automatically.
Silently ignores errors — push is best-effort.
"""
from shared.models.push import PushSubscription
from shared.api.push import _send_push
try:
user_uuid = uuid.UUID(user_id)
payload = {
"title": title,
"body": body,
"data": {"conversationId": conversation_id},
}
async with async_session_factory() as session:
result = await session.execute(
select(PushSubscription).where(PushSubscription.user_id == user_uuid)
)
subscriptions = result.scalars().all()
if not subscriptions:
return
stale_endpoints: list[str] = []
for sub in subscriptions:
try:
ok = await _send_push(sub, payload)
if not ok:
stale_endpoints.append(sub.endpoint)
except Exception as exc:
logger.warning("Push delivery failed for user=%s: %s", user_id, exc)
if stale_endpoints:
await session.execute(
delete(PushSubscription).where(
PushSubscription.user_id == user_uuid,
PushSubscription.endpoint.in_(stale_endpoints),
)
)
await session.commit()
except Exception as exc:
logger.warning("Push notification send error for user=%s: %s", user_id, exc)
def normalize_web_event(event: dict[str, Any]) -> KonstructMessage:
"""
@@ -164,9 +252,13 @@ async def _handle_websocket_connection(
user_id_str, user_role, tenant_id_str, conversation_id,
)
# Track this user as connected (for push notification gating)
_mark_connected(user_id_str, conversation_id)
# -------------------------------------------------------------------------
# Step 2: Message loop
# -------------------------------------------------------------------------
try:
while True:
try:
msg_data = await websocket.receive_json()
@@ -288,6 +380,7 @@ async def _handle_websocket_connection(
# Stream LLM response directly to WebSocket — no Celery, no pub-sub
response_text = ""
ws_disconnected_during_stream = False
try:
async for token in run_agent_streaming(
msg=normalized_msg,
@@ -298,6 +391,7 @@ async def _handle_websocket_connection(
try:
await websocket.send_json({"type": "chunk", "text": token})
except Exception:
ws_disconnected_during_stream = True
break # Client disconnected
except Exception:
logger.exception("Direct streaming failed for conversation=%s", saved_conversation_id)
@@ -361,6 +455,21 @@ async def _handle_websocket_connection(
finally:
current_tenant_id.reset(rls_token2)
# If user disconnected during streaming, send push notification
if ws_disconnected_during_stream or not is_user_connected(user_id_str):
agent_name = agent.name if hasattr(agent, "name") and agent.name else "Your AI Employee"
preview = response_text[:100] + ("..." if len(response_text) > 100 else "")
asyncio.create_task(
_send_push_notification(
user_id=user_id_str,
title=f"{agent_name} replied",
body=preview,
conversation_id=saved_conversation_id,
)
)
if ws_disconnected_during_stream:
break # Stop the message loop — WS is gone
# Signal stream completion to the client
try:
await websocket.send_json({
@@ -382,6 +491,10 @@ async def _handle_websocket_connection(
except Exception:
pass # Client already disconnected
finally:
# Always untrack this user when connection ends
_mark_disconnected(user_id_str, conversation_id)
@web_chat_router.websocket("/chat/ws/{conversation_id}")
async def chat_websocket(websocket: WebSocket, conversation_id: str) -> None:

View File

@@ -48,6 +48,7 @@ from shared.api import (
invitations_router,
llm_keys_router,
portal_router,
push_router,
templates_router,
usage_router,
webhook_router,
@@ -158,6 +159,11 @@ app.include_router(templates_router)
app.include_router(chat_router) # REST: /api/portal/chat/*
app.include_router(web_chat_router) # WebSocket: /chat/ws/{conversation_id}
# ---------------------------------------------------------------------------
# Phase 8 Push Notification router
# ---------------------------------------------------------------------------
app.include_router(push_router) # Push subscribe/unsubscribe/send
# ---------------------------------------------------------------------------
# Routes

View File

@@ -18,6 +18,7 @@ dependencies = [
"httpx>=0.28.0",
"redis>=5.0.0",
"boto3>=1.35.0",
"pywebpush>=2.0.0",
]
[tool.uv.sources]

Submodule packages/portal updated: acba978f2f...c35a982236

View File

@@ -10,6 +10,7 @@ from shared.api.chat import chat_router
from shared.api.invitations import invitations_router
from shared.api.llm_keys import llm_keys_router
from shared.api.portal import portal_router
from shared.api.push import push_router
from shared.api.templates import templates_router
from shared.api.usage import usage_router
@@ -23,4 +24,5 @@ __all__ = [
"invitations_router",
"templates_router",
"chat_router",
"push_router",
]

View File

@@ -0,0 +1,232 @@
"""
FastAPI push notification API — subscription management and send endpoint.
Provides Web Push subscription storage so the gateway can deliver
push notifications when an AI employee responds and the user's
WebSocket is not connected.
Endpoints:
POST /api/portal/push/subscribe — store browser push subscription
DELETE /api/portal/push/unsubscribe — remove subscription by endpoint
POST /api/portal/push/send — internal: send push to user (called by WS handler)
Authentication:
subscribe / unsubscribe: require portal user headers (X-Portal-User-Id)
send: internal endpoint — requires same portal headers but is called by
the gateway WebSocket handler when user is offline
Push delivery:
Uses pywebpush for VAPID-signed Web Push delivery.
Handles 410 Gone responses by deleting stale subscriptions.
"""
from __future__ import annotations
import json
import logging
import os
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import delete, select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from shared.api.rbac import PortalCaller, get_portal_caller
from shared.db import get_session
from shared.models.push import PushSubscription, PushSubscriptionCreate, PushSubscriptionOut, PushSendRequest
logger = logging.getLogger(__name__)
push_router = APIRouter(prefix="/api/portal/push", tags=["push"])
# ---------------------------------------------------------------------------
# VAPID config (read from environment at import time)
# ---------------------------------------------------------------------------
VAPID_PRIVATE_KEY: str = os.environ.get("VAPID_PRIVATE_KEY", "")
VAPID_PUBLIC_KEY: str = os.environ.get("NEXT_PUBLIC_VAPID_PUBLIC_KEY", "")
VAPID_CLAIMS_EMAIL: str = os.environ.get("VAPID_CLAIMS_EMAIL", "admin@konstruct.dev")
# ---------------------------------------------------------------------------
# Helper — send a single push notification via pywebpush
# ---------------------------------------------------------------------------
async def _send_push(subscription: PushSubscription, payload: dict[str, object]) -> bool:
"""
Send a Web Push notification to a single subscription.
Returns True on success, False if the subscription is stale (410 Gone).
Raises on other errors so the caller can decide how to handle them.
"""
if not VAPID_PRIVATE_KEY:
logger.warning("VAPID_PRIVATE_KEY not set — skipping push notification")
return True
try:
from pywebpush import WebPusher, webpush, WebPushException # type: ignore[import]
subscription_info = {
"endpoint": subscription.endpoint,
"keys": {
"p256dh": subscription.p256dh,
"auth": subscription.auth,
},
}
webpush(
subscription_info=subscription_info,
data=json.dumps(payload),
vapid_private_key=VAPID_PRIVATE_KEY,
vapid_claims={
"sub": f"mailto:{VAPID_CLAIMS_EMAIL}",
},
)
return True
except Exception as exc:
# Check for 410 Gone — subscription is no longer valid
exc_str = str(exc)
if "410" in exc_str or "Gone" in exc_str or "expired" in exc_str.lower():
logger.info("Push subscription stale (410 Gone): %s", subscription.endpoint[:40])
return False
logger.error("Push delivery failed: %s", exc_str)
raise
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@push_router.post("/subscribe", status_code=status.HTTP_201_CREATED, response_model=PushSubscriptionOut)
async def subscribe(
body: PushSubscriptionCreate,
caller: PortalCaller = Depends(get_portal_caller),
session: AsyncSession = Depends(get_session),
) -> PushSubscriptionOut:
"""
Store a browser push subscription for the authenticated user.
Uses INSERT ... ON CONFLICT (user_id, endpoint) DO UPDATE so
re-subscribing the same browser updates the keys without creating
a duplicate row.
"""
stmt = (
pg_insert(PushSubscription)
.values(
user_id=caller.user_id,
tenant_id=uuid.UUID(body.tenant_id) if body.tenant_id else None,
endpoint=body.endpoint,
p256dh=body.p256dh,
auth=body.auth,
)
.on_conflict_do_update(
constraint="uq_push_user_endpoint",
set_={
"p256dh": body.p256dh,
"auth": body.auth,
"tenant_id": uuid.UUID(body.tenant_id) if body.tenant_id else None,
},
)
.returning(PushSubscription)
)
result = await session.execute(stmt)
row = result.scalar_one()
await session.commit()
return PushSubscriptionOut(
id=str(row.id),
endpoint=row.endpoint,
created_at=row.created_at,
)
class UnsubscribeRequest(BaseModel):
endpoint: str
@push_router.delete("/unsubscribe", status_code=status.HTTP_204_NO_CONTENT)
async def unsubscribe(
body: UnsubscribeRequest,
caller: PortalCaller = Depends(get_portal_caller),
session: AsyncSession = Depends(get_session),
) -> None:
"""Remove a push subscription for the authenticated user."""
await session.execute(
delete(PushSubscription).where(
PushSubscription.user_id == caller.user_id,
PushSubscription.endpoint == body.endpoint,
)
)
await session.commit()
@push_router.post("/send", status_code=status.HTTP_200_OK)
async def send_push(
body: PushSendRequest,
caller: PortalCaller = Depends(get_portal_caller),
session: AsyncSession = Depends(get_session),
) -> dict[str, object]:
"""
Internal endpoint — send a push notification to all subscriptions for a user.
Called by the gateway WebSocket handler when the agent responds but
the user's WebSocket is no longer connected.
Handles 410 Gone by deleting stale subscriptions.
Returns counts of delivered and stale subscriptions.
"""
try:
target_user_id = uuid.UUID(body.user_id)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Invalid user_id") from exc
# Fetch all subscriptions for this user
result = await session.execute(
select(PushSubscription).where(PushSubscription.user_id == target_user_id)
)
subscriptions = result.scalars().all()
if not subscriptions:
return {"delivered": 0, "stale": 0, "total": 0}
payload = {
"title": body.title,
"body": body.body,
"data": {
"conversationId": body.conversation_id,
},
}
delivered = 0
stale_endpoints: list[str] = []
for sub in subscriptions:
try:
ok = await _send_push(sub, payload)
if ok:
delivered += 1
else:
stale_endpoints.append(sub.endpoint)
except Exception as exc:
logger.error("Push send error for user %s: %s", body.user_id, exc)
# Delete stale subscriptions
if stale_endpoints:
await session.execute(
delete(PushSubscription).where(
PushSubscription.user_id == target_user_id,
PushSubscription.endpoint.in_(stale_endpoints),
)
)
await session.commit()
return {
"delivered": delivered,
"stale": len(stale_endpoints),
"total": len(subscriptions),
}

View File

@@ -0,0 +1,122 @@
"""
Push subscription model for Web Push notifications.
Stores browser push subscriptions for portal users so the gateway can
send push notifications when an AI employee responds and the user's
WebSocket is not connected.
Push subscriptions are per-user, per-browser-endpoint. No RLS is applied
to this table — the API filters by user_id in the query (push subscriptions
are portal-user-scoped, not tenant-scoped).
"""
from __future__ import annotations
import uuid
from datetime import datetime
from pydantic import BaseModel
from sqlalchemy import DateTime, ForeignKey, String, Text, UniqueConstraint, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from shared.models.tenant import Base
class PushSubscription(Base):
"""
Browser push subscription for a portal user.
endpoint: The push service URL provided by the browser.
p256dh: ECDH public key for message encryption.
auth: Auth secret for message encryption.
Unique constraint on (user_id, endpoint) — one subscription per
browser per user. Upsert on conflict avoids duplicates on re-subscribe.
"""
__tablename__ = "push_subscriptions"
__table_args__ = (
UniqueConstraint("user_id", "endpoint", name="uq_push_user_endpoint"),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
server_default=func.gen_random_uuid(),
)
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("portal_users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
tenant_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="SET NULL"),
nullable=True,
index=True,
comment="Optional tenant scope for notification routing",
)
endpoint: Mapped[str] = mapped_column(
Text,
nullable=False,
comment="Push service URL (browser-provided)",
)
p256dh: Mapped[str] = mapped_column(
Text,
nullable=False,
comment="ECDH public key for payload encryption",
)
auth: Mapped[str] = mapped_column(
Text,
nullable=False,
comment="Auth secret for payload encryption",
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
onupdate=func.now(),
)
def __repr__(self) -> str:
return f"<PushSubscription user={self.user_id} endpoint={self.endpoint[:40]!r}>"
# ---------------------------------------------------------------------------
# Pydantic schemas
# ---------------------------------------------------------------------------
class PushSubscriptionCreate(BaseModel):
"""Payload for POST /portal/push/subscribe."""
endpoint: str
p256dh: str
auth: str
tenant_id: str | None = None
class PushSubscriptionOut(BaseModel):
"""Response body for subscription operations."""
id: str
endpoint: str
created_at: datetime
model_config = {"from_attributes": True}
class PushSendRequest(BaseModel):
"""Internal payload for POST /portal/push/send."""
user_id: str
title: str
body: str
conversation_id: str | None = None