feat(03-01): DB migrations, models, encryption service, and test scaffolds

- Add stripe and cryptography to shared pyproject.toml
- Add recharts, @stripe/stripe-js, stripe to portal package.json (submodule)
- Add billing fields to Tenant model (stripe_customer_id, subscription_status, agent_quota, trial_ends_at)
- Add budget_limit_usd to Agent model
- Create TenantLlmKey and StripeEvent models in billing.py (AuditBase and Base respectively)
- Create KeyEncryptionService (MultiFernet encrypt/decrypt/rotate) in crypto.py
- Create compute_budget_status helper in usage.py (threshold logic: ok/warning/exceeded)
- Add platform_encryption_key, stripe_, slack_oauth settings to config.py
- Create Alembic migration 005 with all schema changes, RLS, grants, and composite index
- All 12 tests passing (key encryption roundtrip, rotation, budget thresholds)
This commit is contained in:
2026-03-23 21:19:09 -06:00
parent ac606cf9ff
commit 215e67a7eb
9 changed files with 1085 additions and 1 deletions

View File

@@ -0,0 +1,383 @@
"""
Usage aggregation API endpoints for the Konstruct portal.
Endpoints:
GET /api/portal/usage/{tenant_id}/summary — per-agent token usage and cost
GET /api/portal/usage/{tenant_id}/by-provider — cost grouped by provider
GET /api/portal/usage/{tenant_id}/message-volume — message count grouped by channel
GET /api/portal/usage/{tenant_id}/budget-alerts — budget threshold alerts per agent
All endpoints query the audit_events table JSONB metadata column using the
composite index (tenant_id, action_type, created_at DESC) added in migration 005.
JSONB query pattern: CAST(:param AS jsonb) required for asyncpg compatibility.
"""
from __future__ import annotations
import uuid
from datetime import date, datetime, timezone
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import AsyncSession
from shared.db import get_session
from shared.models.tenant import Agent, Tenant
usage_router = APIRouter(prefix="/api/portal/usage", tags=["usage"])
# ---------------------------------------------------------------------------
# Budget threshold helper (also used by tests directly)
# ---------------------------------------------------------------------------
def compute_budget_status(current_usd: float, budget_limit_usd: float | None) -> str:
"""
Determine budget alert status for a given usage vs. limit.
Returns:
"ok" — no limit set, or usage below 80% of limit
"warning" — usage is between 80% and 99% of limit (inclusive)
"exceeded" — usage is at or above 100% of limit
"""
if budget_limit_usd is None or budget_limit_usd <= 0:
return "ok"
ratio = current_usd / budget_limit_usd
if ratio >= 1.0:
return "exceeded"
elif ratio >= 0.8:
return "warning"
else:
return "ok"
# ---------------------------------------------------------------------------
# Pydantic response schemas
# ---------------------------------------------------------------------------
class AgentUsageSummary(BaseModel):
agent_id: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
cost_usd: float
call_count: int
class UsageSummaryResponse(BaseModel):
tenant_id: str
start_date: str
end_date: str
agents: list[AgentUsageSummary]
class ProviderUsage(BaseModel):
provider: str
cost_usd: float
call_count: int
class ProviderUsageResponse(BaseModel):
tenant_id: str
start_date: str
end_date: str
providers: list[ProviderUsage]
class ChannelVolume(BaseModel):
channel: str
message_count: int
class MessageVolumeResponse(BaseModel):
tenant_id: str
start_date: str
end_date: str
channels: list[ChannelVolume]
class BudgetAlert(BaseModel):
agent_id: str
agent_name: str
budget_limit_usd: float
current_usd: float
status: str # "ok" | "warning" | "exceeded"
class BudgetAlertsResponse(BaseModel):
tenant_id: str
alerts: list[BudgetAlert]
# ---------------------------------------------------------------------------
# Helper
# ---------------------------------------------------------------------------
def _start_of_month() -> str:
today = date.today()
return date(today.year, today.month, 1).isoformat()
def _today() -> str:
return date.today().isoformat()
async def _get_tenant_or_404(tenant_id: uuid.UUID, session: AsyncSession) -> Tenant:
result = await session.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalar_one_or_none()
if tenant is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tenant not found")
return tenant
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@usage_router.get("/{tenant_id}/summary", response_model=UsageSummaryResponse)
async def get_usage_summary(
tenant_id: uuid.UUID,
start_date: str = Query(default=None),
end_date: str = Query(default=None),
session: AsyncSession = Depends(get_session),
) -> UsageSummaryResponse:
"""
Per-agent token usage and cost aggregated from audit_events.
Aggregates: prompt_tokens, completion_tokens, cost_usd from
audit_events.metadata JSONB, filtered by action_type='llm_call'.
"""
await _get_tenant_or_404(tenant_id, session)
start = start_date or _start_of_month()
end = end_date or _today()
# Uses the composite index: (tenant_id, action_type, created_at DESC)
sql = text("""
SELECT
agent_id::text,
COALESCE(SUM((metadata->>'prompt_tokens')::numeric), 0)::bigint AS prompt_tokens,
COALESCE(SUM((metadata->>'completion_tokens')::numeric), 0)::bigint AS completion_tokens,
COALESCE(SUM((metadata->>'total_tokens')::numeric), 0)::bigint AS total_tokens,
COALESCE(SUM((metadata->>'cost_usd')::numeric), 0)::numeric(12,6) AS cost_usd,
COUNT(*)::bigint AS call_count
FROM audit_events
WHERE
tenant_id = :tenant_id
AND action_type = 'llm_call'
AND created_at >= :start_date::timestamptz
AND created_at < :end_date::timestamptz + INTERVAL '1 day'
GROUP BY agent_id
ORDER BY cost_usd DESC
""")
result = await session.execute(
sql,
{
"tenant_id": str(tenant_id),
"start_date": start,
"end_date": end,
},
)
rows = result.mappings().all()
agents = [
AgentUsageSummary(
agent_id=str(row["agent_id"]) if row["agent_id"] else "",
prompt_tokens=int(row["prompt_tokens"]),
completion_tokens=int(row["completion_tokens"]),
total_tokens=int(row["total_tokens"]),
cost_usd=float(row["cost_usd"]),
call_count=int(row["call_count"]),
)
for row in rows
]
return UsageSummaryResponse(
tenant_id=str(tenant_id),
start_date=start,
end_date=end,
agents=agents,
)
@usage_router.get("/{tenant_id}/by-provider", response_model=ProviderUsageResponse)
async def get_usage_by_provider(
tenant_id: uuid.UUID,
start_date: str = Query(default=None),
end_date: str = Query(default=None),
session: AsyncSession = Depends(get_session),
) -> ProviderUsageResponse:
"""Cost aggregated by LLM provider from audit_events.metadata.provider."""
await _get_tenant_or_404(tenant_id, session)
start = start_date or _start_of_month()
end = end_date or _today()
sql = text("""
SELECT
COALESCE(metadata->>'provider', 'unknown') AS provider,
COALESCE(SUM((metadata->>'cost_usd')::numeric), 0)::numeric(12,6) AS cost_usd,
COUNT(*)::bigint AS call_count
FROM audit_events
WHERE
tenant_id = :tenant_id
AND action_type = 'llm_call'
AND created_at >= :start_date::timestamptz
AND created_at < :end_date::timestamptz + INTERVAL '1 day'
GROUP BY provider
ORDER BY cost_usd DESC
""")
result = await session.execute(
sql,
{
"tenant_id": str(tenant_id),
"start_date": start,
"end_date": end,
},
)
rows = result.mappings().all()
providers = [
ProviderUsage(
provider=row["provider"],
cost_usd=float(row["cost_usd"]),
call_count=int(row["call_count"]),
)
for row in rows
]
return ProviderUsageResponse(
tenant_id=str(tenant_id),
start_date=start,
end_date=end,
providers=providers,
)
@usage_router.get("/{tenant_id}/message-volume", response_model=MessageVolumeResponse)
async def get_message_volume(
tenant_id: uuid.UUID,
start_date: str = Query(default=None),
end_date: str = Query(default=None),
session: AsyncSession = Depends(get_session),
) -> MessageVolumeResponse:
"""Message count grouped by channel from audit_events.metadata.channel."""
await _get_tenant_or_404(tenant_id, session)
start = start_date or _start_of_month()
end = end_date or _today()
sql = text("""
SELECT
COALESCE(metadata->>'channel', 'unknown') AS channel,
COUNT(*)::bigint AS message_count
FROM audit_events
WHERE
tenant_id = :tenant_id
AND action_type = 'llm_call'
AND created_at >= :start_date::timestamptz
AND created_at < :end_date::timestamptz + INTERVAL '1 day'
GROUP BY channel
ORDER BY message_count DESC
""")
result = await session.execute(
sql,
{
"tenant_id": str(tenant_id),
"start_date": start,
"end_date": end,
},
)
rows = result.mappings().all()
channels = [
ChannelVolume(channel=row["channel"], message_count=int(row["message_count"]))
for row in rows
]
return MessageVolumeResponse(
tenant_id=str(tenant_id),
start_date=start,
end_date=end,
channels=channels,
)
@usage_router.get("/{tenant_id}/budget-alerts", response_model=BudgetAlertsResponse)
async def get_budget_alerts(
tenant_id: uuid.UUID,
session: AsyncSession = Depends(get_session),
) -> BudgetAlertsResponse:
"""
Budget threshold alerts for agents with budget_limit_usd set.
Queries current-month cost_usd from audit_events for each agent that has
a budget limit configured. Returns status: "ok", "warning", or "exceeded".
"""
await _get_tenant_or_404(tenant_id, session)
# Load agents with a budget limit
result = await session.execute(
select(Agent).where(
Agent.tenant_id == tenant_id,
Agent.budget_limit_usd.isnot(None),
)
)
agents: list[Agent] = list(result.scalars().all())
if not agents:
return BudgetAlertsResponse(tenant_id=str(tenant_id), alerts=[])
start = _start_of_month()
# Aggregate current month cost per agent
sql = text("""
SELECT
agent_id::text,
COALESCE(SUM((metadata->>'cost_usd')::numeric), 0)::numeric(12,6) AS cost_usd
FROM audit_events
WHERE
tenant_id = :tenant_id
AND action_type = 'llm_call'
AND created_at >= :start_date::timestamptz
AND agent_id = ANY(:agent_ids)
GROUP BY agent_id
""")
agent_ids = [str(a.id) for a in agents]
cost_result = await session.execute(
sql,
{
"tenant_id": str(tenant_id),
"start_date": start,
"agent_ids": agent_ids,
},
)
cost_by_agent: dict[str, float] = {
row["agent_id"]: float(row["cost_usd"])
for row in cost_result.mappings().all()
}
alerts = []
for agent in agents:
current = cost_by_agent.get(str(agent.id), 0.0)
limit = float(agent.budget_limit_usd) # type: ignore[arg-type]
alert_status = compute_budget_status(current, limit)
alerts.append(
BudgetAlert(
agent_id=str(agent.id),
agent_name=agent.name,
budget_limit_usd=limit,
current_usd=current,
status=alert_status,
)
)
return BudgetAlertsResponse(tenant_id=str(tenant_id), alerts=alerts)

View File

@@ -129,6 +129,58 @@ class Settings(BaseSettings):
orchestrator_url: str = Field(default="http://localhost:8003")
llm_pool_url: str = Field(default="http://localhost:8004")
# -------------------------------------------------------------------------
# Encryption
# -------------------------------------------------------------------------
platform_encryption_key: str = Field(
default="",
description="Fernet key for BYO API key encryption (base64-encoded 32-byte key)",
)
platform_encryption_key_previous: str = Field(
default="",
description="Previous Fernet key retained for decryption during rotation window",
)
# -------------------------------------------------------------------------
# Stripe
# -------------------------------------------------------------------------
stripe_secret_key: str = Field(
default="",
description="Stripe secret API key (sk_live_... or sk_test_...)",
)
stripe_webhook_secret: str = Field(
default="",
description="Stripe webhook endpoint signing secret (whsec_...)",
)
stripe_per_agent_price_id: str = Field(
default="",
description="Stripe Price ID for the per-agent monthly subscription plan",
)
portal_url: str = Field(
default="http://localhost:3000",
description="Portal base URL used in Stripe checkout success/cancel redirects",
)
# -------------------------------------------------------------------------
# Slack OAuth
# -------------------------------------------------------------------------
slack_client_id: str = Field(
default="",
description="Slack OAuth app client ID",
)
slack_client_secret: str = Field(
default="",
description="Slack OAuth app client secret",
)
slack_oauth_redirect_uri: str = Field(
default="http://localhost:3000/api/slack/callback",
description="Slack OAuth redirect URI (must match Slack app config)",
)
oauth_state_secret: str = Field(
default="",
description="HMAC secret for signing OAuth state parameters (CSRF protection)",
)
# -------------------------------------------------------------------------
# Application
# -------------------------------------------------------------------------

View File

@@ -0,0 +1,87 @@
"""
KeyEncryptionService — Fernet-based encryption for BYO API keys.
Uses MultiFernet to support key rotation:
- primary key: active encryption key (all new values encrypted with this)
- previous key: optional previous key (supports decryption during rotation window)
The PLATFORM_ENCRYPTION_KEY environment variable must be a valid URL-safe
base64-encoded 32-byte key, as generated by `Fernet.generate_key()`.
Usage:
from shared.crypto import KeyEncryptionService
svc = KeyEncryptionService(primary_key=settings.platform_encryption_key)
ciphertext = svc.encrypt("sk-my-secret-key")
plaintext = svc.decrypt(ciphertext)
new_cipher = svc.rotate(old_ciphertext) # re-encrypts with primary key
"""
from __future__ import annotations
from cryptography.fernet import Fernet, MultiFernet
class KeyEncryptionService:
"""
Encrypt and decrypt BYO API keys using Fernet symmetric encryption.
Fernet guarantees:
- AES-128-CBC with PKCS7 padding
- HMAC-SHA256 authentication
- Random IV per encryption call (produces different ciphertext each time)
- Timestamp in token (can enforce TTL if desired)
MultiFernet supports key rotation:
- Encryption always uses the first (primary) key
- Decryption tries all keys in order until one succeeds
- rotate() decrypts with any key, re-encrypts with the primary key
"""
def __init__(self, primary_key: str, previous_key: str = "") -> None:
"""
Initialise the service with one or two Fernet keys.
Args:
primary_key: Active key for encryption and decryption. Must be a
URL-safe base64-encoded 32-byte value (Fernet key).
previous_key: Optional previous key retained only for decryption
during a rotation window. Pass "" to omit.
"""
keys: list[Fernet] = [Fernet(primary_key.encode())]
if previous_key:
keys.append(Fernet(previous_key.encode()))
self._multi = MultiFernet(keys)
def encrypt(self, plaintext: str) -> str:
"""
Encrypt a plaintext string.
Returns a URL-safe base64-encoded Fernet token (str).
Calling encrypt() twice with the same plaintext produces different
ciphertexts due to the random IV embedded in each Fernet token.
"""
return self._multi.encrypt(plaintext.encode()).decode()
def decrypt(self, ciphertext: str) -> str:
"""
Decrypt a Fernet token back to the original plaintext.
Raises:
cryptography.fernet.InvalidToken: if the ciphertext is invalid,
tampered, or cannot be decrypted by any of the known keys.
"""
return self._multi.decrypt(ciphertext.encode()).decode()
def rotate(self, ciphertext: str) -> str:
"""
Re-encrypt an existing ciphertext with the current primary key.
Useful for key rotation: after adding a new primary key and keeping
the old key as previous_key, call rotate() on each stored ciphertext
to migrate it to the new key. Once all values are rotated, the old
key can be removed.
Returns a new Fernet token encrypted with the primary key.
Raises InvalidToken if the ciphertext cannot be decrypted.
"""
return self._multi.rotate(ciphertext.encode()).decode()

View File

@@ -0,0 +1,131 @@
"""
SQLAlchemy 2.0 ORM models for billing and BYO API key storage.
Models:
TenantLlmKey — stores encrypted BYO API keys per tenant per provider
StripeEvent — idempotency table for processed Stripe webhook events
Design notes:
- TenantLlmKey uses AuditBase (same separate declarative base as audit_events)
because tenant_llm_keys is a sensitive, compliance-relevant table.
- StripeEvent uses Base (same as tenants/agents) because it is a simple
idempotency guard, not a sensitive record.
- TenantLlmKey has RLS enabled (tenant isolation enforced at DB level via
Alembic migration 005). Only konstruct_app SELECT, INSERT, DELETE are
granted — UPDATE is not (keys are immutable; to change a key, delete and
re-create).
- key_hint stores the last 4 characters of the plaintext API key so the
portal can display "...ABCD" without decrypting the stored ciphertext.
"""
from __future__ import annotations
import uuid
from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from shared.models.audit import AuditBase
from shared.models.tenant import Base
class TenantLlmKey(AuditBase):
"""
Encrypted BYO API key for a specific LLM provider, scoped to a tenant.
One row per (tenant_id, provider) pair — enforced by UNIQUE constraint.
RLS is ENABLED — tenant_id isolation enforced at DB level.
konstruct_app has SELECT, INSERT, DELETE only (no UPDATE).
To rotate a key: DELETE existing row, INSERT new row with updated key.
Fields:
provider — LLM provider name, e.g. "openai", "anthropic"
label — Human-readable label, e.g. "Production OpenAI Key"
encrypted_key — Fernet-encrypted API key (via KeyEncryptionService)
key_hint — Last 4 chars of plaintext key for portal display ("...XXXX")
key_version — Incremented on key rotation (future use for audit)
"""
__tablename__ = "tenant_llm_keys"
__table_args__ = (
UniqueConstraint("tenant_id", "provider", name="uq_tenant_llm_key_provider"),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
provider: Mapped[str] = mapped_column(
Text,
nullable=False,
comment="LLM provider name: openai | anthropic | cohere | groq | etc.",
)
label: Mapped[str] = mapped_column(
Text,
nullable=False,
comment="Human-readable label for the key in the portal",
)
encrypted_key: Mapped[str] = mapped_column(
Text,
nullable=False,
comment="Fernet-encrypted API key ciphertext — NEVER logged or exposed via API",
)
key_hint: Mapped[str | None] = mapped_column(
String(4),
nullable=True,
comment="Last 4 characters of plaintext key for portal display (e.g. 'ABCD')",
)
key_version: Mapped[int] = mapped_column(
Integer,
nullable=False,
default=1,
comment="Incremented on key rotation for audit trail",
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)
def __repr__(self) -> str:
return f"<TenantLlmKey id={self.id} provider={self.provider!r} tenant={self.tenant_id}>"
class StripeEvent(Base):
"""
Idempotency guard for processed Stripe webhook events.
Stripe guarantees at-least-once delivery, so the same event may arrive
multiple times. This table prevents duplicate processing via an
INSERT ... ON CONFLICT DO NOTHING pattern checked before each handler.
Fields:
event_id — Stripe event ID (e.g. "evt_1AbCdEfGhIjKlMnO") — primary key
processed_at — Timestamp when the event was first successfully processed
"""
__tablename__ = "stripe_events"
event_id: Mapped[str] = mapped_column(
Text,
primary_key=True,
comment="Stripe event ID — globally unique per Stripe account",
)
processed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)
def __repr__(self) -> str:
return f"<StripeEvent event_id={self.event_id!r}>"

View File

@@ -16,7 +16,7 @@ import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import JSON, Boolean, DateTime, Enum, ForeignKey, String, Text, UniqueConstraint, func
from sqlalchemy import JSON, Boolean, DateTime, Enum, Float, ForeignKey, Integer, String, Text, UniqueConstraint, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@@ -69,6 +69,42 @@ class Tenant(Base):
onupdate=func.now(),
)
# ---------------------------------------------------------------------------
# Billing fields (added in migration 005)
# ---------------------------------------------------------------------------
stripe_customer_id: Mapped[str | None] = mapped_column(
String(255),
nullable=True,
comment="Stripe Customer ID (cus_...)",
)
stripe_subscription_id: Mapped[str | None] = mapped_column(
String(255),
nullable=True,
comment="Stripe Subscription ID (sub_...)",
)
stripe_subscription_item_id: Mapped[str | None] = mapped_column(
String(255),
nullable=True,
comment="Stripe Subscription Item ID (si_...) for quantity updates",
)
subscription_status: Mapped[str] = mapped_column(
String(50),
nullable=False,
default="none",
comment="none | trialing | active | past_due | canceled | unpaid",
)
trial_ends_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
comment="Trial expiry timestamp (NULL for non-trial subscriptions)",
)
agent_quota: Mapped[int] = mapped_column(
Integer,
nullable=False,
default=0,
comment="Number of active agents allowed under current subscription",
)
# Relationships
agents: Mapped[list[Agent]] = relationship("Agent", back_populates="tenant", cascade="all, delete-orphan")
channel_connections: Mapped[list[ChannelConnection]] = relationship(
@@ -125,6 +161,12 @@ class Agent(Base):
comment="Whether natural language escalation phrases trigger handoff",
)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
budget_limit_usd: Mapped[float | None] = mapped_column(
Float,
nullable=True,
default=None,
comment="Monthly spend cap in USD. NULL means no limit.",
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,