Files
konstruct/packages/shared/shared/api/usage.py
Adolfo Delorenzo 43b73aa6c5 feat(04-rbac-03): wire RBAC guards to all portal API endpoints + new endpoints
- Add require_platform_admin guard to GET/POST /tenants, PUT/DELETE /tenants/{id}
- Add require_tenant_member to GET /tenants/{id}, GET agents, GET agent/{id}
- Add require_tenant_admin to POST agents, PUT/DELETE agents
- Add require_tenant_admin to billing checkout and portal endpoints
- Add require_tenant_admin to channels slack/install and whatsapp/connect
- Add require_tenant_member to channels /{tid}/test
- Add require_tenant_admin to all llm_keys endpoints
- Add require_tenant_member to all usage GET endpoints
- Add POST /tenants/{tid}/agents/{aid}/test (require_tenant_member for operators)
- Add GET /tenants/{tid}/users with pending invitations (require_tenant_admin)
- Add GET /admin/users with tenant filter/role filter (require_platform_admin)
- Add POST /admin/impersonate with AuditEvent logging (require_platform_admin)
- Add POST /admin/stop-impersonation with AuditEvent logging (require_platform_admin)
2026-03-24 17:13:35 -06:00

447 lines
14 KiB
Python

"""
Usage aggregation API endpoints for the Konstruct portal.
Endpoints:
GET /api/portal/usage/{tenant_id}/summary — per-agent token usage and cost
GET /api/portal/usage/{tenant_id}/by-provider — cost grouped by provider
GET /api/portal/usage/{tenant_id}/message-volume — message count grouped by channel
GET /api/portal/usage/{tenant_id}/budget-alerts — budget threshold alerts per agent
All endpoints query the audit_events table JSONB metadata column using the
composite index (tenant_id, action_type, created_at DESC) added in migration 005.
JSONB query pattern: CAST(:param AS jsonb) required for asyncpg compatibility.
"""
from __future__ import annotations
import uuid
from datetime import date, datetime, timezone
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import AsyncSession
from shared.api.rbac import PortalCaller, require_tenant_member
from shared.db import get_session
from shared.models.tenant import Agent, Tenant
usage_router = APIRouter(prefix="/api/portal/usage", tags=["usage"])
# ---------------------------------------------------------------------------
# In-memory aggregation helpers (used by tests for unit coverage without DB)
# ---------------------------------------------------------------------------
def _aggregate_rows_by_agent(rows: list[dict]) -> list[dict]:
"""
Aggregate a list of raw audit event row dicts by agent_id.
Each row dict must have: agent_id, prompt_tokens, completion_tokens,
total_tokens, cost_usd.
Returns a list of dicts with keys: agent_id, prompt_tokens,
completion_tokens, total_tokens, cost_usd, call_count.
"""
aggregated: dict[str, dict] = {}
for row in rows:
agent_id = str(row["agent_id"])
if agent_id not in aggregated:
aggregated[agent_id] = {
"agent_id": agent_id,
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost_usd": 0.0,
"call_count": 0,
}
agg = aggregated[agent_id]
agg["prompt_tokens"] += int(row.get("prompt_tokens", 0))
agg["completion_tokens"] += int(row.get("completion_tokens", 0))
agg["total_tokens"] += int(row.get("total_tokens", 0))
agg["cost_usd"] += float(row.get("cost_usd", 0.0))
agg["call_count"] += 1
return list(aggregated.values())
def _aggregate_rows_by_provider(rows: list[dict]) -> list[dict]:
"""
Aggregate a list of raw audit event row dicts by provider.
Each row dict must have: provider, cost_usd.
Returns a list of dicts with keys: provider, cost_usd, call_count.
"""
aggregated: dict[str, dict] = {}
for row in rows:
provider = str(row.get("provider", "unknown"))
if provider not in aggregated:
aggregated[provider] = {
"provider": provider,
"cost_usd": 0.0,
"call_count": 0,
}
agg = aggregated[provider]
agg["cost_usd"] += float(row.get("cost_usd", 0.0))
agg["call_count"] += 1
return list(aggregated.values())
# ---------------------------------------------------------------------------
# Budget threshold helper (also used by tests directly)
# ---------------------------------------------------------------------------
def compute_budget_status(current_usd: float, budget_limit_usd: float | None) -> str:
"""
Determine budget alert status for a given usage vs. limit.
Returns:
"ok" — no limit set, or usage below 80% of limit
"warning" — usage is between 80% and 99% of limit (inclusive)
"exceeded" — usage is at or above 100% of limit
"""
if budget_limit_usd is None or budget_limit_usd <= 0:
return "ok"
ratio = current_usd / budget_limit_usd
if ratio >= 1.0:
return "exceeded"
elif ratio >= 0.8:
return "warning"
else:
return "ok"
# ---------------------------------------------------------------------------
# Pydantic response schemas
# ---------------------------------------------------------------------------
class AgentUsageSummary(BaseModel):
agent_id: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
cost_usd: float
call_count: int
class UsageSummaryResponse(BaseModel):
tenant_id: str
start_date: str
end_date: str
agents: list[AgentUsageSummary]
class ProviderUsage(BaseModel):
provider: str
cost_usd: float
call_count: int
class ProviderUsageResponse(BaseModel):
tenant_id: str
start_date: str
end_date: str
providers: list[ProviderUsage]
class ChannelVolume(BaseModel):
channel: str
message_count: int
class MessageVolumeResponse(BaseModel):
tenant_id: str
start_date: str
end_date: str
channels: list[ChannelVolume]
class BudgetAlert(BaseModel):
agent_id: str
agent_name: str
budget_limit_usd: float
current_usd: float
status: str # "ok" | "warning" | "exceeded"
class BudgetAlertsResponse(BaseModel):
tenant_id: str
alerts: list[BudgetAlert]
# ---------------------------------------------------------------------------
# Helper
# ---------------------------------------------------------------------------
def _start_of_month() -> str:
today = date.today()
return date(today.year, today.month, 1).isoformat()
def _today() -> str:
return date.today().isoformat()
async def _get_tenant_or_404(tenant_id: uuid.UUID, session: AsyncSession) -> Tenant:
result = await session.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalar_one_or_none()
if tenant is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tenant not found")
return tenant
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@usage_router.get("/{tenant_id}/summary", response_model=UsageSummaryResponse)
async def get_usage_summary(
tenant_id: uuid.UUID,
start_date: str = Query(default=None),
end_date: str = Query(default=None),
caller: PortalCaller = Depends(require_tenant_member),
session: AsyncSession = Depends(get_session),
) -> UsageSummaryResponse:
"""
Per-agent token usage and cost aggregated from audit_events.
Aggregates: prompt_tokens, completion_tokens, cost_usd from
audit_events.metadata JSONB, filtered by action_type='llm_call'.
"""
await _get_tenant_or_404(tenant_id, session)
start = start_date or _start_of_month()
end = end_date or _today()
# Uses the composite index: (tenant_id, action_type, created_at DESC)
sql = text("""
SELECT
agent_id::text,
COALESCE(SUM((metadata->>'prompt_tokens')::numeric), 0)::bigint AS prompt_tokens,
COALESCE(SUM((metadata->>'completion_tokens')::numeric), 0)::bigint AS completion_tokens,
COALESCE(SUM((metadata->>'total_tokens')::numeric), 0)::bigint AS total_tokens,
COALESCE(SUM((metadata->>'cost_usd')::numeric), 0)::numeric(12,6) AS cost_usd,
COUNT(*)::bigint AS call_count
FROM audit_events
WHERE
tenant_id = :tenant_id
AND action_type = 'llm_call'
AND created_at >= :start_date::timestamptz
AND created_at < :end_date::timestamptz + INTERVAL '1 day'
GROUP BY agent_id
ORDER BY cost_usd DESC
""")
result = await session.execute(
sql,
{
"tenant_id": str(tenant_id),
"start_date": start,
"end_date": end,
},
)
rows = result.mappings().all()
agents = [
AgentUsageSummary(
agent_id=str(row["agent_id"]) if row["agent_id"] else "",
prompt_tokens=int(row["prompt_tokens"]),
completion_tokens=int(row["completion_tokens"]),
total_tokens=int(row["total_tokens"]),
cost_usd=float(row["cost_usd"]),
call_count=int(row["call_count"]),
)
for row in rows
]
return UsageSummaryResponse(
tenant_id=str(tenant_id),
start_date=start,
end_date=end,
agents=agents,
)
@usage_router.get("/{tenant_id}/by-provider", response_model=ProviderUsageResponse)
async def get_usage_by_provider(
tenant_id: uuid.UUID,
start_date: str = Query(default=None),
end_date: str = Query(default=None),
caller: PortalCaller = Depends(require_tenant_member),
session: AsyncSession = Depends(get_session),
) -> ProviderUsageResponse:
"""Cost aggregated by LLM provider from audit_events.metadata.provider."""
await _get_tenant_or_404(tenant_id, session)
start = start_date or _start_of_month()
end = end_date or _today()
sql = text("""
SELECT
COALESCE(metadata->>'provider', 'unknown') AS provider,
COALESCE(SUM((metadata->>'cost_usd')::numeric), 0)::numeric(12,6) AS cost_usd,
COUNT(*)::bigint AS call_count
FROM audit_events
WHERE
tenant_id = :tenant_id
AND action_type = 'llm_call'
AND created_at >= :start_date::timestamptz
AND created_at < :end_date::timestamptz + INTERVAL '1 day'
GROUP BY provider
ORDER BY cost_usd DESC
""")
result = await session.execute(
sql,
{
"tenant_id": str(tenant_id),
"start_date": start,
"end_date": end,
},
)
rows = result.mappings().all()
providers = [
ProviderUsage(
provider=row["provider"],
cost_usd=float(row["cost_usd"]),
call_count=int(row["call_count"]),
)
for row in rows
]
return ProviderUsageResponse(
tenant_id=str(tenant_id),
start_date=start,
end_date=end,
providers=providers,
)
@usage_router.get("/{tenant_id}/message-volume", response_model=MessageVolumeResponse)
async def get_message_volume(
tenant_id: uuid.UUID,
start_date: str = Query(default=None),
end_date: str = Query(default=None),
caller: PortalCaller = Depends(require_tenant_member),
session: AsyncSession = Depends(get_session),
) -> MessageVolumeResponse:
"""Message count grouped by channel from audit_events.metadata.channel."""
await _get_tenant_or_404(tenant_id, session)
start = start_date or _start_of_month()
end = end_date or _today()
sql = text("""
SELECT
COALESCE(metadata->>'channel', 'unknown') AS channel,
COUNT(*)::bigint AS message_count
FROM audit_events
WHERE
tenant_id = :tenant_id
AND action_type = 'llm_call'
AND created_at >= :start_date::timestamptz
AND created_at < :end_date::timestamptz + INTERVAL '1 day'
GROUP BY channel
ORDER BY message_count DESC
""")
result = await session.execute(
sql,
{
"tenant_id": str(tenant_id),
"start_date": start,
"end_date": end,
},
)
rows = result.mappings().all()
channels = [
ChannelVolume(channel=row["channel"], message_count=int(row["message_count"]))
for row in rows
]
return MessageVolumeResponse(
tenant_id=str(tenant_id),
start_date=start,
end_date=end,
channels=channels,
)
@usage_router.get("/{tenant_id}/budget-alerts", response_model=BudgetAlertsResponse)
async def get_budget_alerts(
tenant_id: uuid.UUID,
caller: PortalCaller = Depends(require_tenant_member),
session: AsyncSession = Depends(get_session),
) -> BudgetAlertsResponse:
"""
Budget threshold alerts for agents with budget_limit_usd set.
Queries current-month cost_usd from audit_events for each agent that has
a budget limit configured. Returns status: "ok", "warning", or "exceeded".
"""
await _get_tenant_or_404(tenant_id, session)
# Load agents with a budget limit
result = await session.execute(
select(Agent).where(
Agent.tenant_id == tenant_id,
Agent.budget_limit_usd.isnot(None),
)
)
agents: list[Agent] = list(result.scalars().all())
if not agents:
return BudgetAlertsResponse(tenant_id=str(tenant_id), alerts=[])
start = _start_of_month()
# Aggregate current month cost per agent
sql = text("""
SELECT
agent_id::text,
COALESCE(SUM((metadata->>'cost_usd')::numeric), 0)::numeric(12,6) AS cost_usd
FROM audit_events
WHERE
tenant_id = :tenant_id
AND action_type = 'llm_call'
AND created_at >= :start_date::timestamptz
AND agent_id = ANY(:agent_ids)
GROUP BY agent_id
""")
agent_ids = [str(a.id) for a in agents]
cost_result = await session.execute(
sql,
{
"tenant_id": str(tenant_id),
"start_date": start,
"agent_ids": agent_ids,
},
)
cost_by_agent: dict[str, float] = {
row["agent_id"]: float(row["cost_usd"])
for row in cost_result.mappings().all()
}
alerts = []
for agent in agents:
current = cost_by_agent.get(str(agent.id), 0.0)
limit = float(agent.budget_limit_usd) # type: ignore[arg-type]
alert_status = compute_budget_status(current, limit)
alerts.append(
BudgetAlert(
agent_id=str(agent.id),
agent_name=agent.name,
budget_limit_usd=limit,
current_usd=current,
status=alert_status,
)
)
return BudgetAlertsResponse(tenant_id=str(tenant_id), alerts=alerts)