diff --git a/.env.example b/.env.example index 1e86ca6..1690eb1 100644 --- a/.env.example +++ b/.env.example @@ -61,3 +61,11 @@ DEBUG=false # Tenant rate limits (requests per minute defaults) DEFAULT_RATE_LIMIT_RPM=60 + +# ----------------------------------------------------------------------------- +# Web Push Notifications (VAPID keys) +# Generate with: cd packages/portal && npx web-push generate-vapid-keys +# ----------------------------------------------------------------------------- +NEXT_PUBLIC_VAPID_PUBLIC_KEY=your-vapid-public-key +VAPID_PRIVATE_KEY=your-vapid-private-key +VAPID_CLAIMS_EMAIL=admin@yourdomain.com diff --git a/migrations/versions/012_push_subscriptions.py b/migrations/versions/012_push_subscriptions.py new file mode 100644 index 0000000..f1f87b8 --- /dev/null +++ b/migrations/versions/012_push_subscriptions.py @@ -0,0 +1,91 @@ +"""Push subscriptions table for Web Push notifications + +Revision ID: 012 +Revises: 011 +Create Date: 2026-03-26 + +Creates the push_subscriptions table so the gateway can store browser +push subscriptions and deliver Web Push notifications when an AI employee +responds and the user's WebSocket is not connected. + +No RLS policy is applied — the API filters by user_id at the application +layer (push subscriptions are portal-user-scoped, not tenant-scoped). +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "012" +down_revision: Union[str, None] = "011" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "push_subscriptions", + sa.Column( + "id", + postgresql.UUID(as_uuid=True), + server_default=sa.text("gen_random_uuid()"), + nullable=False, + ), + sa.Column( + "user_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("portal_users.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "tenant_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("tenants.id", ondelete="SET NULL"), + nullable=True, + comment="Optional tenant scope for notification routing", + ), + sa.Column( + "endpoint", + sa.Text, + nullable=False, + comment="Push service URL (browser-provided)", + ), + sa.Column( + "p256dh", + sa.Text, + nullable=False, + comment="ECDH public key for payload encryption", + ), + sa.Column( + "auth", + sa.Text, + nullable=False, + comment="Auth secret for payload encryption", + ), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("user_id", "endpoint", name="uq_push_user_endpoint"), + ) + op.create_index("ix_push_subscriptions_user_id", "push_subscriptions", ["user_id"]) + op.create_index("ix_push_subscriptions_tenant_id", "push_subscriptions", ["tenant_id"]) + + +def downgrade() -> None: + op.drop_index("ix_push_subscriptions_tenant_id", table_name="push_subscriptions") + op.drop_index("ix_push_subscriptions_user_id", table_name="push_subscriptions") + op.drop_table("push_subscriptions") diff --git a/packages/gateway/gateway/channels/web.py b/packages/gateway/gateway/channels/web.py index fda505e..2f71556 100644 --- a/packages/gateway/gateway/channels/web.py +++ b/packages/gateway/gateway/channels/web.py @@ -27,6 +27,11 @@ Design notes: - DB access uses configure_rls_hook + current_tenant_id context var per project pattern - WebSocket is a long-lived connection; each message/response cycle is synchronous within the connection but non-blocking for other connections + +Push notifications: + - Connected users are tracked in _connected_users (in-memory dict) + - When the WebSocket send for "done" raises (client disconnected mid-stream), + a push notification is fired so the user sees the response on their device. """ from __future__ import annotations @@ -40,7 +45,7 @@ from typing import Any import redis.asyncio as aioredis from fastapi import APIRouter, WebSocket, WebSocketDisconnect -from sqlalchemy import select, text +from sqlalchemy import delete, select, text from orchestrator.agents.builder import build_messages_with_memory, build_system_prompt from orchestrator.agents.runner import run_agent_streaming @@ -64,6 +69,89 @@ web_chat_router = APIRouter(tags=["web-chat"]) # Timeout for waiting for an agent response via Redis pub-sub (seconds) _RESPONSE_TIMEOUT_SECONDS = 180 +# --------------------------------------------------------------------------- +# Connected user tracking — used to decide whether to send push notifications +# --------------------------------------------------------------------------- +# Maps user_id -> set of conversation_ids with active WebSocket connections. +# When a user disconnects, their entry is removed. If the agent response +# finishes after disconnect, a push notification is sent. +_connected_users: dict[str, set[str]] = {} + + +def _mark_connected(user_id: str, conversation_id: str) -> None: + """Record that user_id has an active WebSocket for conversation_id.""" + if user_id not in _connected_users: + _connected_users[user_id] = set() + _connected_users[user_id].add(conversation_id) + + +def _mark_disconnected(user_id: str, conversation_id: str) -> None: + """Remove the active WebSocket record for user_id + conversation_id.""" + if user_id in _connected_users: + _connected_users[user_id].discard(conversation_id) + if not _connected_users[user_id]: + del _connected_users[user_id] + + +def is_user_connected(user_id: str) -> bool: + """Return True if the user has any active WebSocket connection.""" + return user_id in _connected_users and bool(_connected_users[user_id]) + + +async def _send_push_notification( + user_id: str, + title: str, + body: str, + conversation_id: str | None = None, +) -> None: + """ + Fire-and-forget push notification delivery. + + Queries push_subscriptions for the user and calls pywebpush directly. + Deletes stale (410 Gone) subscriptions automatically. + Silently ignores errors — push is best-effort. + """ + from shared.models.push import PushSubscription + from shared.api.push import _send_push + + try: + user_uuid = uuid.UUID(user_id) + payload = { + "title": title, + "body": body, + "data": {"conversationId": conversation_id}, + } + + async with async_session_factory() as session: + result = await session.execute( + select(PushSubscription).where(PushSubscription.user_id == user_uuid) + ) + subscriptions = result.scalars().all() + + if not subscriptions: + return + + stale_endpoints: list[str] = [] + for sub in subscriptions: + try: + ok = await _send_push(sub, payload) + if not ok: + stale_endpoints.append(sub.endpoint) + except Exception as exc: + logger.warning("Push delivery failed for user=%s: %s", user_id, exc) + + if stale_endpoints: + await session.execute( + delete(PushSubscription).where( + PushSubscription.user_id == user_uuid, + PushSubscription.endpoint.in_(stale_endpoints), + ) + ) + await session.commit() + + except Exception as exc: + logger.warning("Push notification send error for user=%s: %s", user_id, exc) + def normalize_web_event(event: dict[str, Any]) -> KonstructMessage: """ @@ -164,223 +252,248 @@ async def _handle_websocket_connection( user_id_str, user_role, tenant_id_str, conversation_id, ) + # Track this user as connected (for push notification gating) + _mark_connected(user_id_str, conversation_id) + # ------------------------------------------------------------------------- # Step 2: Message loop # ------------------------------------------------------------------------- - while True: - try: - msg_data = await websocket.receive_json() - except (WebSocketDisconnect, Exception): - break + try: + while True: + try: + msg_data = await websocket.receive_json() + except (WebSocketDisconnect, Exception): + break - if msg_data.get("type") != "message": - continue + if msg_data.get("type") != "message": + continue - text_content: str = msg_data.get("text", "") or "" - agent_id_str: str = msg_data.get("agentId", "") or "" - msg_conversation_id: str = msg_data.get("conversationId", conversation_id) or conversation_id - display_name: str = msg_data.get("displayName", "Portal User") + text_content: str = msg_data.get("text", "") or "" + agent_id_str: str = msg_data.get("agentId", "") or "" + msg_conversation_id: str = msg_data.get("conversationId", conversation_id) or conversation_id + display_name: str = msg_data.get("displayName", "Portal User") - # ------------------------------------------------------------------- - # a. Send typing indicator IMMEDIATELY — before any DB or Celery work - # ------------------------------------------------------------------- - await websocket.send_json({"type": "typing"}) + # ------------------------------------------------------------------- + # a. Send typing indicator IMMEDIATELY — before any DB or Celery work + # ------------------------------------------------------------------- + await websocket.send_json({"type": "typing"}) - # ------------------------------------------------------------------- - # b. Save user message to web_conversation_messages - # ------------------------------------------------------------------- - configure_rls_hook(engine) - rls_token = current_tenant_id.set(tenant_uuid) - saved_conversation_id = msg_conversation_id + # ------------------------------------------------------------------- + # b. Save user message to web_conversation_messages + # ------------------------------------------------------------------- + configure_rls_hook(engine) + rls_token = current_tenant_id.set(tenant_uuid) + saved_conversation_id = msg_conversation_id - try: - async with async_session_factory() as session: - # Look up the conversation to get tenant-scoped context - conv_stmt = select(WebConversation).where( - WebConversation.id == uuid.UUID(msg_conversation_id) + try: + async with async_session_factory() as session: + # Look up the conversation to get tenant-scoped context + conv_stmt = select(WebConversation).where( + WebConversation.id == uuid.UUID(msg_conversation_id) + ) + conv_result = await session.execute(conv_stmt) + conversation = conv_result.scalar_one_or_none() + + if conversation is not None: + # Save user message + user_msg = WebConversationMessage( + conversation_id=uuid.UUID(msg_conversation_id), + tenant_id=tenant_uuid, + role="user", + content=text_content, + ) + session.add(user_msg) + + # Update conversation timestamp + await session.execute( + text( + "UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id" + ), + {"conv_id": str(msg_conversation_id)}, + ) + await session.commit() + saved_conversation_id = msg_conversation_id + except Exception: + logger.exception( + "Failed to save user message for conversation=%s", msg_conversation_id ) - conv_result = await session.execute(conv_stmt) - conversation = conv_result.scalar_one_or_none() + finally: + current_tenant_id.reset(rls_token) - if conversation is not None: - # Save user message - user_msg = WebConversationMessage( - conversation_id=uuid.UUID(msg_conversation_id), - tenant_id=tenant_uuid, - role="user", - content=text_content, - ) - session.add(user_msg) - - # Update conversation timestamp - await session.execute( - text( - "UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id" - ), - {"conv_id": str(msg_conversation_id)}, - ) - await session.commit() - saved_conversation_id = msg_conversation_id - except Exception: - logger.exception( - "Failed to save user message for conversation=%s", msg_conversation_id - ) - finally: - current_tenant_id.reset(rls_token) - - # ------------------------------------------------------------------- - # c. Build KonstructMessage and stream LLM response DIRECTLY - # - # Bypasses Celery entirely for web chat — calls the LLM pool's - # streaming endpoint from the WebSocket handler. This eliminates - # ~5-10s of Celery queue + Redis pub-sub round-trip overhead. - # Slack/WhatsApp still use Celery (async webhook pattern). - # ------------------------------------------------------------------- - event = { - "text": text_content, - "tenant_id": tenant_id_str, - "agent_id": agent_id_str, - "user_id": user_id_str, - "display_name": display_name, - "conversation_id": saved_conversation_id, - } - normalized_msg = normalize_web_event(event) - - # Load agent for this tenant - agent: Agent | None = None - rls_token3 = current_tenant_id.set(tenant_uuid) - try: - async with async_session_factory() as session: - from sqlalchemy import select as sa_select - agent_stmt = sa_select(Agent).where( - Agent.tenant_id == tenant_uuid, - Agent.is_active == True, - ).limit(1) - agent_result = await session.execute(agent_stmt) - agent = agent_result.scalar_one_or_none() - finally: - current_tenant_id.reset(rls_token3) - - if agent is None: - await websocket.send_json({ - "type": "done", - "text": "No active AI employee is configured for this workspace.", - "conversation_id": saved_conversation_id, - }) - continue - - # Build memory-enriched messages (Redis sliding window only — fast) - redis_mem = aioredis.from_url(settings.redis_url) - try: - recent_messages = await get_recent_messages( - redis_mem, tenant_id_str, str(agent.id), user_id_str - ) - finally: - await redis_mem.aclose() - - enriched_messages = build_messages_with_memory( - agent=agent, - current_message=text_content, - recent_messages=recent_messages, - relevant_context=[], - channel="web", - ) - - # Stream LLM response directly to WebSocket — no Celery, no pub-sub - response_text = "" - try: - async for token in run_agent_streaming( - msg=normalized_msg, - agent=agent, - messages=enriched_messages, - ): - response_text += token - try: - await websocket.send_json({"type": "chunk", "text": token}) - except Exception: - break # Client disconnected - except Exception: - logger.exception("Direct streaming failed for conversation=%s", saved_conversation_id) - if not response_text: - response_text = "I encountered an error processing your message. Please try again." - - # Save to Redis sliding window (fire-and-forget, non-blocking) - redis_mem2 = aioredis.from_url(settings.redis_url) - try: - await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "user", text_content) - if response_text: - await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "assistant", response_text) - finally: - await redis_mem2.aclose() - - # Fire-and-forget embedding for long-term memory - try: - embed_and_store.delay({ + # ------------------------------------------------------------------- + # c. Build KonstructMessage and stream LLM response DIRECTLY + # + # Bypasses Celery entirely for web chat — calls the LLM pool's + # streaming endpoint from the WebSocket handler. This eliminates + # ~5-10s of Celery queue + Redis pub-sub round-trip overhead. + # Slack/WhatsApp still use Celery (async webhook pattern). + # ------------------------------------------------------------------- + event = { + "text": text_content, "tenant_id": tenant_id_str, - "agent_id": str(agent.id), + "agent_id": agent_id_str, "user_id": user_id_str, - "role": "user", - "content": text_content, - }) - if response_text: + "display_name": display_name, + "conversation_id": saved_conversation_id, + } + normalized_msg = normalize_web_event(event) + + # Load agent for this tenant + agent: Agent | None = None + rls_token3 = current_tenant_id.set(tenant_uuid) + try: + async with async_session_factory() as session: + from sqlalchemy import select as sa_select + agent_stmt = sa_select(Agent).where( + Agent.tenant_id == tenant_uuid, + Agent.is_active == True, + ).limit(1) + agent_result = await session.execute(agent_stmt) + agent = agent_result.scalar_one_or_none() + finally: + current_tenant_id.reset(rls_token3) + + if agent is None: + await websocket.send_json({ + "type": "done", + "text": "No active AI employee is configured for this workspace.", + "conversation_id": saved_conversation_id, + }) + continue + + # Build memory-enriched messages (Redis sliding window only — fast) + redis_mem = aioredis.from_url(settings.redis_url) + try: + recent_messages = await get_recent_messages( + redis_mem, tenant_id_str, str(agent.id), user_id_str + ) + finally: + await redis_mem.aclose() + + enriched_messages = build_messages_with_memory( + agent=agent, + current_message=text_content, + recent_messages=recent_messages, + relevant_context=[], + channel="web", + ) + + # Stream LLM response directly to WebSocket — no Celery, no pub-sub + response_text = "" + ws_disconnected_during_stream = False + try: + async for token in run_agent_streaming( + msg=normalized_msg, + agent=agent, + messages=enriched_messages, + ): + response_text += token + try: + await websocket.send_json({"type": "chunk", "text": token}) + except Exception: + ws_disconnected_during_stream = True + break # Client disconnected + except Exception: + logger.exception("Direct streaming failed for conversation=%s", saved_conversation_id) + if not response_text: + response_text = "I encountered an error processing your message. Please try again." + + # Save to Redis sliding window (fire-and-forget, non-blocking) + redis_mem2 = aioredis.from_url(settings.redis_url) + try: + await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "user", text_content) + if response_text: + await append_message(redis_mem2, tenant_id_str, str(agent.id), user_id_str, "assistant", response_text) + finally: + await redis_mem2.aclose() + + # Fire-and-forget embedding for long-term memory + try: embed_and_store.delay({ "tenant_id": tenant_id_str, "agent_id": str(agent.id), "user_id": user_id_str, - "role": "assistant", - "content": response_text, + "role": "user", + "content": text_content, }) - except Exception: - pass # Non-fatal — memory will rebuild over time - - # ------------------------------------------------------------------- - # e. Save assistant message and send final "done" to client - # ------------------------------------------------------------------- - if response_text: - rls_token2 = current_tenant_id.set(tenant_uuid) - try: - async with async_session_factory() as session: - assistant_msg = WebConversationMessage( - conversation_id=uuid.UUID(saved_conversation_id), - tenant_id=tenant_uuid, - role="assistant", - content=response_text, - ) - session.add(assistant_msg) - await session.execute( - text( - "UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id" - ), - {"conv_id": str(saved_conversation_id)}, - ) - await session.commit() + if response_text: + embed_and_store.delay({ + "tenant_id": tenant_id_str, + "agent_id": str(agent.id), + "user_id": user_id_str, + "role": "assistant", + "content": response_text, + }) except Exception: - logger.exception( - "Failed to save assistant message for conversation=%s", saved_conversation_id + pass # Non-fatal — memory will rebuild over time + + # ------------------------------------------------------------------- + # e. Save assistant message and send final "done" to client + # ------------------------------------------------------------------- + if response_text: + rls_token2 = current_tenant_id.set(tenant_uuid) + try: + async with async_session_factory() as session: + assistant_msg = WebConversationMessage( + conversation_id=uuid.UUID(saved_conversation_id), + tenant_id=tenant_uuid, + role="assistant", + content=response_text, + ) + session.add(assistant_msg) + await session.execute( + text( + "UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id" + ), + {"conv_id": str(saved_conversation_id)}, + ) + await session.commit() + except Exception: + logger.exception( + "Failed to save assistant message for conversation=%s", saved_conversation_id + ) + finally: + current_tenant_id.reset(rls_token2) + + # If user disconnected during streaming, send push notification + if ws_disconnected_during_stream or not is_user_connected(user_id_str): + agent_name = agent.name if hasattr(agent, "name") and agent.name else "Your AI Employee" + preview = response_text[:100] + ("..." if len(response_text) > 100 else "") + asyncio.create_task( + _send_push_notification( + user_id=user_id_str, + title=f"{agent_name} replied", + body=preview, + conversation_id=saved_conversation_id, + ) + ) + if ws_disconnected_during_stream: + break # Stop the message loop — WS is gone + + # Signal stream completion to the client + try: + await websocket.send_json({ + "type": "done", + "text": response_text, + "conversation_id": saved_conversation_id, + }) + except Exception: + pass # Client already disconnected + else: + logger.warning( + "No response received for conversation=%s", saved_conversation_id, ) - finally: - current_tenant_id.reset(rls_token2) + try: + await websocket.send_json({ + "type": "error", + "message": "I'm having trouble responding right now. Please try again.", + }) + except Exception: + pass # Client already disconnected - # Signal stream completion to the client - try: - await websocket.send_json({ - "type": "done", - "text": response_text, - "conversation_id": saved_conversation_id, - }) - except Exception: - pass # Client already disconnected - else: - logger.warning( - "No response received for conversation=%s", saved_conversation_id, - ) - try: - await websocket.send_json({ - "type": "error", - "message": "I'm having trouble responding right now. Please try again.", - }) - except Exception: - pass # Client already disconnected + finally: + # Always untrack this user when connection ends + _mark_disconnected(user_id_str, conversation_id) @web_chat_router.websocket("/chat/ws/{conversation_id}") diff --git a/packages/gateway/gateway/main.py b/packages/gateway/gateway/main.py index ead84ce..5813d6f 100644 --- a/packages/gateway/gateway/main.py +++ b/packages/gateway/gateway/main.py @@ -48,6 +48,7 @@ from shared.api import ( invitations_router, llm_keys_router, portal_router, + push_router, templates_router, usage_router, webhook_router, @@ -158,6 +159,11 @@ app.include_router(templates_router) app.include_router(chat_router) # REST: /api/portal/chat/* app.include_router(web_chat_router) # WebSocket: /chat/ws/{conversation_id} +# --------------------------------------------------------------------------- +# Phase 8 Push Notification router +# --------------------------------------------------------------------------- +app.include_router(push_router) # Push subscribe/unsubscribe/send + # --------------------------------------------------------------------------- # Routes diff --git a/packages/gateway/pyproject.toml b/packages/gateway/pyproject.toml index c7ca3b8..b883842 100644 --- a/packages/gateway/pyproject.toml +++ b/packages/gateway/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "httpx>=0.28.0", "redis>=5.0.0", "boto3>=1.35.0", + "pywebpush>=2.0.0", ] [tool.uv.sources] diff --git a/packages/portal b/packages/portal index acba978..c35a982 160000 --- a/packages/portal +++ b/packages/portal @@ -1 +1 @@ -Subproject commit acba978f2f041d4b1437d9e866600e5fab59e438 +Subproject commit c35a9822366b94adb967bdabf03e0a73a993b153 diff --git a/packages/shared/shared/api/__init__.py b/packages/shared/shared/api/__init__.py index b50d189..853bb5c 100644 --- a/packages/shared/shared/api/__init__.py +++ b/packages/shared/shared/api/__init__.py @@ -10,6 +10,7 @@ from shared.api.chat import chat_router from shared.api.invitations import invitations_router from shared.api.llm_keys import llm_keys_router from shared.api.portal import portal_router +from shared.api.push import push_router from shared.api.templates import templates_router from shared.api.usage import usage_router @@ -23,4 +24,5 @@ __all__ = [ "invitations_router", "templates_router", "chat_router", + "push_router", ] diff --git a/packages/shared/shared/api/push.py b/packages/shared/shared/api/push.py new file mode 100644 index 0000000..367a179 --- /dev/null +++ b/packages/shared/shared/api/push.py @@ -0,0 +1,232 @@ +""" +FastAPI push notification API — subscription management and send endpoint. + +Provides Web Push subscription storage so the gateway can deliver +push notifications when an AI employee responds and the user's +WebSocket is not connected. + +Endpoints: + POST /api/portal/push/subscribe — store browser push subscription + DELETE /api/portal/push/unsubscribe — remove subscription by endpoint + POST /api/portal/push/send — internal: send push to user (called by WS handler) + +Authentication: + subscribe / unsubscribe: require portal user headers (X-Portal-User-Id) + send: internal endpoint — requires same portal headers but is called by + the gateway WebSocket handler when user is offline + +Push delivery: + Uses pywebpush for VAPID-signed Web Push delivery. + Handles 410 Gone responses by deleting stale subscriptions. +""" + +from __future__ import annotations + +import json +import logging +import os +import uuid + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel +from sqlalchemy import delete, select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.ext.asyncio import AsyncSession + +from shared.api.rbac import PortalCaller, get_portal_caller +from shared.db import get_session +from shared.models.push import PushSubscription, PushSubscriptionCreate, PushSubscriptionOut, PushSendRequest + +logger = logging.getLogger(__name__) + +push_router = APIRouter(prefix="/api/portal/push", tags=["push"]) + +# --------------------------------------------------------------------------- +# VAPID config (read from environment at import time) +# --------------------------------------------------------------------------- + +VAPID_PRIVATE_KEY: str = os.environ.get("VAPID_PRIVATE_KEY", "") +VAPID_PUBLIC_KEY: str = os.environ.get("NEXT_PUBLIC_VAPID_PUBLIC_KEY", "") +VAPID_CLAIMS_EMAIL: str = os.environ.get("VAPID_CLAIMS_EMAIL", "admin@konstruct.dev") + + +# --------------------------------------------------------------------------- +# Helper — send a single push notification via pywebpush +# --------------------------------------------------------------------------- + + +async def _send_push(subscription: PushSubscription, payload: dict[str, object]) -> bool: + """ + Send a Web Push notification to a single subscription. + + Returns True on success, False if the subscription is stale (410 Gone). + Raises on other errors so the caller can decide how to handle them. + """ + if not VAPID_PRIVATE_KEY: + logger.warning("VAPID_PRIVATE_KEY not set — skipping push notification") + return True + + try: + from pywebpush import WebPusher, webpush, WebPushException # type: ignore[import] + + subscription_info = { + "endpoint": subscription.endpoint, + "keys": { + "p256dh": subscription.p256dh, + "auth": subscription.auth, + }, + } + + webpush( + subscription_info=subscription_info, + data=json.dumps(payload), + vapid_private_key=VAPID_PRIVATE_KEY, + vapid_claims={ + "sub": f"mailto:{VAPID_CLAIMS_EMAIL}", + }, + ) + return True + + except Exception as exc: + # Check for 410 Gone — subscription is no longer valid + exc_str = str(exc) + if "410" in exc_str or "Gone" in exc_str or "expired" in exc_str.lower(): + logger.info("Push subscription stale (410 Gone): %s", subscription.endpoint[:40]) + return False + logger.error("Push delivery failed: %s", exc_str) + raise + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + + +@push_router.post("/subscribe", status_code=status.HTTP_201_CREATED, response_model=PushSubscriptionOut) +async def subscribe( + body: PushSubscriptionCreate, + caller: PortalCaller = Depends(get_portal_caller), + session: AsyncSession = Depends(get_session), +) -> PushSubscriptionOut: + """ + Store a browser push subscription for the authenticated user. + + Uses INSERT ... ON CONFLICT (user_id, endpoint) DO UPDATE so + re-subscribing the same browser updates the keys without creating + a duplicate row. + """ + stmt = ( + pg_insert(PushSubscription) + .values( + user_id=caller.user_id, + tenant_id=uuid.UUID(body.tenant_id) if body.tenant_id else None, + endpoint=body.endpoint, + p256dh=body.p256dh, + auth=body.auth, + ) + .on_conflict_do_update( + constraint="uq_push_user_endpoint", + set_={ + "p256dh": body.p256dh, + "auth": body.auth, + "tenant_id": uuid.UUID(body.tenant_id) if body.tenant_id else None, + }, + ) + .returning(PushSubscription) + ) + result = await session.execute(stmt) + row = result.scalar_one() + await session.commit() + + return PushSubscriptionOut( + id=str(row.id), + endpoint=row.endpoint, + created_at=row.created_at, + ) + + +class UnsubscribeRequest(BaseModel): + endpoint: str + + +@push_router.delete("/unsubscribe", status_code=status.HTTP_204_NO_CONTENT) +async def unsubscribe( + body: UnsubscribeRequest, + caller: PortalCaller = Depends(get_portal_caller), + session: AsyncSession = Depends(get_session), +) -> None: + """Remove a push subscription for the authenticated user.""" + await session.execute( + delete(PushSubscription).where( + PushSubscription.user_id == caller.user_id, + PushSubscription.endpoint == body.endpoint, + ) + ) + await session.commit() + + +@push_router.post("/send", status_code=status.HTTP_200_OK) +async def send_push( + body: PushSendRequest, + caller: PortalCaller = Depends(get_portal_caller), + session: AsyncSession = Depends(get_session), +) -> dict[str, object]: + """ + Internal endpoint — send a push notification to all subscriptions for a user. + + Called by the gateway WebSocket handler when the agent responds but + the user's WebSocket is no longer connected. + + Handles 410 Gone by deleting stale subscriptions. + Returns counts of delivered and stale subscriptions. + """ + try: + target_user_id = uuid.UUID(body.user_id) + except ValueError as exc: + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Invalid user_id") from exc + + # Fetch all subscriptions for this user + result = await session.execute( + select(PushSubscription).where(PushSubscription.user_id == target_user_id) + ) + subscriptions = result.scalars().all() + + if not subscriptions: + return {"delivered": 0, "stale": 0, "total": 0} + + payload = { + "title": body.title, + "body": body.body, + "data": { + "conversationId": body.conversation_id, + }, + } + + delivered = 0 + stale_endpoints: list[str] = [] + + for sub in subscriptions: + try: + ok = await _send_push(sub, payload) + if ok: + delivered += 1 + else: + stale_endpoints.append(sub.endpoint) + except Exception as exc: + logger.error("Push send error for user %s: %s", body.user_id, exc) + + # Delete stale subscriptions + if stale_endpoints: + await session.execute( + delete(PushSubscription).where( + PushSubscription.user_id == target_user_id, + PushSubscription.endpoint.in_(stale_endpoints), + ) + ) + await session.commit() + + return { + "delivered": delivered, + "stale": len(stale_endpoints), + "total": len(subscriptions), + } diff --git a/packages/shared/shared/models/push.py b/packages/shared/shared/models/push.py new file mode 100644 index 0000000..6c9b787 --- /dev/null +++ b/packages/shared/shared/models/push.py @@ -0,0 +1,122 @@ +""" +Push subscription model for Web Push notifications. + +Stores browser push subscriptions for portal users so the gateway can +send push notifications when an AI employee responds and the user's +WebSocket is not connected. + +Push subscriptions are per-user, per-browser-endpoint. No RLS is applied +to this table — the API filters by user_id in the query (push subscriptions +are portal-user-scoped, not tenant-scoped). +""" + +from __future__ import annotations + +import uuid +from datetime import datetime + +from pydantic import BaseModel +from sqlalchemy import DateTime, ForeignKey, String, Text, UniqueConstraint, func +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from shared.models.tenant import Base + + +class PushSubscription(Base): + """ + Browser push subscription for a portal user. + + endpoint: The push service URL provided by the browser. + p256dh: ECDH public key for message encryption. + auth: Auth secret for message encryption. + + Unique constraint on (user_id, endpoint) — one subscription per + browser per user. Upsert on conflict avoids duplicates on re-subscribe. + """ + + __tablename__ = "push_subscriptions" + __table_args__ = ( + UniqueConstraint("user_id", "endpoint", name="uq_push_user_endpoint"), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + primary_key=True, + server_default=func.gen_random_uuid(), + ) + user_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("portal_users.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + tenant_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), + ForeignKey("tenants.id", ondelete="SET NULL"), + nullable=True, + index=True, + comment="Optional tenant scope for notification routing", + ) + endpoint: Mapped[str] = mapped_column( + Text, + nullable=False, + comment="Push service URL (browser-provided)", + ) + p256dh: Mapped[str] = mapped_column( + Text, + nullable=False, + comment="ECDH public key for payload encryption", + ) + auth: Mapped[str] = mapped_column( + Text, + nullable=False, + comment="Auth secret for payload encryption", + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + onupdate=func.now(), + ) + + def __repr__(self) -> str: + return f"" + + +# --------------------------------------------------------------------------- +# Pydantic schemas +# --------------------------------------------------------------------------- + + +class PushSubscriptionCreate(BaseModel): + """Payload for POST /portal/push/subscribe.""" + + endpoint: str + p256dh: str + auth: str + tenant_id: str | None = None + + +class PushSubscriptionOut(BaseModel): + """Response body for subscription operations.""" + + id: str + endpoint: str + created_at: datetime + + model_config = {"from_attributes": True} + + +class PushSendRequest(BaseModel): + """Internal payload for POST /portal/push/send.""" + + user_id: str + title: str + body: str + conversation_id: str | None = None