feat(03-01): DB migrations, models, encryption service, and test scaffolds

- Add stripe and cryptography to shared pyproject.toml
- Add recharts, @stripe/stripe-js, stripe to portal package.json (submodule)
- Add billing fields to Tenant model (stripe_customer_id, subscription_status, agent_quota, trial_ends_at)
- Add budget_limit_usd to Agent model
- Create TenantLlmKey and StripeEvent models in billing.py (AuditBase and Base respectively)
- Create KeyEncryptionService (MultiFernet encrypt/decrypt/rotate) in crypto.py
- Create compute_budget_status helper in usage.py (threshold logic: ok/warning/exceeded)
- Add platform_encryption_key, stripe_, slack_oauth settings to config.py
- Create Alembic migration 005 with all schema changes, RLS, grants, and composite index
- All 12 tests passing (key encryption roundtrip, rotation, budget thresholds)
This commit is contained in:
2026-03-23 21:19:09 -06:00
parent ac606cf9ff
commit 215e67a7eb
9 changed files with 1085 additions and 1 deletions

View File

@@ -0,0 +1,234 @@
"""Phase 3: billing fields, tenant_llm_keys, stripe_events, audit index, agent budget
Revision ID: 005
Revises: 004
Create Date: 2026-03-24
This migration adds:
1. Billing columns on tenants table:
- stripe_customer_id, stripe_subscription_id, stripe_subscription_item_id
- subscription_status (TEXT, default 'none')
- trial_ends_at (TIMESTAMPTZ, nullable)
- agent_quota (INTEGER, default 0)
2. Budget column on agents table:
- budget_limit_usd (FLOAT, nullable) — monthly spend cap per agent
3. tenant_llm_keys table:
- Stores encrypted BYO API keys per tenant per provider
- RLS enabled (same FORCE ROW LEVEL SECURITY pattern as agents)
- UNIQUE(tenant_id, provider) constraint
- key_hint column (VARCHAR(4)) for safe portal display without decryption
- konstruct_app granted SELECT, INSERT, DELETE (no UPDATE — keys are immutable)
4. stripe_events table:
- Idempotency guard for Stripe webhook event processing
- Simple TEXT primary key (stripe event_id)
- konstruct_app granted SELECT, INSERT (no UPDATE, no DELETE)
5. Composite index on audit_events:
- idx_audit_events_tenant_type_created ON audit_events(tenant_id, action_type, created_at DESC)
- Supports cost aggregation queries in usage.py endpoints
Design:
- tenant_llm_keys intentionally grants DELETE (portal operators can remove keys)
- audit_events immutability is NOT weakened — only a new covering index is added
- stripe_events does not need RLS — it's platform-wide idempotency, not tenant-scoped
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = "005"
down_revision: Union[str, None] = "004"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# =========================================================================
# 1. Billing columns on tenants
# =========================================================================
op.add_column("tenants", sa.Column(
"stripe_customer_id",
sa.String(255),
nullable=True,
comment="Stripe Customer ID (cus_...)",
))
op.add_column("tenants", sa.Column(
"stripe_subscription_id",
sa.String(255),
nullable=True,
comment="Stripe Subscription ID (sub_...)",
))
op.add_column("tenants", sa.Column(
"stripe_subscription_item_id",
sa.String(255),
nullable=True,
comment="Stripe Subscription Item ID (si_...) for quantity updates",
))
op.add_column("tenants", sa.Column(
"subscription_status",
sa.String(50),
nullable=False,
server_default="none",
comment="none | trialing | active | past_due | canceled | unpaid",
))
op.add_column("tenants", sa.Column(
"trial_ends_at",
sa.DateTime(timezone=True),
nullable=True,
comment="Trial expiry timestamp (NULL for non-trial subscriptions)",
))
op.add_column("tenants", sa.Column(
"agent_quota",
sa.Integer,
nullable=False,
server_default="0",
comment="Number of active agents allowed under current subscription",
))
# =========================================================================
# 2. Budget column on agents
# =========================================================================
op.add_column("agents", sa.Column(
"budget_limit_usd",
sa.Float,
nullable=True,
comment="Monthly spend cap in USD. NULL means no limit.",
))
# =========================================================================
# 3. tenant_llm_keys — encrypted BYO API keys
# =========================================================================
op.create_table(
"tenant_llm_keys",
sa.Column(
"id",
UUID(as_uuid=True),
primary_key=True,
server_default=sa.text("gen_random_uuid()"),
),
sa.Column(
"tenant_id",
UUID(as_uuid=True),
sa.ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"provider",
sa.Text,
nullable=False,
comment="LLM provider: openai | anthropic | cohere | groq | etc.",
),
sa.Column(
"label",
sa.Text,
nullable=False,
comment="Human-readable label for the portal",
),
sa.Column(
"encrypted_key",
sa.Text,
nullable=False,
comment="Fernet-encrypted API key ciphertext",
),
sa.Column(
"key_hint",
sa.String(4),
nullable=True,
comment="Last 4 chars of plaintext key for portal display",
),
sa.Column(
"key_version",
sa.Integer,
nullable=False,
server_default="1",
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("NOW()"),
),
sa.UniqueConstraint("tenant_id", "provider", name="uq_tenant_llm_key_provider"),
)
op.create_index("ix_tenant_llm_keys_tenant", "tenant_llm_keys", ["tenant_id"])
# RLS: only rows matching current tenant are visible
op.execute("ALTER TABLE tenant_llm_keys ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE tenant_llm_keys FORCE ROW LEVEL SECURITY")
op.execute("""
CREATE POLICY tenant_isolation ON tenant_llm_keys
USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid)
""")
# SELECT, INSERT, DELETE — no UPDATE (keys are immutable; rotate by delete+insert)
op.execute("GRANT SELECT, INSERT, DELETE ON tenant_llm_keys TO konstruct_app")
# =========================================================================
# 4. stripe_events — webhook idempotency guard
# =========================================================================
op.create_table(
"stripe_events",
sa.Column(
"event_id",
sa.Text,
primary_key=True,
comment="Stripe event ID — globally unique per Stripe account",
),
sa.Column(
"processed_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("NOW()"),
),
)
# No RLS — platform-wide idempotency table, not tenant-scoped
op.execute("GRANT SELECT, INSERT ON stripe_events TO konstruct_app")
# =========================================================================
# 5. Composite index on audit_events for usage aggregation queries
# =========================================================================
# Covers: WHERE tenant_id = X AND action_type = 'llm_call' AND created_at >= ...
# Used by: usage.py endpoints for per-agent and per-provider cost aggregation
op.create_index(
"idx_audit_events_tenant_type_created",
"audit_events",
["tenant_id", "action_type", "created_at"],
postgresql_ops={"created_at": "DESC"},
)
def downgrade() -> None:
# Remove composite index on audit_events
op.drop_index("idx_audit_events_tenant_type_created", table_name="audit_events")
# Remove stripe_events
op.execute("REVOKE ALL ON stripe_events FROM konstruct_app")
op.drop_table("stripe_events")
# Remove tenant_llm_keys
op.execute("REVOKE ALL ON tenant_llm_keys FROM konstruct_app")
op.drop_table("tenant_llm_keys")
# Remove budget column from agents
op.drop_column("agents", "budget_limit_usd")
# Remove billing columns from tenants
op.drop_column("tenants", "agent_quota")
op.drop_column("tenants", "trial_ends_at")
op.drop_column("tenants", "subscription_status")
op.drop_column("tenants", "stripe_subscription_item_id")
op.drop_column("tenants", "stripe_subscription_id")
op.drop_column("tenants", "stripe_customer_id")

View File

@@ -20,6 +20,8 @@ dependencies = [
"slowapi>=0.1.9", "slowapi>=0.1.9",
"bcrypt>=4.0.0", "bcrypt>=4.0.0",
"pgvector>=0.3.0", "pgvector>=0.3.0",
"stripe>=10.0.0",
"cryptography>=42.0.0",
] ]
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]

View File

@@ -0,0 +1,383 @@
"""
Usage aggregation API endpoints for the Konstruct portal.
Endpoints:
GET /api/portal/usage/{tenant_id}/summary — per-agent token usage and cost
GET /api/portal/usage/{tenant_id}/by-provider — cost grouped by provider
GET /api/portal/usage/{tenant_id}/message-volume — message count grouped by channel
GET /api/portal/usage/{tenant_id}/budget-alerts — budget threshold alerts per agent
All endpoints query the audit_events table JSONB metadata column using the
composite index (tenant_id, action_type, created_at DESC) added in migration 005.
JSONB query pattern: CAST(:param AS jsonb) required for asyncpg compatibility.
"""
from __future__ import annotations
import uuid
from datetime import date, datetime, timezone
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import AsyncSession
from shared.db import get_session
from shared.models.tenant import Agent, Tenant
usage_router = APIRouter(prefix="/api/portal/usage", tags=["usage"])
# ---------------------------------------------------------------------------
# Budget threshold helper (also used by tests directly)
# ---------------------------------------------------------------------------
def compute_budget_status(current_usd: float, budget_limit_usd: float | None) -> str:
"""
Determine budget alert status for a given usage vs. limit.
Returns:
"ok" — no limit set, or usage below 80% of limit
"warning" — usage is between 80% and 99% of limit (inclusive)
"exceeded" — usage is at or above 100% of limit
"""
if budget_limit_usd is None or budget_limit_usd <= 0:
return "ok"
ratio = current_usd / budget_limit_usd
if ratio >= 1.0:
return "exceeded"
elif ratio >= 0.8:
return "warning"
else:
return "ok"
# ---------------------------------------------------------------------------
# Pydantic response schemas
# ---------------------------------------------------------------------------
class AgentUsageSummary(BaseModel):
agent_id: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
cost_usd: float
call_count: int
class UsageSummaryResponse(BaseModel):
tenant_id: str
start_date: str
end_date: str
agents: list[AgentUsageSummary]
class ProviderUsage(BaseModel):
provider: str
cost_usd: float
call_count: int
class ProviderUsageResponse(BaseModel):
tenant_id: str
start_date: str
end_date: str
providers: list[ProviderUsage]
class ChannelVolume(BaseModel):
channel: str
message_count: int
class MessageVolumeResponse(BaseModel):
tenant_id: str
start_date: str
end_date: str
channels: list[ChannelVolume]
class BudgetAlert(BaseModel):
agent_id: str
agent_name: str
budget_limit_usd: float
current_usd: float
status: str # "ok" | "warning" | "exceeded"
class BudgetAlertsResponse(BaseModel):
tenant_id: str
alerts: list[BudgetAlert]
# ---------------------------------------------------------------------------
# Helper
# ---------------------------------------------------------------------------
def _start_of_month() -> str:
today = date.today()
return date(today.year, today.month, 1).isoformat()
def _today() -> str:
return date.today().isoformat()
async def _get_tenant_or_404(tenant_id: uuid.UUID, session: AsyncSession) -> Tenant:
result = await session.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalar_one_or_none()
if tenant is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tenant not found")
return tenant
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@usage_router.get("/{tenant_id}/summary", response_model=UsageSummaryResponse)
async def get_usage_summary(
tenant_id: uuid.UUID,
start_date: str = Query(default=None),
end_date: str = Query(default=None),
session: AsyncSession = Depends(get_session),
) -> UsageSummaryResponse:
"""
Per-agent token usage and cost aggregated from audit_events.
Aggregates: prompt_tokens, completion_tokens, cost_usd from
audit_events.metadata JSONB, filtered by action_type='llm_call'.
"""
await _get_tenant_or_404(tenant_id, session)
start = start_date or _start_of_month()
end = end_date or _today()
# Uses the composite index: (tenant_id, action_type, created_at DESC)
sql = text("""
SELECT
agent_id::text,
COALESCE(SUM((metadata->>'prompt_tokens')::numeric), 0)::bigint AS prompt_tokens,
COALESCE(SUM((metadata->>'completion_tokens')::numeric), 0)::bigint AS completion_tokens,
COALESCE(SUM((metadata->>'total_tokens')::numeric), 0)::bigint AS total_tokens,
COALESCE(SUM((metadata->>'cost_usd')::numeric), 0)::numeric(12,6) AS cost_usd,
COUNT(*)::bigint AS call_count
FROM audit_events
WHERE
tenant_id = :tenant_id
AND action_type = 'llm_call'
AND created_at >= :start_date::timestamptz
AND created_at < :end_date::timestamptz + INTERVAL '1 day'
GROUP BY agent_id
ORDER BY cost_usd DESC
""")
result = await session.execute(
sql,
{
"tenant_id": str(tenant_id),
"start_date": start,
"end_date": end,
},
)
rows = result.mappings().all()
agents = [
AgentUsageSummary(
agent_id=str(row["agent_id"]) if row["agent_id"] else "",
prompt_tokens=int(row["prompt_tokens"]),
completion_tokens=int(row["completion_tokens"]),
total_tokens=int(row["total_tokens"]),
cost_usd=float(row["cost_usd"]),
call_count=int(row["call_count"]),
)
for row in rows
]
return UsageSummaryResponse(
tenant_id=str(tenant_id),
start_date=start,
end_date=end,
agents=agents,
)
@usage_router.get("/{tenant_id}/by-provider", response_model=ProviderUsageResponse)
async def get_usage_by_provider(
tenant_id: uuid.UUID,
start_date: str = Query(default=None),
end_date: str = Query(default=None),
session: AsyncSession = Depends(get_session),
) -> ProviderUsageResponse:
"""Cost aggregated by LLM provider from audit_events.metadata.provider."""
await _get_tenant_or_404(tenant_id, session)
start = start_date or _start_of_month()
end = end_date or _today()
sql = text("""
SELECT
COALESCE(metadata->>'provider', 'unknown') AS provider,
COALESCE(SUM((metadata->>'cost_usd')::numeric), 0)::numeric(12,6) AS cost_usd,
COUNT(*)::bigint AS call_count
FROM audit_events
WHERE
tenant_id = :tenant_id
AND action_type = 'llm_call'
AND created_at >= :start_date::timestamptz
AND created_at < :end_date::timestamptz + INTERVAL '1 day'
GROUP BY provider
ORDER BY cost_usd DESC
""")
result = await session.execute(
sql,
{
"tenant_id": str(tenant_id),
"start_date": start,
"end_date": end,
},
)
rows = result.mappings().all()
providers = [
ProviderUsage(
provider=row["provider"],
cost_usd=float(row["cost_usd"]),
call_count=int(row["call_count"]),
)
for row in rows
]
return ProviderUsageResponse(
tenant_id=str(tenant_id),
start_date=start,
end_date=end,
providers=providers,
)
@usage_router.get("/{tenant_id}/message-volume", response_model=MessageVolumeResponse)
async def get_message_volume(
tenant_id: uuid.UUID,
start_date: str = Query(default=None),
end_date: str = Query(default=None),
session: AsyncSession = Depends(get_session),
) -> MessageVolumeResponse:
"""Message count grouped by channel from audit_events.metadata.channel."""
await _get_tenant_or_404(tenant_id, session)
start = start_date or _start_of_month()
end = end_date or _today()
sql = text("""
SELECT
COALESCE(metadata->>'channel', 'unknown') AS channel,
COUNT(*)::bigint AS message_count
FROM audit_events
WHERE
tenant_id = :tenant_id
AND action_type = 'llm_call'
AND created_at >= :start_date::timestamptz
AND created_at < :end_date::timestamptz + INTERVAL '1 day'
GROUP BY channel
ORDER BY message_count DESC
""")
result = await session.execute(
sql,
{
"tenant_id": str(tenant_id),
"start_date": start,
"end_date": end,
},
)
rows = result.mappings().all()
channels = [
ChannelVolume(channel=row["channel"], message_count=int(row["message_count"]))
for row in rows
]
return MessageVolumeResponse(
tenant_id=str(tenant_id),
start_date=start,
end_date=end,
channels=channels,
)
@usage_router.get("/{tenant_id}/budget-alerts", response_model=BudgetAlertsResponse)
async def get_budget_alerts(
tenant_id: uuid.UUID,
session: AsyncSession = Depends(get_session),
) -> BudgetAlertsResponse:
"""
Budget threshold alerts for agents with budget_limit_usd set.
Queries current-month cost_usd from audit_events for each agent that has
a budget limit configured. Returns status: "ok", "warning", or "exceeded".
"""
await _get_tenant_or_404(tenant_id, session)
# Load agents with a budget limit
result = await session.execute(
select(Agent).where(
Agent.tenant_id == tenant_id,
Agent.budget_limit_usd.isnot(None),
)
)
agents: list[Agent] = list(result.scalars().all())
if not agents:
return BudgetAlertsResponse(tenant_id=str(tenant_id), alerts=[])
start = _start_of_month()
# Aggregate current month cost per agent
sql = text("""
SELECT
agent_id::text,
COALESCE(SUM((metadata->>'cost_usd')::numeric), 0)::numeric(12,6) AS cost_usd
FROM audit_events
WHERE
tenant_id = :tenant_id
AND action_type = 'llm_call'
AND created_at >= :start_date::timestamptz
AND agent_id = ANY(:agent_ids)
GROUP BY agent_id
""")
agent_ids = [str(a.id) for a in agents]
cost_result = await session.execute(
sql,
{
"tenant_id": str(tenant_id),
"start_date": start,
"agent_ids": agent_ids,
},
)
cost_by_agent: dict[str, float] = {
row["agent_id"]: float(row["cost_usd"])
for row in cost_result.mappings().all()
}
alerts = []
for agent in agents:
current = cost_by_agent.get(str(agent.id), 0.0)
limit = float(agent.budget_limit_usd) # type: ignore[arg-type]
alert_status = compute_budget_status(current, limit)
alerts.append(
BudgetAlert(
agent_id=str(agent.id),
agent_name=agent.name,
budget_limit_usd=limit,
current_usd=current,
status=alert_status,
)
)
return BudgetAlertsResponse(tenant_id=str(tenant_id), alerts=alerts)

View File

@@ -129,6 +129,58 @@ class Settings(BaseSettings):
orchestrator_url: str = Field(default="http://localhost:8003") orchestrator_url: str = Field(default="http://localhost:8003")
llm_pool_url: str = Field(default="http://localhost:8004") llm_pool_url: str = Field(default="http://localhost:8004")
# -------------------------------------------------------------------------
# Encryption
# -------------------------------------------------------------------------
platform_encryption_key: str = Field(
default="",
description="Fernet key for BYO API key encryption (base64-encoded 32-byte key)",
)
platform_encryption_key_previous: str = Field(
default="",
description="Previous Fernet key retained for decryption during rotation window",
)
# -------------------------------------------------------------------------
# Stripe
# -------------------------------------------------------------------------
stripe_secret_key: str = Field(
default="",
description="Stripe secret API key (sk_live_... or sk_test_...)",
)
stripe_webhook_secret: str = Field(
default="",
description="Stripe webhook endpoint signing secret (whsec_...)",
)
stripe_per_agent_price_id: str = Field(
default="",
description="Stripe Price ID for the per-agent monthly subscription plan",
)
portal_url: str = Field(
default="http://localhost:3000",
description="Portal base URL used in Stripe checkout success/cancel redirects",
)
# -------------------------------------------------------------------------
# Slack OAuth
# -------------------------------------------------------------------------
slack_client_id: str = Field(
default="",
description="Slack OAuth app client ID",
)
slack_client_secret: str = Field(
default="",
description="Slack OAuth app client secret",
)
slack_oauth_redirect_uri: str = Field(
default="http://localhost:3000/api/slack/callback",
description="Slack OAuth redirect URI (must match Slack app config)",
)
oauth_state_secret: str = Field(
default="",
description="HMAC secret for signing OAuth state parameters (CSRF protection)",
)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# Application # Application
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------

View File

@@ -0,0 +1,87 @@
"""
KeyEncryptionService — Fernet-based encryption for BYO API keys.
Uses MultiFernet to support key rotation:
- primary key: active encryption key (all new values encrypted with this)
- previous key: optional previous key (supports decryption during rotation window)
The PLATFORM_ENCRYPTION_KEY environment variable must be a valid URL-safe
base64-encoded 32-byte key, as generated by `Fernet.generate_key()`.
Usage:
from shared.crypto import KeyEncryptionService
svc = KeyEncryptionService(primary_key=settings.platform_encryption_key)
ciphertext = svc.encrypt("sk-my-secret-key")
plaintext = svc.decrypt(ciphertext)
new_cipher = svc.rotate(old_ciphertext) # re-encrypts with primary key
"""
from __future__ import annotations
from cryptography.fernet import Fernet, MultiFernet
class KeyEncryptionService:
"""
Encrypt and decrypt BYO API keys using Fernet symmetric encryption.
Fernet guarantees:
- AES-128-CBC with PKCS7 padding
- HMAC-SHA256 authentication
- Random IV per encryption call (produces different ciphertext each time)
- Timestamp in token (can enforce TTL if desired)
MultiFernet supports key rotation:
- Encryption always uses the first (primary) key
- Decryption tries all keys in order until one succeeds
- rotate() decrypts with any key, re-encrypts with the primary key
"""
def __init__(self, primary_key: str, previous_key: str = "") -> None:
"""
Initialise the service with one or two Fernet keys.
Args:
primary_key: Active key for encryption and decryption. Must be a
URL-safe base64-encoded 32-byte value (Fernet key).
previous_key: Optional previous key retained only for decryption
during a rotation window. Pass "" to omit.
"""
keys: list[Fernet] = [Fernet(primary_key.encode())]
if previous_key:
keys.append(Fernet(previous_key.encode()))
self._multi = MultiFernet(keys)
def encrypt(self, plaintext: str) -> str:
"""
Encrypt a plaintext string.
Returns a URL-safe base64-encoded Fernet token (str).
Calling encrypt() twice with the same plaintext produces different
ciphertexts due to the random IV embedded in each Fernet token.
"""
return self._multi.encrypt(plaintext.encode()).decode()
def decrypt(self, ciphertext: str) -> str:
"""
Decrypt a Fernet token back to the original plaintext.
Raises:
cryptography.fernet.InvalidToken: if the ciphertext is invalid,
tampered, or cannot be decrypted by any of the known keys.
"""
return self._multi.decrypt(ciphertext.encode()).decode()
def rotate(self, ciphertext: str) -> str:
"""
Re-encrypt an existing ciphertext with the current primary key.
Useful for key rotation: after adding a new primary key and keeping
the old key as previous_key, call rotate() on each stored ciphertext
to migrate it to the new key. Once all values are rotated, the old
key can be removed.
Returns a new Fernet token encrypted with the primary key.
Raises InvalidToken if the ciphertext cannot be decrypted.
"""
return self._multi.rotate(ciphertext.encode()).decode()

View File

@@ -0,0 +1,131 @@
"""
SQLAlchemy 2.0 ORM models for billing and BYO API key storage.
Models:
TenantLlmKey — stores encrypted BYO API keys per tenant per provider
StripeEvent — idempotency table for processed Stripe webhook events
Design notes:
- TenantLlmKey uses AuditBase (same separate declarative base as audit_events)
because tenant_llm_keys is a sensitive, compliance-relevant table.
- StripeEvent uses Base (same as tenants/agents) because it is a simple
idempotency guard, not a sensitive record.
- TenantLlmKey has RLS enabled (tenant isolation enforced at DB level via
Alembic migration 005). Only konstruct_app SELECT, INSERT, DELETE are
granted — UPDATE is not (keys are immutable; to change a key, delete and
re-create).
- key_hint stores the last 4 characters of the plaintext API key so the
portal can display "...ABCD" without decrypting the stored ciphertext.
"""
from __future__ import annotations
import uuid
from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from shared.models.audit import AuditBase
from shared.models.tenant import Base
class TenantLlmKey(AuditBase):
"""
Encrypted BYO API key for a specific LLM provider, scoped to a tenant.
One row per (tenant_id, provider) pair — enforced by UNIQUE constraint.
RLS is ENABLED — tenant_id isolation enforced at DB level.
konstruct_app has SELECT, INSERT, DELETE only (no UPDATE).
To rotate a key: DELETE existing row, INSERT new row with updated key.
Fields:
provider — LLM provider name, e.g. "openai", "anthropic"
label — Human-readable label, e.g. "Production OpenAI Key"
encrypted_key — Fernet-encrypted API key (via KeyEncryptionService)
key_hint — Last 4 chars of plaintext key for portal display ("...XXXX")
key_version — Incremented on key rotation (future use for audit)
"""
__tablename__ = "tenant_llm_keys"
__table_args__ = (
UniqueConstraint("tenant_id", "provider", name="uq_tenant_llm_key_provider"),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
provider: Mapped[str] = mapped_column(
Text,
nullable=False,
comment="LLM provider name: openai | anthropic | cohere | groq | etc.",
)
label: Mapped[str] = mapped_column(
Text,
nullable=False,
comment="Human-readable label for the key in the portal",
)
encrypted_key: Mapped[str] = mapped_column(
Text,
nullable=False,
comment="Fernet-encrypted API key ciphertext — NEVER logged or exposed via API",
)
key_hint: Mapped[str | None] = mapped_column(
String(4),
nullable=True,
comment="Last 4 characters of plaintext key for portal display (e.g. 'ABCD')",
)
key_version: Mapped[int] = mapped_column(
Integer,
nullable=False,
default=1,
comment="Incremented on key rotation for audit trail",
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)
def __repr__(self) -> str:
return f"<TenantLlmKey id={self.id} provider={self.provider!r} tenant={self.tenant_id}>"
class StripeEvent(Base):
"""
Idempotency guard for processed Stripe webhook events.
Stripe guarantees at-least-once delivery, so the same event may arrive
multiple times. This table prevents duplicate processing via an
INSERT ... ON CONFLICT DO NOTHING pattern checked before each handler.
Fields:
event_id — Stripe event ID (e.g. "evt_1AbCdEfGhIjKlMnO") — primary key
processed_at — Timestamp when the event was first successfully processed
"""
__tablename__ = "stripe_events"
event_id: Mapped[str] = mapped_column(
Text,
primary_key=True,
comment="Stripe event ID — globally unique per Stripe account",
)
processed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)
def __repr__(self) -> str:
return f"<StripeEvent event_id={self.event_id!r}>"

View File

@@ -16,7 +16,7 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from sqlalchemy import JSON, Boolean, DateTime, Enum, ForeignKey, String, Text, UniqueConstraint, func from sqlalchemy import JSON, Boolean, DateTime, Enum, Float, ForeignKey, Integer, String, Text, UniqueConstraint, func
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@@ -69,6 +69,42 @@ class Tenant(Base):
onupdate=func.now(), onupdate=func.now(),
) )
# ---------------------------------------------------------------------------
# Billing fields (added in migration 005)
# ---------------------------------------------------------------------------
stripe_customer_id: Mapped[str | None] = mapped_column(
String(255),
nullable=True,
comment="Stripe Customer ID (cus_...)",
)
stripe_subscription_id: Mapped[str | None] = mapped_column(
String(255),
nullable=True,
comment="Stripe Subscription ID (sub_...)",
)
stripe_subscription_item_id: Mapped[str | None] = mapped_column(
String(255),
nullable=True,
comment="Stripe Subscription Item ID (si_...) for quantity updates",
)
subscription_status: Mapped[str] = mapped_column(
String(50),
nullable=False,
default="none",
comment="none | trialing | active | past_due | canceled | unpaid",
)
trial_ends_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
comment="Trial expiry timestamp (NULL for non-trial subscriptions)",
)
agent_quota: Mapped[int] = mapped_column(
Integer,
nullable=False,
default=0,
comment="Number of active agents allowed under current subscription",
)
# Relationships # Relationships
agents: Mapped[list[Agent]] = relationship("Agent", back_populates="tenant", cascade="all, delete-orphan") agents: Mapped[list[Agent]] = relationship("Agent", back_populates="tenant", cascade="all, delete-orphan")
channel_connections: Mapped[list[ChannelConnection]] = relationship( channel_connections: Mapped[list[ChannelConnection]] = relationship(
@@ -125,6 +161,12 @@ class Agent(Base):
comment="Whether natural language escalation phrases trigger handoff", comment="Whether natural language escalation phrases trigger handoff",
) )
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
budget_limit_usd: Mapped[float | None] = mapped_column(
Float,
nullable=True,
default=None,
comment="Monthly spend cap in USD. NULL means no limit.",
)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), DateTime(timezone=True),
nullable=False, nullable=False,

View File

@@ -0,0 +1,65 @@
"""
Unit tests for budget alert threshold logic.
Tests thresholds:
- No budget limit (None) → status "ok", no alert
- Usage at 50% → status "ok"
- Usage at exactly 80% → status "warning"
- Usage at 95% → status "warning"
- Usage at exactly 100% → status "exceeded"
- Usage at 120% → status "exceeded"
"""
from __future__ import annotations
import pytest
from shared.api.usage import compute_budget_status
def test_budget_alert_no_limit() -> None:
"""Agent with no budget limit (None) → status 'ok', no alert."""
status = compute_budget_status(current_usd=500.0, budget_limit_usd=None)
assert status == "ok"
def test_budget_alert_under_threshold() -> None:
"""Usage at 50% of limit → status 'ok'."""
status = compute_budget_status(current_usd=50.0, budget_limit_usd=100.0)
assert status == "ok"
def test_budget_alert_just_below_warning() -> None:
"""Usage at 79% → still 'ok' (below 80% threshold)."""
status = compute_budget_status(current_usd=79.0, budget_limit_usd=100.0)
assert status == "ok"
def test_budget_alert_warning() -> None:
"""Usage at exactly 80% → status 'warning'."""
status = compute_budget_status(current_usd=80.0, budget_limit_usd=100.0)
assert status == "warning"
def test_budget_alert_warning_mid() -> None:
"""Usage at 95% → status 'warning'."""
status = compute_budget_status(current_usd=95.0, budget_limit_usd=100.0)
assert status == "warning"
def test_budget_alert_exceeded() -> None:
"""Usage at exactly 100% → status 'exceeded'."""
status = compute_budget_status(current_usd=100.0, budget_limit_usd=100.0)
assert status == "exceeded"
def test_budget_alert_over_limit() -> None:
"""Usage at 120% → status 'exceeded'."""
status = compute_budget_status(current_usd=120.0, budget_limit_usd=100.0)
assert status == "exceeded"
def test_budget_alert_zero_usage() -> None:
"""Zero usage with a limit → status 'ok'."""
status = compute_budget_status(current_usd=0.0, budget_limit_usd=50.0)
assert status == "ok"

View File

@@ -0,0 +1,88 @@
"""
Unit tests for KeyEncryptionService (Fernet-based encryption of BYO API keys).
Tests:
- encrypt/decrypt roundtrip
- different ciphertexts produced from same plaintext (Fernet random IV)
- invalid ciphertext raises InvalidToken
- MultiFernet rotation produces new ciphertext decryptable by current key
"""
from __future__ import annotations
import pytest
from cryptography.fernet import Fernet, InvalidToken
from shared.crypto import KeyEncryptionService
@pytest.fixture()
def primary_key() -> str:
"""Generate a fresh Fernet key for each test."""
return Fernet.generate_key().decode()
@pytest.fixture()
def secondary_key() -> str:
"""Second Fernet key for rotation tests."""
return Fernet.generate_key().decode()
def test_encrypt_decrypt_roundtrip(primary_key: str) -> None:
"""Encrypt then decrypt returns the original plaintext."""
svc = KeyEncryptionService(primary_key=primary_key)
plaintext = "sk-my-secret-api-key-12345"
ciphertext = svc.encrypt(plaintext)
result = svc.decrypt(ciphertext)
assert result == plaintext
def test_encrypt_produces_different_ciphertext(primary_key: str) -> None:
"""Same plaintext encrypted twice produces different ciphertexts (Fernet random IV)."""
svc = KeyEncryptionService(primary_key=primary_key)
plaintext = "sk-same-plaintext"
ct1 = svc.encrypt(plaintext)
ct2 = svc.encrypt(plaintext)
assert ct1 != ct2
# Both must still decrypt to the same value
assert svc.decrypt(ct1) == plaintext
assert svc.decrypt(ct2) == plaintext
def test_decrypt_invalid_raises(primary_key: str) -> None:
"""Decrypting garbage raises InvalidToken (or ValueError wrapping it)."""
svc = KeyEncryptionService(primary_key=primary_key)
with pytest.raises((InvalidToken, ValueError)):
svc.decrypt("this-is-not-valid-fernet-ciphertext")
def test_multifernet_rotation(primary_key: str, secondary_key: str) -> None:
"""
Rotation scenario:
1. Encrypt with 'old' key (secondary_key as primary, no previous)
2. Create new service with primary_key=new and previous=old
3. rotate(old_ciphertext) produces a new ciphertext decryptable by primary_key
"""
# Step 1: encrypt with the old (secondary) key
old_svc = KeyEncryptionService(primary_key=secondary_key)
plaintext = "sk-rotate-me"
old_ciphertext = old_svc.encrypt(plaintext)
# Step 2: new service knows both keys — primary=new, previous=old
new_svc = KeyEncryptionService(primary_key=primary_key, previous_key=secondary_key)
# Verify old ciphertext is still decryptable via the new service (previous key fallback)
assert new_svc.decrypt(old_ciphertext) == plaintext
# Step 3: rotate — re-encrypt with the primary key
rotated_ciphertext = new_svc.rotate(old_ciphertext)
assert rotated_ciphertext != old_ciphertext
# Rotated ciphertext must be decryptable by a service with only the new primary key
only_new_svc = KeyEncryptionService(primary_key=primary_key)
assert only_new_svc.decrypt(rotated_ciphertext) == plaintext