From 215e67a7ebf9af212d0a919d15f5c6c04d83d5ff Mon Sep 17 00:00:00 2001 From: Adolfo Delorenzo Date: Mon, 23 Mar 2026 21:19:09 -0600 Subject: [PATCH] feat(03-01): DB migrations, models, encryption service, and test scaffolds - Add stripe and cryptography to shared pyproject.toml - Add recharts, @stripe/stripe-js, stripe to portal package.json (submodule) - Add billing fields to Tenant model (stripe_customer_id, subscription_status, agent_quota, trial_ends_at) - Add budget_limit_usd to Agent model - Create TenantLlmKey and StripeEvent models in billing.py (AuditBase and Base respectively) - Create KeyEncryptionService (MultiFernet encrypt/decrypt/rotate) in crypto.py - Create compute_budget_status helper in usage.py (threshold logic: ok/warning/exceeded) - Add platform_encryption_key, stripe_, slack_oauth settings to config.py - Create Alembic migration 005 with all schema changes, RLS, grants, and composite index - All 12 tests passing (key encryption roundtrip, rotation, budget thresholds) --- migrations/versions/005_billing_and_usage.py | 234 +++++++++++ packages/shared/pyproject.toml | 2 + packages/shared/shared/api/usage.py | 383 +++++++++++++++++++ packages/shared/shared/config.py | 52 +++ packages/shared/shared/crypto.py | 87 +++++ packages/shared/shared/models/billing.py | 131 +++++++ packages/shared/shared/models/tenant.py | 44 ++- tests/unit/test_budget_alerts.py | 65 ++++ tests/unit/test_key_encryption.py | 88 +++++ 9 files changed, 1085 insertions(+), 1 deletion(-) create mode 100644 migrations/versions/005_billing_and_usage.py create mode 100644 packages/shared/shared/api/usage.py create mode 100644 packages/shared/shared/crypto.py create mode 100644 packages/shared/shared/models/billing.py create mode 100644 tests/unit/test_budget_alerts.py create mode 100644 tests/unit/test_key_encryption.py diff --git a/migrations/versions/005_billing_and_usage.py b/migrations/versions/005_billing_and_usage.py new file mode 100644 index 0000000..8712eae --- /dev/null +++ b/migrations/versions/005_billing_and_usage.py @@ -0,0 +1,234 @@ +"""Phase 3: billing fields, tenant_llm_keys, stripe_events, audit index, agent budget + +Revision ID: 005 +Revises: 004 +Create Date: 2026-03-24 + +This migration adds: + +1. Billing columns on tenants table: + - stripe_customer_id, stripe_subscription_id, stripe_subscription_item_id + - subscription_status (TEXT, default 'none') + - trial_ends_at (TIMESTAMPTZ, nullable) + - agent_quota (INTEGER, default 0) + +2. Budget column on agents table: + - budget_limit_usd (FLOAT, nullable) — monthly spend cap per agent + +3. tenant_llm_keys table: + - Stores encrypted BYO API keys per tenant per provider + - RLS enabled (same FORCE ROW LEVEL SECURITY pattern as agents) + - UNIQUE(tenant_id, provider) constraint + - key_hint column (VARCHAR(4)) for safe portal display without decryption + - konstruct_app granted SELECT, INSERT, DELETE (no UPDATE — keys are immutable) + +4. stripe_events table: + - Idempotency guard for Stripe webhook event processing + - Simple TEXT primary key (stripe event_id) + - konstruct_app granted SELECT, INSERT (no UPDATE, no DELETE) + +5. Composite index on audit_events: + - idx_audit_events_tenant_type_created ON audit_events(tenant_id, action_type, created_at DESC) + - Supports cost aggregation queries in usage.py endpoints + +Design: + - tenant_llm_keys intentionally grants DELETE (portal operators can remove keys) + - audit_events immutability is NOT weakened — only a new covering index is added + - stripe_events does not need RLS — it's platform-wide idempotency, not tenant-scoped +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import UUID + + +# revision identifiers, used by Alembic. +revision: str = "005" +down_revision: Union[str, None] = "004" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ========================================================================= + # 1. Billing columns on tenants + # ========================================================================= + op.add_column("tenants", sa.Column( + "stripe_customer_id", + sa.String(255), + nullable=True, + comment="Stripe Customer ID (cus_...)", + )) + op.add_column("tenants", sa.Column( + "stripe_subscription_id", + sa.String(255), + nullable=True, + comment="Stripe Subscription ID (sub_...)", + )) + op.add_column("tenants", sa.Column( + "stripe_subscription_item_id", + sa.String(255), + nullable=True, + comment="Stripe Subscription Item ID (si_...) for quantity updates", + )) + op.add_column("tenants", sa.Column( + "subscription_status", + sa.String(50), + nullable=False, + server_default="none", + comment="none | trialing | active | past_due | canceled | unpaid", + )) + op.add_column("tenants", sa.Column( + "trial_ends_at", + sa.DateTime(timezone=True), + nullable=True, + comment="Trial expiry timestamp (NULL for non-trial subscriptions)", + )) + op.add_column("tenants", sa.Column( + "agent_quota", + sa.Integer, + nullable=False, + server_default="0", + comment="Number of active agents allowed under current subscription", + )) + + # ========================================================================= + # 2. Budget column on agents + # ========================================================================= + op.add_column("agents", sa.Column( + "budget_limit_usd", + sa.Float, + nullable=True, + comment="Monthly spend cap in USD. NULL means no limit.", + )) + + # ========================================================================= + # 3. tenant_llm_keys — encrypted BYO API keys + # ========================================================================= + op.create_table( + "tenant_llm_keys", + sa.Column( + "id", + UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), + sa.Column( + "tenant_id", + UUID(as_uuid=True), + sa.ForeignKey("tenants.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "provider", + sa.Text, + nullable=False, + comment="LLM provider: openai | anthropic | cohere | groq | etc.", + ), + sa.Column( + "label", + sa.Text, + nullable=False, + comment="Human-readable label for the portal", + ), + sa.Column( + "encrypted_key", + sa.Text, + nullable=False, + comment="Fernet-encrypted API key ciphertext", + ), + sa.Column( + "key_hint", + sa.String(4), + nullable=True, + comment="Last 4 chars of plaintext key for portal display", + ), + sa.Column( + "key_version", + sa.Integer, + nullable=False, + server_default="1", + ), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("NOW()"), + ), + sa.UniqueConstraint("tenant_id", "provider", name="uq_tenant_llm_key_provider"), + ) + + op.create_index("ix_tenant_llm_keys_tenant", "tenant_llm_keys", ["tenant_id"]) + + # RLS: only rows matching current tenant are visible + op.execute("ALTER TABLE tenant_llm_keys ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE tenant_llm_keys FORCE ROW LEVEL SECURITY") + op.execute(""" + CREATE POLICY tenant_isolation ON tenant_llm_keys + USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid) + """) + + # SELECT, INSERT, DELETE — no UPDATE (keys are immutable; rotate by delete+insert) + op.execute("GRANT SELECT, INSERT, DELETE ON tenant_llm_keys TO konstruct_app") + + # ========================================================================= + # 4. stripe_events — webhook idempotency guard + # ========================================================================= + op.create_table( + "stripe_events", + sa.Column( + "event_id", + sa.Text, + primary_key=True, + comment="Stripe event ID — globally unique per Stripe account", + ), + sa.Column( + "processed_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("NOW()"), + ), + ) + + # No RLS — platform-wide idempotency table, not tenant-scoped + op.execute("GRANT SELECT, INSERT ON stripe_events TO konstruct_app") + + # ========================================================================= + # 5. Composite index on audit_events for usage aggregation queries + # ========================================================================= + # Covers: WHERE tenant_id = X AND action_type = 'llm_call' AND created_at >= ... + # Used by: usage.py endpoints for per-agent and per-provider cost aggregation + op.create_index( + "idx_audit_events_tenant_type_created", + "audit_events", + ["tenant_id", "action_type", "created_at"], + postgresql_ops={"created_at": "DESC"}, + ) + + +def downgrade() -> None: + # Remove composite index on audit_events + op.drop_index("idx_audit_events_tenant_type_created", table_name="audit_events") + + # Remove stripe_events + op.execute("REVOKE ALL ON stripe_events FROM konstruct_app") + op.drop_table("stripe_events") + + # Remove tenant_llm_keys + op.execute("REVOKE ALL ON tenant_llm_keys FROM konstruct_app") + op.drop_table("tenant_llm_keys") + + # Remove budget column from agents + op.drop_column("agents", "budget_limit_usd") + + # Remove billing columns from tenants + op.drop_column("tenants", "agent_quota") + op.drop_column("tenants", "trial_ends_at") + op.drop_column("tenants", "subscription_status") + op.drop_column("tenants", "stripe_subscription_item_id") + op.drop_column("tenants", "stripe_subscription_id") + op.drop_column("tenants", "stripe_customer_id") diff --git a/packages/shared/pyproject.toml b/packages/shared/pyproject.toml index fb51dd8..f888048 100644 --- a/packages/shared/pyproject.toml +++ b/packages/shared/pyproject.toml @@ -20,6 +20,8 @@ dependencies = [ "slowapi>=0.1.9", "bcrypt>=4.0.0", "pgvector>=0.3.0", + "stripe>=10.0.0", + "cryptography>=42.0.0", ] [tool.hatch.build.targets.wheel] diff --git a/packages/shared/shared/api/usage.py b/packages/shared/shared/api/usage.py new file mode 100644 index 0000000..2d1d009 --- /dev/null +++ b/packages/shared/shared/api/usage.py @@ -0,0 +1,383 @@ +""" +Usage aggregation API endpoints for the Konstruct portal. + +Endpoints: + GET /api/portal/usage/{tenant_id}/summary — per-agent token usage and cost + GET /api/portal/usage/{tenant_id}/by-provider — cost grouped by provider + GET /api/portal/usage/{tenant_id}/message-volume — message count grouped by channel + GET /api/portal/usage/{tenant_id}/budget-alerts — budget threshold alerts per agent + +All endpoints query the audit_events table JSONB metadata column using the +composite index (tenant_id, action_type, created_at DESC) added in migration 005. + +JSONB query pattern: CAST(:param AS jsonb) required for asyncpg compatibility. +""" + +from __future__ import annotations + +import uuid +from datetime import date, datetime, timezone +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from pydantic import BaseModel +from sqlalchemy import select, text +from sqlalchemy.ext.asyncio import AsyncSession + +from shared.db import get_session +from shared.models.tenant import Agent, Tenant + +usage_router = APIRouter(prefix="/api/portal/usage", tags=["usage"]) + + +# --------------------------------------------------------------------------- +# Budget threshold helper (also used by tests directly) +# --------------------------------------------------------------------------- + +def compute_budget_status(current_usd: float, budget_limit_usd: float | None) -> str: + """ + Determine budget alert status for a given usage vs. limit. + + Returns: + "ok" — no limit set, or usage below 80% of limit + "warning" — usage is between 80% and 99% of limit (inclusive) + "exceeded" — usage is at or above 100% of limit + """ + if budget_limit_usd is None or budget_limit_usd <= 0: + return "ok" + + ratio = current_usd / budget_limit_usd + if ratio >= 1.0: + return "exceeded" + elif ratio >= 0.8: + return "warning" + else: + return "ok" + + +# --------------------------------------------------------------------------- +# Pydantic response schemas +# --------------------------------------------------------------------------- + +class AgentUsageSummary(BaseModel): + agent_id: str + prompt_tokens: int + completion_tokens: int + total_tokens: int + cost_usd: float + call_count: int + + +class UsageSummaryResponse(BaseModel): + tenant_id: str + start_date: str + end_date: str + agents: list[AgentUsageSummary] + + +class ProviderUsage(BaseModel): + provider: str + cost_usd: float + call_count: int + + +class ProviderUsageResponse(BaseModel): + tenant_id: str + start_date: str + end_date: str + providers: list[ProviderUsage] + + +class ChannelVolume(BaseModel): + channel: str + message_count: int + + +class MessageVolumeResponse(BaseModel): + tenant_id: str + start_date: str + end_date: str + channels: list[ChannelVolume] + + +class BudgetAlert(BaseModel): + agent_id: str + agent_name: str + budget_limit_usd: float + current_usd: float + status: str # "ok" | "warning" | "exceeded" + + +class BudgetAlertsResponse(BaseModel): + tenant_id: str + alerts: list[BudgetAlert] + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + +def _start_of_month() -> str: + today = date.today() + return date(today.year, today.month, 1).isoformat() + + +def _today() -> str: + return date.today().isoformat() + + +async def _get_tenant_or_404(tenant_id: uuid.UUID, session: AsyncSession) -> Tenant: + result = await session.execute(select(Tenant).where(Tenant.id == tenant_id)) + tenant = result.scalar_one_or_none() + if tenant is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tenant not found") + return tenant + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + +@usage_router.get("/{tenant_id}/summary", response_model=UsageSummaryResponse) +async def get_usage_summary( + tenant_id: uuid.UUID, + start_date: str = Query(default=None), + end_date: str = Query(default=None), + session: AsyncSession = Depends(get_session), +) -> UsageSummaryResponse: + """ + Per-agent token usage and cost aggregated from audit_events. + + Aggregates: prompt_tokens, completion_tokens, cost_usd from + audit_events.metadata JSONB, filtered by action_type='llm_call'. + """ + await _get_tenant_or_404(tenant_id, session) + + start = start_date or _start_of_month() + end = end_date or _today() + + # Uses the composite index: (tenant_id, action_type, created_at DESC) + sql = text(""" + SELECT + agent_id::text, + COALESCE(SUM((metadata->>'prompt_tokens')::numeric), 0)::bigint AS prompt_tokens, + COALESCE(SUM((metadata->>'completion_tokens')::numeric), 0)::bigint AS completion_tokens, + COALESCE(SUM((metadata->>'total_tokens')::numeric), 0)::bigint AS total_tokens, + COALESCE(SUM((metadata->>'cost_usd')::numeric), 0)::numeric(12,6) AS cost_usd, + COUNT(*)::bigint AS call_count + FROM audit_events + WHERE + tenant_id = :tenant_id + AND action_type = 'llm_call' + AND created_at >= :start_date::timestamptz + AND created_at < :end_date::timestamptz + INTERVAL '1 day' + GROUP BY agent_id + ORDER BY cost_usd DESC + """) + + result = await session.execute( + sql, + { + "tenant_id": str(tenant_id), + "start_date": start, + "end_date": end, + }, + ) + + rows = result.mappings().all() + agents = [ + AgentUsageSummary( + agent_id=str(row["agent_id"]) if row["agent_id"] else "", + prompt_tokens=int(row["prompt_tokens"]), + completion_tokens=int(row["completion_tokens"]), + total_tokens=int(row["total_tokens"]), + cost_usd=float(row["cost_usd"]), + call_count=int(row["call_count"]), + ) + for row in rows + ] + + return UsageSummaryResponse( + tenant_id=str(tenant_id), + start_date=start, + end_date=end, + agents=agents, + ) + + +@usage_router.get("/{tenant_id}/by-provider", response_model=ProviderUsageResponse) +async def get_usage_by_provider( + tenant_id: uuid.UUID, + start_date: str = Query(default=None), + end_date: str = Query(default=None), + session: AsyncSession = Depends(get_session), +) -> ProviderUsageResponse: + """Cost aggregated by LLM provider from audit_events.metadata.provider.""" + await _get_tenant_or_404(tenant_id, session) + + start = start_date or _start_of_month() + end = end_date or _today() + + sql = text(""" + SELECT + COALESCE(metadata->>'provider', 'unknown') AS provider, + COALESCE(SUM((metadata->>'cost_usd')::numeric), 0)::numeric(12,6) AS cost_usd, + COUNT(*)::bigint AS call_count + FROM audit_events + WHERE + tenant_id = :tenant_id + AND action_type = 'llm_call' + AND created_at >= :start_date::timestamptz + AND created_at < :end_date::timestamptz + INTERVAL '1 day' + GROUP BY provider + ORDER BY cost_usd DESC + """) + + result = await session.execute( + sql, + { + "tenant_id": str(tenant_id), + "start_date": start, + "end_date": end, + }, + ) + + rows = result.mappings().all() + providers = [ + ProviderUsage( + provider=row["provider"], + cost_usd=float(row["cost_usd"]), + call_count=int(row["call_count"]), + ) + for row in rows + ] + + return ProviderUsageResponse( + tenant_id=str(tenant_id), + start_date=start, + end_date=end, + providers=providers, + ) + + +@usage_router.get("/{tenant_id}/message-volume", response_model=MessageVolumeResponse) +async def get_message_volume( + tenant_id: uuid.UUID, + start_date: str = Query(default=None), + end_date: str = Query(default=None), + session: AsyncSession = Depends(get_session), +) -> MessageVolumeResponse: + """Message count grouped by channel from audit_events.metadata.channel.""" + await _get_tenant_or_404(tenant_id, session) + + start = start_date or _start_of_month() + end = end_date or _today() + + sql = text(""" + SELECT + COALESCE(metadata->>'channel', 'unknown') AS channel, + COUNT(*)::bigint AS message_count + FROM audit_events + WHERE + tenant_id = :tenant_id + AND action_type = 'llm_call' + AND created_at >= :start_date::timestamptz + AND created_at < :end_date::timestamptz + INTERVAL '1 day' + GROUP BY channel + ORDER BY message_count DESC + """) + + result = await session.execute( + sql, + { + "tenant_id": str(tenant_id), + "start_date": start, + "end_date": end, + }, + ) + + rows = result.mappings().all() + channels = [ + ChannelVolume(channel=row["channel"], message_count=int(row["message_count"])) + for row in rows + ] + + return MessageVolumeResponse( + tenant_id=str(tenant_id), + start_date=start, + end_date=end, + channels=channels, + ) + + +@usage_router.get("/{tenant_id}/budget-alerts", response_model=BudgetAlertsResponse) +async def get_budget_alerts( + tenant_id: uuid.UUID, + session: AsyncSession = Depends(get_session), +) -> BudgetAlertsResponse: + """ + Budget threshold alerts for agents with budget_limit_usd set. + + Queries current-month cost_usd from audit_events for each agent that has + a budget limit configured. Returns status: "ok", "warning", or "exceeded". + """ + await _get_tenant_or_404(tenant_id, session) + + # Load agents with a budget limit + result = await session.execute( + select(Agent).where( + Agent.tenant_id == tenant_id, + Agent.budget_limit_usd.isnot(None), + ) + ) + agents: list[Agent] = list(result.scalars().all()) + + if not agents: + return BudgetAlertsResponse(tenant_id=str(tenant_id), alerts=[]) + + start = _start_of_month() + + # Aggregate current month cost per agent + sql = text(""" + SELECT + agent_id::text, + COALESCE(SUM((metadata->>'cost_usd')::numeric), 0)::numeric(12,6) AS cost_usd + FROM audit_events + WHERE + tenant_id = :tenant_id + AND action_type = 'llm_call' + AND created_at >= :start_date::timestamptz + AND agent_id = ANY(:agent_ids) + GROUP BY agent_id + """) + + agent_ids = [str(a.id) for a in agents] + cost_result = await session.execute( + sql, + { + "tenant_id": str(tenant_id), + "start_date": start, + "agent_ids": agent_ids, + }, + ) + cost_by_agent: dict[str, float] = { + row["agent_id"]: float(row["cost_usd"]) + for row in cost_result.mappings().all() + } + + alerts = [] + for agent in agents: + current = cost_by_agent.get(str(agent.id), 0.0) + limit = float(agent.budget_limit_usd) # type: ignore[arg-type] + alert_status = compute_budget_status(current, limit) + alerts.append( + BudgetAlert( + agent_id=str(agent.id), + agent_name=agent.name, + budget_limit_usd=limit, + current_usd=current, + status=alert_status, + ) + ) + + return BudgetAlertsResponse(tenant_id=str(tenant_id), alerts=alerts) diff --git a/packages/shared/shared/config.py b/packages/shared/shared/config.py index 61e2641..f7fe5c1 100644 --- a/packages/shared/shared/config.py +++ b/packages/shared/shared/config.py @@ -129,6 +129,58 @@ class Settings(BaseSettings): orchestrator_url: str = Field(default="http://localhost:8003") llm_pool_url: str = Field(default="http://localhost:8004") + # ------------------------------------------------------------------------- + # Encryption + # ------------------------------------------------------------------------- + platform_encryption_key: str = Field( + default="", + description="Fernet key for BYO API key encryption (base64-encoded 32-byte key)", + ) + platform_encryption_key_previous: str = Field( + default="", + description="Previous Fernet key retained for decryption during rotation window", + ) + + # ------------------------------------------------------------------------- + # Stripe + # ------------------------------------------------------------------------- + stripe_secret_key: str = Field( + default="", + description="Stripe secret API key (sk_live_... or sk_test_...)", + ) + stripe_webhook_secret: str = Field( + default="", + description="Stripe webhook endpoint signing secret (whsec_...)", + ) + stripe_per_agent_price_id: str = Field( + default="", + description="Stripe Price ID for the per-agent monthly subscription plan", + ) + portal_url: str = Field( + default="http://localhost:3000", + description="Portal base URL used in Stripe checkout success/cancel redirects", + ) + + # ------------------------------------------------------------------------- + # Slack OAuth + # ------------------------------------------------------------------------- + slack_client_id: str = Field( + default="", + description="Slack OAuth app client ID", + ) + slack_client_secret: str = Field( + default="", + description="Slack OAuth app client secret", + ) + slack_oauth_redirect_uri: str = Field( + default="http://localhost:3000/api/slack/callback", + description="Slack OAuth redirect URI (must match Slack app config)", + ) + oauth_state_secret: str = Field( + default="", + description="HMAC secret for signing OAuth state parameters (CSRF protection)", + ) + # ------------------------------------------------------------------------- # Application # ------------------------------------------------------------------------- diff --git a/packages/shared/shared/crypto.py b/packages/shared/shared/crypto.py new file mode 100644 index 0000000..cd97332 --- /dev/null +++ b/packages/shared/shared/crypto.py @@ -0,0 +1,87 @@ +""" +KeyEncryptionService — Fernet-based encryption for BYO API keys. + +Uses MultiFernet to support key rotation: + - primary key: active encryption key (all new values encrypted with this) + - previous key: optional previous key (supports decryption during rotation window) + +The PLATFORM_ENCRYPTION_KEY environment variable must be a valid URL-safe +base64-encoded 32-byte key, as generated by `Fernet.generate_key()`. + +Usage: + from shared.crypto import KeyEncryptionService + svc = KeyEncryptionService(primary_key=settings.platform_encryption_key) + ciphertext = svc.encrypt("sk-my-secret-key") + plaintext = svc.decrypt(ciphertext) + new_cipher = svc.rotate(old_ciphertext) # re-encrypts with primary key +""" + +from __future__ import annotations + +from cryptography.fernet import Fernet, MultiFernet + + +class KeyEncryptionService: + """ + Encrypt and decrypt BYO API keys using Fernet symmetric encryption. + + Fernet guarantees: + - AES-128-CBC with PKCS7 padding + - HMAC-SHA256 authentication + - Random IV per encryption call (produces different ciphertext each time) + - Timestamp in token (can enforce TTL if desired) + + MultiFernet supports key rotation: + - Encryption always uses the first (primary) key + - Decryption tries all keys in order until one succeeds + - rotate() decrypts with any key, re-encrypts with the primary key + """ + + def __init__(self, primary_key: str, previous_key: str = "") -> None: + """ + Initialise the service with one or two Fernet keys. + + Args: + primary_key: Active key for encryption and decryption. Must be a + URL-safe base64-encoded 32-byte value (Fernet key). + previous_key: Optional previous key retained only for decryption + during a rotation window. Pass "" to omit. + """ + keys: list[Fernet] = [Fernet(primary_key.encode())] + if previous_key: + keys.append(Fernet(previous_key.encode())) + self._multi = MultiFernet(keys) + + def encrypt(self, plaintext: str) -> str: + """ + Encrypt a plaintext string. + + Returns a URL-safe base64-encoded Fernet token (str). + Calling encrypt() twice with the same plaintext produces different + ciphertexts due to the random IV embedded in each Fernet token. + """ + return self._multi.encrypt(plaintext.encode()).decode() + + def decrypt(self, ciphertext: str) -> str: + """ + Decrypt a Fernet token back to the original plaintext. + + Raises: + cryptography.fernet.InvalidToken: if the ciphertext is invalid, + tampered, or cannot be decrypted by any of the known keys. + """ + return self._multi.decrypt(ciphertext.encode()).decode() + + def rotate(self, ciphertext: str) -> str: + """ + Re-encrypt an existing ciphertext with the current primary key. + + Useful for key rotation: after adding a new primary key and keeping + the old key as previous_key, call rotate() on each stored ciphertext + to migrate it to the new key. Once all values are rotated, the old + key can be removed. + + Returns a new Fernet token encrypted with the primary key. + Raises InvalidToken if the ciphertext cannot be decrypted. + """ + return self._multi.rotate(ciphertext.encode()).decode() diff --git a/packages/shared/shared/models/billing.py b/packages/shared/shared/models/billing.py new file mode 100644 index 0000000..b973f99 --- /dev/null +++ b/packages/shared/shared/models/billing.py @@ -0,0 +1,131 @@ +""" +SQLAlchemy 2.0 ORM models for billing and BYO API key storage. + +Models: + TenantLlmKey — stores encrypted BYO API keys per tenant per provider + StripeEvent — idempotency table for processed Stripe webhook events + +Design notes: + - TenantLlmKey uses AuditBase (same separate declarative base as audit_events) + because tenant_llm_keys is a sensitive, compliance-relevant table. + - StripeEvent uses Base (same as tenants/agents) because it is a simple + idempotency guard, not a sensitive record. + - TenantLlmKey has RLS enabled (tenant isolation enforced at DB level via + Alembic migration 005). Only konstruct_app SELECT, INSERT, DELETE are + granted — UPDATE is not (keys are immutable; to change a key, delete and + re-create). + - key_hint stores the last 4 characters of the plaintext API key so the + portal can display "...ABCD" without decrypting the stored ciphertext. +""" + +from __future__ import annotations + +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, func +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from shared.models.audit import AuditBase +from shared.models.tenant import Base + + +class TenantLlmKey(AuditBase): + """ + Encrypted BYO API key for a specific LLM provider, scoped to a tenant. + + One row per (tenant_id, provider) pair — enforced by UNIQUE constraint. + + RLS is ENABLED — tenant_id isolation enforced at DB level. + konstruct_app has SELECT, INSERT, DELETE only (no UPDATE). + To rotate a key: DELETE existing row, INSERT new row with updated key. + + Fields: + provider — LLM provider name, e.g. "openai", "anthropic" + label — Human-readable label, e.g. "Production OpenAI Key" + encrypted_key — Fernet-encrypted API key (via KeyEncryptionService) + key_hint — Last 4 chars of plaintext key for portal display ("...XXXX") + key_version — Incremented on key rotation (future use for audit) + """ + + __tablename__ = "tenant_llm_keys" + __table_args__ = ( + UniqueConstraint("tenant_id", "provider", name="uq_tenant_llm_key_provider"), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + tenant_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("tenants.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + provider: Mapped[str] = mapped_column( + Text, + nullable=False, + comment="LLM provider name: openai | anthropic | cohere | groq | etc.", + ) + label: Mapped[str] = mapped_column( + Text, + nullable=False, + comment="Human-readable label for the key in the portal", + ) + encrypted_key: Mapped[str] = mapped_column( + Text, + nullable=False, + comment="Fernet-encrypted API key ciphertext — NEVER logged or exposed via API", + ) + key_hint: Mapped[str | None] = mapped_column( + String(4), + nullable=True, + comment="Last 4 characters of plaintext key for portal display (e.g. 'ABCD')", + ) + key_version: Mapped[int] = mapped_column( + Integer, + nullable=False, + default=1, + comment="Incremented on key rotation for audit trail", + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ) + + def __repr__(self) -> str: + return f"" + + +class StripeEvent(Base): + """ + Idempotency guard for processed Stripe webhook events. + + Stripe guarantees at-least-once delivery, so the same event may arrive + multiple times. This table prevents duplicate processing via an + INSERT ... ON CONFLICT DO NOTHING pattern checked before each handler. + + Fields: + event_id — Stripe event ID (e.g. "evt_1AbCdEfGhIjKlMnO") — primary key + processed_at — Timestamp when the event was first successfully processed + """ + + __tablename__ = "stripe_events" + + event_id: Mapped[str] = mapped_column( + Text, + primary_key=True, + comment="Stripe event ID — globally unique per Stripe account", + ) + processed_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ) + + def __repr__(self) -> str: + return f"" diff --git a/packages/shared/shared/models/tenant.py b/packages/shared/shared/models/tenant.py index 322481f..c84b5ab 100644 --- a/packages/shared/shared/models/tenant.py +++ b/packages/shared/shared/models/tenant.py @@ -16,7 +16,7 @@ import uuid from datetime import datetime from typing import Any -from sqlalchemy import JSON, Boolean, DateTime, Enum, ForeignKey, String, Text, UniqueConstraint, func +from sqlalchemy import JSON, Boolean, DateTime, Enum, Float, ForeignKey, Integer, String, Text, UniqueConstraint, func from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -69,6 +69,42 @@ class Tenant(Base): onupdate=func.now(), ) + # --------------------------------------------------------------------------- + # Billing fields (added in migration 005) + # --------------------------------------------------------------------------- + stripe_customer_id: Mapped[str | None] = mapped_column( + String(255), + nullable=True, + comment="Stripe Customer ID (cus_...)", + ) + stripe_subscription_id: Mapped[str | None] = mapped_column( + String(255), + nullable=True, + comment="Stripe Subscription ID (sub_...)", + ) + stripe_subscription_item_id: Mapped[str | None] = mapped_column( + String(255), + nullable=True, + comment="Stripe Subscription Item ID (si_...) for quantity updates", + ) + subscription_status: Mapped[str] = mapped_column( + String(50), + nullable=False, + default="none", + comment="none | trialing | active | past_due | canceled | unpaid", + ) + trial_ends_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), + nullable=True, + comment="Trial expiry timestamp (NULL for non-trial subscriptions)", + ) + agent_quota: Mapped[int] = mapped_column( + Integer, + nullable=False, + default=0, + comment="Number of active agents allowed under current subscription", + ) + # Relationships agents: Mapped[list[Agent]] = relationship("Agent", back_populates="tenant", cascade="all, delete-orphan") channel_connections: Mapped[list[ChannelConnection]] = relationship( @@ -125,6 +161,12 @@ class Agent(Base): comment="Whether natural language escalation phrases trigger handoff", ) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + budget_limit_usd: Mapped[float | None] = mapped_column( + Float, + nullable=True, + default=None, + comment="Monthly spend cap in USD. NULL means no limit.", + ) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=False, diff --git a/tests/unit/test_budget_alerts.py b/tests/unit/test_budget_alerts.py new file mode 100644 index 0000000..b5275a8 --- /dev/null +++ b/tests/unit/test_budget_alerts.py @@ -0,0 +1,65 @@ +""" +Unit tests for budget alert threshold logic. + +Tests thresholds: + - No budget limit (None) → status "ok", no alert + - Usage at 50% → status "ok" + - Usage at exactly 80% → status "warning" + - Usage at 95% → status "warning" + - Usage at exactly 100% → status "exceeded" + - Usage at 120% → status "exceeded" +""" + +from __future__ import annotations + +import pytest + +from shared.api.usage import compute_budget_status + + +def test_budget_alert_no_limit() -> None: + """Agent with no budget limit (None) → status 'ok', no alert.""" + status = compute_budget_status(current_usd=500.0, budget_limit_usd=None) + assert status == "ok" + + +def test_budget_alert_under_threshold() -> None: + """Usage at 50% of limit → status 'ok'.""" + status = compute_budget_status(current_usd=50.0, budget_limit_usd=100.0) + assert status == "ok" + + +def test_budget_alert_just_below_warning() -> None: + """Usage at 79% → still 'ok' (below 80% threshold).""" + status = compute_budget_status(current_usd=79.0, budget_limit_usd=100.0) + assert status == "ok" + + +def test_budget_alert_warning() -> None: + """Usage at exactly 80% → status 'warning'.""" + status = compute_budget_status(current_usd=80.0, budget_limit_usd=100.0) + assert status == "warning" + + +def test_budget_alert_warning_mid() -> None: + """Usage at 95% → status 'warning'.""" + status = compute_budget_status(current_usd=95.0, budget_limit_usd=100.0) + assert status == "warning" + + +def test_budget_alert_exceeded() -> None: + """Usage at exactly 100% → status 'exceeded'.""" + status = compute_budget_status(current_usd=100.0, budget_limit_usd=100.0) + assert status == "exceeded" + + +def test_budget_alert_over_limit() -> None: + """Usage at 120% → status 'exceeded'.""" + status = compute_budget_status(current_usd=120.0, budget_limit_usd=100.0) + assert status == "exceeded" + + +def test_budget_alert_zero_usage() -> None: + """Zero usage with a limit → status 'ok'.""" + status = compute_budget_status(current_usd=0.0, budget_limit_usd=50.0) + assert status == "ok" diff --git a/tests/unit/test_key_encryption.py b/tests/unit/test_key_encryption.py new file mode 100644 index 0000000..e4ae43b --- /dev/null +++ b/tests/unit/test_key_encryption.py @@ -0,0 +1,88 @@ +""" +Unit tests for KeyEncryptionService (Fernet-based encryption of BYO API keys). + +Tests: + - encrypt/decrypt roundtrip + - different ciphertexts produced from same plaintext (Fernet random IV) + - invalid ciphertext raises InvalidToken + - MultiFernet rotation produces new ciphertext decryptable by current key +""" + +from __future__ import annotations + +import pytest +from cryptography.fernet import Fernet, InvalidToken + +from shared.crypto import KeyEncryptionService + + +@pytest.fixture() +def primary_key() -> str: + """Generate a fresh Fernet key for each test.""" + return Fernet.generate_key().decode() + + +@pytest.fixture() +def secondary_key() -> str: + """Second Fernet key for rotation tests.""" + return Fernet.generate_key().decode() + + +def test_encrypt_decrypt_roundtrip(primary_key: str) -> None: + """Encrypt then decrypt returns the original plaintext.""" + svc = KeyEncryptionService(primary_key=primary_key) + plaintext = "sk-my-secret-api-key-12345" + + ciphertext = svc.encrypt(plaintext) + result = svc.decrypt(ciphertext) + + assert result == plaintext + + +def test_encrypt_produces_different_ciphertext(primary_key: str) -> None: + """Same plaintext encrypted twice produces different ciphertexts (Fernet random IV).""" + svc = KeyEncryptionService(primary_key=primary_key) + plaintext = "sk-same-plaintext" + + ct1 = svc.encrypt(plaintext) + ct2 = svc.encrypt(plaintext) + + assert ct1 != ct2 + # Both must still decrypt to the same value + assert svc.decrypt(ct1) == plaintext + assert svc.decrypt(ct2) == plaintext + + +def test_decrypt_invalid_raises(primary_key: str) -> None: + """Decrypting garbage raises InvalidToken (or ValueError wrapping it).""" + svc = KeyEncryptionService(primary_key=primary_key) + + with pytest.raises((InvalidToken, ValueError)): + svc.decrypt("this-is-not-valid-fernet-ciphertext") + + +def test_multifernet_rotation(primary_key: str, secondary_key: str) -> None: + """ + Rotation scenario: + 1. Encrypt with 'old' key (secondary_key as primary, no previous) + 2. Create new service with primary_key=new and previous=old + 3. rotate(old_ciphertext) produces a new ciphertext decryptable by primary_key + """ + # Step 1: encrypt with the old (secondary) key + old_svc = KeyEncryptionService(primary_key=secondary_key) + plaintext = "sk-rotate-me" + old_ciphertext = old_svc.encrypt(plaintext) + + # Step 2: new service knows both keys — primary=new, previous=old + new_svc = KeyEncryptionService(primary_key=primary_key, previous_key=secondary_key) + + # Verify old ciphertext is still decryptable via the new service (previous key fallback) + assert new_svc.decrypt(old_ciphertext) == plaintext + + # Step 3: rotate — re-encrypt with the primary key + rotated_ciphertext = new_svc.rotate(old_ciphertext) + assert rotated_ciphertext != old_ciphertext + + # Rotated ciphertext must be decryptable by a service with only the new primary key + only_new_svc = KeyEncryptionService(primary_key=primary_key) + assert only_new_svc.decrypt(rotated_ciphertext) == plaintext