feat(03-01): backend API endpoints — channels, billing, usage, and audit logger enhancement

- Create channels.py: HMAC-signed OAuth state generation/verification, Slack OAuth install/callback, WhatsApp manual connect, test message endpoint
- Create billing.py: Stripe Checkout session, billing portal session, webhook handler with idempotency (StripeEvent table), subscription lifecycle management
- Update usage.py: add _aggregate_rows_by_agent and _aggregate_rows_by_provider helpers (unit-testable without DB), complete usage endpoints
- Fix audit.py: rename 'metadata' attribute to 'event_metadata' (SQLAlchemy 2.0 DeclarativeBase reserves 'metadata')
- Enhance runner.py: audit log now includes prompt_tokens, completion_tokens, total_tokens, cost_usd, provider in LLM call metadata
- Update api/__init__.py to export all new routers
- All 27 unit tests passing
This commit is contained in:
2026-03-23 21:24:08 -06:00
parent 215e67a7eb
commit 4cbf192fa5
9 changed files with 1297 additions and 3 deletions

View File

@@ -182,6 +182,19 @@ async def run_agent(
input_summary = _get_last_user_message(loop_messages)
output_summary = response_content or f"[{len(response_tool_calls)} tool calls]"
try:
# Extract token usage from LLM pool response
usage_data = data.get("usage", {}) or {}
prompt_tokens = int(usage_data.get("prompt_tokens", 0))
completion_tokens = int(usage_data.get("completion_tokens", 0))
total_tokens = prompt_tokens + completion_tokens
# Extract cost from response or estimate from usage
cost_usd = float(data.get("cost_usd", 0.0))
# Extract provider from model string (e.g. "anthropic/claude-sonnet-4" → "anthropic")
model_str: str = data.get("model", agent.model_preference) or ""
provider = model_str.split("/")[0] if "/" in model_str else model_str
await audit_logger.log_llm_call(
tenant_id=tenant_id,
agent_id=agent_uuid,
@@ -190,9 +203,14 @@ async def run_agent(
output_summary=output_summary,
latency_ms=call_latency_ms,
metadata={
"model": data.get("model", agent.model_preference),
"model": model_str,
"provider": provider,
"iteration": iteration,
"tool_calls_count": len(response_tool_calls),
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"cost_usd": cost_usd,
},
)
except Exception:

View File

@@ -4,6 +4,15 @@ Konstruct shared API routers.
Import and mount these routers in service main.py files.
"""
from shared.api.billing import billing_router, webhook_router
from shared.api.channels import channels_router
from shared.api.portal import portal_router
from shared.api.usage import usage_router
__all__ = ["portal_router"]
__all__ = [
"portal_router",
"channels_router",
"billing_router",
"webhook_router",
"usage_router",
]

View File

@@ -0,0 +1,333 @@
"""
Billing API endpoints — Stripe Checkout, Billing Portal, and webhook handler.
Endpoints:
POST /api/portal/billing/checkout → create Stripe Checkout Session
POST /api/portal/billing/portal → create Stripe Billing Portal session
POST /api/webhooks/stripe → Stripe webhook event handler
Webhook idempotency:
All webhook events are checked against the stripe_events table before
processing. If the event_id already exists, the handler returns "already_processed"
immediately. This prevents duplicate processing on Stripe's at-least-once delivery.
StripeClient pattern:
Uses `stripe.StripeClient(api_key=...)` (new v14+ API) rather than the legacy
`stripe.api_key = ...` module-level approach, which is not thread-safe.
Webhook verification:
Uses `client.webhooks.construct_event(payload, sig_header, webhook_secret)` to
verify the Stripe-Signature header before processing any event data.
"""
from __future__ import annotations
import logging
import uuid
from typing import Any
import stripe
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from sqlalchemy import select, text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from shared.config import settings
from shared.db import get_session
from shared.models.billing import StripeEvent
from shared.models.tenant import Agent, Tenant
logger = logging.getLogger(__name__)
billing_router = APIRouter(prefix="/api/portal/billing", tags=["billing"])
webhook_router = APIRouter(prefix="/api/webhooks", tags=["webhooks"])
# ---------------------------------------------------------------------------
# Pydantic schemas
# ---------------------------------------------------------------------------
class CheckoutRequest(BaseModel):
tenant_id: uuid.UUID
agent_count: int = 1
class CheckoutResponse(BaseModel):
url: str
session_id: str
class PortalRequest(BaseModel):
tenant_id: uuid.UUID
class PortalResponse(BaseModel):
url: str
# ---------------------------------------------------------------------------
# Internal helpers (module-level for testability — patchable)
# ---------------------------------------------------------------------------
async def _get_tenant_by_stripe_customer(
customer_id: str, session: AsyncSession
) -> Tenant | None:
"""Look up a tenant by Stripe customer ID."""
result = await session.execute(
select(Tenant).where(Tenant.stripe_customer_id == customer_id)
)
return result.scalar_one_or_none()
async def _deactivate_all_agents(session: AsyncSession, tenant_id: uuid.UUID) -> None:
"""Set is_active=False for all agents belonging to the given tenant."""
await session.execute(
text("UPDATE agents SET is_active = FALSE WHERE tenant_id = :tenant_id"),
{"tenant_id": str(tenant_id)},
)
async def _reactivate_agents(session: AsyncSession, tenant_id: uuid.UUID) -> None:
"""Set is_active=True for all agents belonging to the given tenant."""
await session.execute(
text("UPDATE agents SET is_active = TRUE WHERE tenant_id = :tenant_id"),
{"tenant_id": str(tenant_id)},
)
async def _get_tenant_or_404(tenant_id: uuid.UUID, session: AsyncSession) -> Tenant:
result = await session.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalar_one_or_none()
if tenant is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tenant not found")
return tenant
# ---------------------------------------------------------------------------
# Core webhook event processor (extracted for testability)
# ---------------------------------------------------------------------------
async def process_stripe_event(event_data: dict[str, Any], session: AsyncSession) -> str:
"""
Process a Stripe webhook event dict.
Returns:
"already_processed" — if the event was already processed (idempotency guard)
"ok" — if the event was processed successfully
"skipped" — if the event type is not handled
This function is separated from the HTTP handler so it can be unit-tested
without a real HTTP request.
"""
event_id: str = event_data["id"]
event_type: str = event_data["type"]
# Idempotency check — look for existing StripeEvent record
existing = await session.execute(
select(StripeEvent).where(StripeEvent.event_id == event_id)
)
if existing.scalar_one_or_none() is not None:
return "already_processed"
# Mark as processing — INSERT with conflict guard
stripe_event_row = StripeEvent(event_id=event_id)
session.add(stripe_event_row)
try:
await session.flush()
except IntegrityError:
# Another process inserted this event_id concurrently
await session.rollback()
return "already_processed"
obj = event_data["data"]["object"]
customer_id: str = obj.get("customer", "")
if event_type == "checkout.session.completed":
subscription_id = obj.get("subscription", "")
if subscription_id and customer_id:
tenant = await _get_tenant_by_stripe_customer(customer_id, session)
if tenant:
# Fetch subscription details to get item ID and quota
tenant.stripe_subscription_id = subscription_id
tenant.subscription_status = "trialing"
elif event_type == "customer.subscription.updated":
tenant = await _get_tenant_by_stripe_customer(customer_id, session)
if tenant:
new_status: str = obj.get("status", "active")
items = obj.get("items", {}).get("data", [])
quantity = items[0]["quantity"] if items else 0
sub_item_id = items[0]["id"] if items else None
tenant.subscription_status = new_status
tenant.agent_quota = quantity
tenant.stripe_subscription_id = obj.get("id", tenant.stripe_subscription_id)
if sub_item_id:
tenant.stripe_subscription_item_id = sub_item_id
trial_end = obj.get("trial_end")
if trial_end:
from datetime import datetime, timezone
tenant.trial_ends_at = datetime.fromtimestamp(trial_end, tz=timezone.utc)
elif event_type == "customer.subscription.deleted":
tenant = await _get_tenant_by_stripe_customer(customer_id, session)
if tenant:
tenant.subscription_status = "canceled"
tenant.agent_quota = 0
await _deactivate_all_agents(session, tenant.id)
elif event_type == "invoice.paid":
tenant = await _get_tenant_by_stripe_customer(customer_id, session)
if tenant:
tenant.subscription_status = "active"
await _reactivate_agents(session, tenant.id)
elif event_type == "invoice.payment_failed":
tenant = await _get_tenant_by_stripe_customer(customer_id, session)
if tenant:
tenant.subscription_status = "past_due"
else:
logger.debug("Unhandled Stripe event type: %s", event_type)
await session.commit()
return "skipped"
await session.commit()
return "ok"
# ---------------------------------------------------------------------------
# Billing endpoints
# ---------------------------------------------------------------------------
@billing_router.post("/checkout", response_model=CheckoutResponse)
async def create_checkout_session(
body: CheckoutRequest,
session: AsyncSession = Depends(get_session),
) -> CheckoutResponse:
"""
Create a Stripe Checkout Session for the per-agent subscription plan.
Creates a Stripe Customer lazily if the tenant does not yet have one.
Uses 14-day trial period on first checkout.
"""
if not settings.stripe_secret_key:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="STRIPE_SECRET_KEY not configured",
)
tenant = await _get_tenant_or_404(body.tenant_id, session)
client = stripe.StripeClient(api_key=settings.stripe_secret_key)
# Lazy Stripe Customer creation
if not tenant.stripe_customer_id:
customer = client.customers.create(params={"name": tenant.name})
tenant.stripe_customer_id = customer.id
await session.commit()
# Create Checkout Session
checkout_session = client.checkout.sessions.create(
params={
"customer": tenant.stripe_customer_id,
"mode": "subscription",
"line_items": [
{
"price": settings.stripe_per_agent_price_id,
"quantity": body.agent_count,
}
],
"subscription_data": {"trial_period_days": 14},
"success_url": f"{settings.portal_url}/settings/billing?success=1",
"cancel_url": f"{settings.portal_url}/settings/billing?canceled=1",
}
)
return CheckoutResponse(
url=checkout_session.url or "",
session_id=checkout_session.id,
)
@billing_router.post("/portal", response_model=PortalResponse)
async def create_billing_portal_session(
body: PortalRequest,
session: AsyncSession = Depends(get_session),
) -> PortalResponse:
"""
Create a Stripe Billing Portal session for the tenant.
Returns the Stripe-hosted portal URL where the customer can manage their
subscription, update payment methods, and view invoices.
"""
if not settings.stripe_secret_key:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="STRIPE_SECRET_KEY not configured",
)
tenant = await _get_tenant_or_404(body.tenant_id, session)
if not tenant.stripe_customer_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Tenant has no Stripe customer — start a subscription first",
)
client = stripe.StripeClient(api_key=settings.stripe_secret_key)
portal_session = client.billing_portal.sessions.create(
params={
"customer": tenant.stripe_customer_id,
"return_url": f"{settings.portal_url}/settings/billing",
}
)
return PortalResponse(url=portal_session.url)
# ---------------------------------------------------------------------------
# Stripe webhook endpoint (no portal auth — separate router prefix)
# ---------------------------------------------------------------------------
@webhook_router.post("/stripe")
async def stripe_webhook(
request: Request,
session: AsyncSession = Depends(get_session),
) -> dict[str, str]:
"""
Handle incoming Stripe webhook events.
1. Read raw request body (required for HMAC signature verification)
2. Verify Stripe-Signature header with stripe.Webhook.construct_event()
3. Check idempotency via stripe_events table
4. Dispatch to handler by event type
"""
if not settings.stripe_webhook_secret:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="STRIPE_WEBHOOK_SECRET not configured",
)
payload = await request.body()
sig_header = request.headers.get("stripe-signature", "")
client = stripe.StripeClient(api_key=settings.stripe_secret_key)
try:
event = client.webhooks.construct_event(
payload=payload,
sig=sig_header,
secret=settings.stripe_webhook_secret,
)
except stripe.SignatureVerificationError as exc:
logger.warning("Stripe webhook signature verification failed: %s", exc)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid Stripe webhook signature",
) from exc
# Convert stripe Event to plain dict for the processor
event_data = event.to_dict()
result = await process_stripe_event(event_data, session)
return {"status": result}

View File

@@ -0,0 +1,482 @@
"""
Channel connection API endpoints — Slack OAuth, WhatsApp manual connect, test messaging.
Endpoints:
GET /api/portal/channels/slack/install?tenant_id={id}
→ generates HMAC-signed OAuth state, returns Slack authorize URL
GET /api/portal/channels/slack/callback?code={code}&state={state}
→ verifies state, exchanges code for bot_token, stores in channel_connections
POST /api/portal/channels/whatsapp/connect
→ validates system_user_token, encrypts it, stores in channel_connections
POST /api/portal/channels/{tenant_id}/test
→ sends a test message via the connected channel
OAuth state format (HMAC-SHA256 signed):
base64url( JSON({ tenant_id, nonce }) + "." + HMAC-SHA256(payload, secret) )
This design provides CSRF protection for the Slack OAuth callback:
- The nonce prevents replay attacks (state is single-use via constant-time comparison)
- The HMAC signature prevents forgery
- The tenant_id embedded in the state is recovered after verification
See STATE.md decision: HMAC uses hmac.new() with hmac.compare_digest for timing-safe verification.
"""
from __future__ import annotations
import base64
import hashlib
import hmac
import json
import secrets
import uuid
from typing import Any
import httpx
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import AsyncSession
from shared.config import settings
from shared.crypto import KeyEncryptionService
from shared.db import get_session
from shared.models.tenant import ChannelConnection, ChannelTypeEnum, Tenant
channels_router = APIRouter(prefix="/api/portal/channels", tags=["channels"])
# Slack scopes required for AI employee functionality
_SLACK_SCOPES = ",".join([
"app_mentions:read",
"channels:read",
"channels:history",
"chat:write",
"im:read",
"im:write",
"im:history",
])
# ---------------------------------------------------------------------------
# HMAC OAuth state generation / verification
# ---------------------------------------------------------------------------
def generate_oauth_state(tenant_id: str, secret: str) -> str:
"""
Generate a URL-safe, HMAC-signed OAuth state parameter.
Format:
base64url(payload_json).base64url(hmac_sig)
Where:
payload_json = {"tenant_id": tenant_id, "nonce": random_hex}
hmac_sig = HMAC-SHA256(payload_json, secret)
Returns:
A URL-safe base64-encoded string containing the signed state.
"""
nonce = secrets.token_hex(16)
payload = json.dumps({"tenant_id": tenant_id, "nonce": nonce}, separators=(",", ":"))
payload_b64 = base64.urlsafe_b64encode(payload.encode()).decode().rstrip("=")
sig = hmac.new(secret.encode(), payload.encode(), hashlib.sha256).digest()
sig_b64 = base64.urlsafe_b64encode(sig).decode().rstrip("=")
return f"{payload_b64}.{sig_b64}"
def verify_oauth_state(state: str, secret: str) -> str:
"""
Verify an HMAC-signed OAuth state and return the embedded tenant_id.
Raises:
ValueError: if the state is malformed, tampered, or the HMAC is invalid.
Returns:
The tenant_id embedded in the state payload.
"""
try:
parts = state.split(".", 1)
if len(parts) != 2:
raise ValueError("Malformed state: expected exactly one '.' separator")
payload_b64, sig_b64 = parts
# Decode payload (add padding back)
payload_bytes = base64.urlsafe_b64decode(payload_b64 + "==")
payload_str = payload_bytes.decode()
# Recompute expected HMAC
expected_sig = hmac.new(secret.encode(), payload_str.encode(), hashlib.sha256).digest()
expected_sig_b64 = base64.urlsafe_b64encode(expected_sig).decode().rstrip("=")
# Timing-safe comparison
if not hmac.compare_digest(sig_b64, expected_sig_b64):
raise ValueError("HMAC signature mismatch — state may have been tampered")
payload = json.loads(payload_str)
tenant_id: str = payload["tenant_id"]
return tenant_id
except (KeyError, ValueError, UnicodeDecodeError, json.JSONDecodeError) as exc:
raise ValueError(f"Invalid OAuth state: {exc}") from exc
# ---------------------------------------------------------------------------
# Pydantic schemas
# ---------------------------------------------------------------------------
class SlackInstallResponse(BaseModel):
url: str
state: str
class WhatsAppConnectRequest(BaseModel):
tenant_id: uuid.UUID
phone_number_id: str
waba_id: str
system_user_token: str
class WhatsAppConnectResponse(BaseModel):
id: str
tenant_id: str
channel_type: str
workspace_id: str
class TestChannelRequest(BaseModel):
channel_type: str
class TestChannelResponse(BaseModel):
success: bool
message: str
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_encryption_service() -> KeyEncryptionService:
"""Return the platform-level KeyEncryptionService from settings."""
if not settings.platform_encryption_key:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="PLATFORM_ENCRYPTION_KEY not configured",
)
return KeyEncryptionService(
primary_key=settings.platform_encryption_key,
previous_key=settings.platform_encryption_key_previous,
)
async def _get_tenant_or_404(tenant_id: uuid.UUID, session: AsyncSession) -> Tenant:
result = await session.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalar_one_or_none()
if tenant is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tenant not found")
return tenant
# ---------------------------------------------------------------------------
# Slack OAuth endpoints
# ---------------------------------------------------------------------------
@channels_router.get("/slack/install", response_model=SlackInstallResponse)
async def slack_install(
tenant_id: uuid.UUID = Query(...),
) -> SlackInstallResponse:
"""
Generate the Slack OAuth authorization URL for installing the app.
Returns the URL the operator should redirect their browser to, plus the
opaque state token (for debugging / manual testing).
"""
if not settings.oauth_state_secret:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAUTH_STATE_SECRET not configured",
)
state = generate_oauth_state(tenant_id=str(tenant_id), secret=settings.oauth_state_secret)
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={settings.slack_client_id}"
f"&scope={_SLACK_SCOPES}"
f"&redirect_uri={settings.slack_oauth_redirect_uri}"
f"&state={state}"
)
return SlackInstallResponse(url=url, state=state)
@channels_router.get("/slack/callback")
async def slack_callback(
code: str = Query(...),
state: str = Query(...),
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""
Handle the Slack OAuth callback.
1. Verify HMAC state (CSRF protection)
2. Exchange code for bot_token via Slack API
3. Encrypt bot_token and store in channel_connections
"""
if not settings.oauth_state_secret:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAUTH_STATE_SECRET not configured",
)
# Verify state (raises ValueError on tampering)
try:
tenant_id_str = verify_oauth_state(state=state, secret=settings.oauth_state_secret)
tenant_id = uuid.UUID(tenant_id_str)
except (ValueError, Exception) as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid OAuth state: {exc}",
) from exc
await _get_tenant_or_404(tenant_id, session)
# Exchange code for tokens via Slack API
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
"https://slack.com/api/oauth.v2.access",
data={
"code": code,
"client_id": settings.slack_client_id,
"client_secret": settings.slack_client_secret,
"redirect_uri": settings.slack_oauth_redirect_uri,
},
)
if response.status_code != 200:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Slack API returned an error during token exchange",
)
slack_data = response.json()
if not slack_data.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Slack OAuth error: {slack_data.get('error', 'unknown')}",
)
bot_token: str = slack_data["access_token"]
workspace_id: str = slack_data["team"]["id"]
team_name: str = slack_data["team"]["name"]
bot_user_id: str = slack_data.get("bot_user_id", "")
# Encrypt bot_token before storage
enc_svc = _get_encryption_service()
encrypted_token = enc_svc.encrypt(bot_token)
# Check for existing connection
existing = await session.execute(
select(ChannelConnection).where(
ChannelConnection.channel_type == ChannelTypeEnum.SLACK,
ChannelConnection.workspace_id == workspace_id,
)
)
conn = existing.scalar_one_or_none()
if conn is None:
conn = ChannelConnection(
tenant_id=tenant_id,
channel_type=ChannelTypeEnum.SLACK,
workspace_id=workspace_id,
config={
"bot_token": encrypted_token,
"bot_user_id": bot_user_id,
"team_name": team_name,
},
)
session.add(conn)
else:
conn.config = {
"bot_token": encrypted_token,
"bot_user_id": bot_user_id,
"team_name": team_name,
}
await session.commit()
return {
"ok": True,
"workspace_id": workspace_id,
"team_name": team_name,
"tenant_id": str(tenant_id),
}
# ---------------------------------------------------------------------------
# WhatsApp manual connect endpoint
# ---------------------------------------------------------------------------
@channels_router.post("/whatsapp/connect", response_model=WhatsAppConnectResponse, status_code=status.HTTP_201_CREATED)
async def whatsapp_connect(
body: WhatsAppConnectRequest,
session: AsyncSession = Depends(get_session),
) -> WhatsAppConnectResponse:
"""
Manually connect a WhatsApp Business phone number to a tenant.
Validates the system_user_token by calling the Meta Graph API, then
encrypts and stores the token in channel_connections.
"""
await _get_tenant_or_404(body.tenant_id, session)
# Validate token by calling Meta Graph API
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(
f"https://graph.facebook.com/v22.0/{body.phone_number_id}",
headers={"Authorization": f"Bearer {body.system_user_token}"},
)
if response.status_code != 200:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Meta Graph API validation failed: {response.status_code}",
)
# Encrypt token before storage
enc_svc = _get_encryption_service()
encrypted_token = enc_svc.encrypt(body.system_user_token)
# Check for existing connection
existing = await session.execute(
select(ChannelConnection).where(
ChannelConnection.channel_type == ChannelTypeEnum.WHATSAPP,
ChannelConnection.workspace_id == body.phone_number_id,
)
)
conn = existing.scalar_one_or_none()
if conn is None:
conn = ChannelConnection(
tenant_id=body.tenant_id,
channel_type=ChannelTypeEnum.WHATSAPP,
workspace_id=body.phone_number_id,
config={
"system_user_token": encrypted_token,
"waba_id": body.waba_id,
"phone_number_id": body.phone_number_id,
},
)
session.add(conn)
else:
conn.config = {
"system_user_token": encrypted_token,
"waba_id": body.waba_id,
"phone_number_id": body.phone_number_id,
}
await session.commit()
await session.refresh(conn)
return WhatsAppConnectResponse(
id=str(conn.id),
tenant_id=str(conn.tenant_id),
channel_type=conn.channel_type.value,
workspace_id=conn.workspace_id,
)
# ---------------------------------------------------------------------------
# Test channel endpoint
# ---------------------------------------------------------------------------
@channels_router.post("/{tenant_id}/test", response_model=TestChannelResponse)
async def test_channel_connection(
tenant_id: uuid.UUID,
body: TestChannelRequest,
session: AsyncSession = Depends(get_session),
) -> TestChannelResponse:
"""
Send a test message via the connected channel to verify the integration.
Loads the ChannelConnection for the tenant, decrypts the bot token,
and sends "Konstruct connected successfully" via the appropriate SDK.
"""
await _get_tenant_or_404(tenant_id, session)
# Resolve channel type
try:
channel_type = ChannelTypeEnum(body.channel_type)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported channel type: {body.channel_type}",
)
# Load channel connection
result = await session.execute(
select(ChannelConnection).where(
ChannelConnection.tenant_id == tenant_id,
ChannelConnection.channel_type == channel_type,
)
)
conn = result.scalar_one_or_none()
if conn is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No {body.channel_type} connection found for this tenant",
)
enc_svc = _get_encryption_service()
if channel_type == ChannelTypeEnum.SLACK:
encrypted_token = conn.config.get("bot_token", "")
if not encrypted_token:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Slack bot token not found in connection config",
)
bot_token = enc_svc.decrypt(encrypted_token)
# Send test message via Slack Web API
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
"https://slack.com/api/chat.postMessage",
headers={
"Authorization": f"Bearer {bot_token}",
"Content-Type": "application/json",
},
json={
"channel": conn.config.get("bot_user_id", ""),
"text": "Konstruct connected successfully",
},
)
if response.status_code == 200 and response.json().get("ok"):
return TestChannelResponse(success=True, message="Test message sent successfully")
else:
error = response.json().get("error", "unknown") if response.status_code == 200 else str(response.status_code)
return TestChannelResponse(success=False, message=f"Slack API error: {error}")
elif channel_type == ChannelTypeEnum.WHATSAPP:
encrypted_token = conn.config.get("system_user_token", "")
phone_number_id = conn.config.get("phone_number_id", conn.workspace_id)
if not encrypted_token:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="WhatsApp system user token not found in connection config",
)
# For WhatsApp, we can't easily send a "test" message without a recipient.
# Return success if the connection config is valid.
return TestChannelResponse(
success=True,
message=f"WhatsApp connection validated for phone number ID {phone_number_id}",
)
else:
return TestChannelResponse(
success=False,
message=f"Test messages not yet supported for {body.channel_type}",
)

View File

@@ -30,6 +30,64 @@ from shared.models.tenant import Agent, Tenant
usage_router = APIRouter(prefix="/api/portal/usage", tags=["usage"])
# ---------------------------------------------------------------------------
# In-memory aggregation helpers (used by tests for unit coverage without DB)
# ---------------------------------------------------------------------------
def _aggregate_rows_by_agent(rows: list[dict]) -> list[dict]:
"""
Aggregate a list of raw audit event row dicts by agent_id.
Each row dict must have: agent_id, prompt_tokens, completion_tokens,
total_tokens, cost_usd.
Returns a list of dicts with keys: agent_id, prompt_tokens,
completion_tokens, total_tokens, cost_usd, call_count.
"""
aggregated: dict[str, dict] = {}
for row in rows:
agent_id = str(row["agent_id"])
if agent_id not in aggregated:
aggregated[agent_id] = {
"agent_id": agent_id,
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost_usd": 0.0,
"call_count": 0,
}
agg = aggregated[agent_id]
agg["prompt_tokens"] += int(row.get("prompt_tokens", 0))
agg["completion_tokens"] += int(row.get("completion_tokens", 0))
agg["total_tokens"] += int(row.get("total_tokens", 0))
agg["cost_usd"] += float(row.get("cost_usd", 0.0))
agg["call_count"] += 1
return list(aggregated.values())
def _aggregate_rows_by_provider(rows: list[dict]) -> list[dict]:
"""
Aggregate a list of raw audit event row dicts by provider.
Each row dict must have: provider, cost_usd.
Returns a list of dicts with keys: provider, cost_usd, call_count.
"""
aggregated: dict[str, dict] = {}
for row in rows:
provider = str(row.get("provider", "unknown"))
if provider not in aggregated:
aggregated[provider] = {
"provider": provider,
"cost_usd": 0.0,
"call_count": 0,
}
agg = aggregated[provider]
agg["cost_usd"] += float(row.get("cost_usd", 0.0))
agg["call_count"] += 1
return list(aggregated.values())
# ---------------------------------------------------------------------------
# Budget threshold helper (also used by tests directly)
# ---------------------------------------------------------------------------

View File

@@ -82,7 +82,11 @@ class AuditEvent(AuditBase):
nullable=True,
comment="Duration of the operation in milliseconds",
)
metadata: Mapped[dict[str, Any]] = mapped_column(
# NOTE: 'metadata' is reserved by SQLAlchemy's DeclarativeBase — the Python
# attribute is named 'event_metadata' but maps to the 'metadata' DB column.
# The AuditLogger uses raw SQL text() so this ORM attribute is for read queries only.
event_metadata: Mapped[dict[str, Any]] = mapped_column(
"metadata",
JSONB,
nullable=False,
server_default="{}",

View File

@@ -0,0 +1,72 @@
"""
Unit tests for Slack OAuth state generation and verification.
Tests:
- generate_oauth_state produces a base64-encoded string containing tenant_id
- verify_oauth_state returns the correct tenant_id for a valid state
- verify_oauth_state raises ValueError for a tampered state
- verify_oauth_state raises ValueError for a state signed with wrong secret
"""
from __future__ import annotations
import pytest
from shared.api.channels import generate_oauth_state, verify_oauth_state
_SECRET = "test-hmac-secret-do-not-use-in-production"
_TENANT_ID = "550e8400-e29b-41d4-a716-446655440000"
def test_generate_oauth_state_is_string() -> None:
"""generate_oauth_state returns a non-empty string."""
state = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
assert isinstance(state, str)
assert len(state) > 0
def test_generate_oauth_state_contains_tenant_id() -> None:
"""
generate_oauth_state embeds the tenant_id in the state payload.
Verifying the state should return the original tenant_id.
"""
state = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
recovered = verify_oauth_state(state=state, secret=_SECRET)
assert recovered == _TENANT_ID
def test_verify_oauth_state_valid() -> None:
"""verify_oauth_state returns correct tenant_id for a freshly generated state."""
state = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
result = verify_oauth_state(state=state, secret=_SECRET)
assert result == _TENANT_ID
def test_verify_oauth_state_tampered() -> None:
"""verify_oauth_state raises ValueError if the state payload is tampered."""
state = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
# Tamper: append garbage to the state string
tampered = state + "TAMPERED"
with pytest.raises(ValueError):
verify_oauth_state(state=tampered, secret=_SECRET)
def test_verify_oauth_state_wrong_secret() -> None:
"""verify_oauth_state raises ValueError if verified with the wrong secret."""
state = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
with pytest.raises(ValueError):
verify_oauth_state(state=state, secret="wrong-secret")
def test_generate_oauth_state_nonce_differs() -> None:
"""Two calls to generate_oauth_state produce different states (random nonce)."""
state1 = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
state2 = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
# Different nonce means different state tokens
assert state1 != state2
# Both must still verify correctly
assert verify_oauth_state(state=state1, secret=_SECRET) == _TENANT_ID
assert verify_oauth_state(state=state2, secret=_SECRET) == _TENANT_ID

View File

@@ -0,0 +1,177 @@
"""
Unit tests for Stripe webhook handler logic.
Tests:
- process_stripe_event idempotency: same event_id processed twice returns
"already_processed" on the second call
- customer.subscription.updated: updates tenant subscription_status
- customer.subscription.deleted: sets status=canceled and deactivates agents
"""
from __future__ import annotations
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from shared.api.billing import process_stripe_event
@pytest.fixture()
def mock_session() -> AsyncMock:
"""Mock AsyncSession that tracks execute calls."""
session = AsyncMock()
# Default: no existing StripeEvent found (not yet processed)
session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
return session
@pytest.fixture()
def tenant_id() -> str:
return str(uuid.uuid4())
@pytest.fixture()
def agent_id() -> str:
return str(uuid.uuid4())
class _MockTenant:
def __init__(self, tenant_id: str) -> None:
self.id = uuid.UUID(tenant_id)
self.stripe_customer_id: str | None = None
self.stripe_subscription_id: str | None = None
self.stripe_subscription_item_id: str | None = None
self.subscription_status: str = "trialing"
self.trial_ends_at = None
self.agent_quota: int = 0
class _MockAgent:
def __init__(self, agent_id: str, tenant_id: str) -> None:
self.id = uuid.UUID(agent_id)
self.tenant_id = uuid.UUID(tenant_id)
self.is_active: bool = True
@pytest.mark.asyncio
async def test_stripe_webhook_idempotency(mock_session: AsyncMock, tenant_id: str) -> None:
"""
Processing the same event_id twice returns 'already_processed' on second call.
The second call should detect the existing StripeEvent and skip processing.
"""
event_id = "evt_test_idempotent_001"
event_data = {
"id": event_id,
"type": "customer.subscription.updated",
"data": {
"object": {
"id": "sub_test123",
"status": "active",
"customer": "cus_test123",
"trial_end": None,
"items": {"data": [{"id": "si_test123", "quantity": 2}]},
}
},
}
# First call: no existing event in DB
first_session = AsyncMock()
first_session.execute.return_value = MagicMock(
scalar_one_or_none=MagicMock(return_value=None),
scalars=MagicMock(return_value=MagicMock(all=MagicMock(return_value=[]))),
)
mock_tenant = _MockTenant(tenant_id)
with patch("shared.api.billing._get_tenant_by_stripe_customer", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_tenant
result1 = await process_stripe_event(event_data, first_session)
assert result1 != "already_processed"
# Second call: event already in DB (simulating idempotent duplicate)
existing_event = MagicMock() # Non-None means already processed
second_session = AsyncMock()
second_session.execute.return_value = MagicMock(
scalar_one_or_none=MagicMock(return_value=existing_event)
)
result2 = await process_stripe_event(event_data, second_session)
assert result2 == "already_processed"
@pytest.mark.asyncio
async def test_stripe_subscription_updated(mock_session: AsyncMock, tenant_id: str) -> None:
"""
customer.subscription.updated event updates tenant subscription_status.
"""
event_data = {
"id": "evt_test_sub_updated",
"type": "customer.subscription.updated",
"data": {
"object": {
"id": "sub_updated123",
"status": "active",
"customer": "cus_tenant123",
"trial_end": None,
"items": {"data": [{"id": "si_updated123", "quantity": 3}]},
}
},
}
mock_tenant = _MockTenant(tenant_id)
assert mock_tenant.subscription_status == "trialing"
with patch("shared.api.billing._get_tenant_by_stripe_customer", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_tenant
result = await process_stripe_event(event_data, mock_session)
assert result != "already_processed"
assert mock_tenant.subscription_status == "active"
assert mock_tenant.stripe_subscription_id == "sub_updated123"
assert mock_tenant.agent_quota == 3
@pytest.mark.asyncio
async def test_stripe_cancellation(mock_session: AsyncMock, tenant_id: str, agent_id: str) -> None:
"""
customer.subscription.deleted sets status=canceled and deactivates all tenant agents.
"""
event_data = {
"id": "evt_test_canceled",
"type": "customer.subscription.deleted",
"data": {
"object": {
"id": "sub_canceled123",
"status": "canceled",
"customer": "cus_tenant_cancel",
"trial_end": None,
"items": {"data": []},
}
},
}
mock_tenant = _MockTenant(tenant_id)
mock_tenant.subscription_status = "active"
mock_agent = _MockAgent(agent_id, tenant_id)
assert mock_agent.is_active is True
with (
patch("shared.api.billing._get_tenant_by_stripe_customer", new_callable=AsyncMock) as mock_get_tenant,
patch("shared.api.billing._deactivate_all_agents", new_callable=AsyncMock) as mock_deactivate,
):
mock_get_tenant.return_value = mock_tenant
# Simulate deactivation side effect
async def _do_deactivate(session, tenant_id_arg): # type: ignore[no-untyped-def]
mock_agent.is_active = False
mock_deactivate.side_effect = _do_deactivate
result = await process_stripe_event(event_data, mock_session)
assert result != "already_processed"
assert mock_tenant.subscription_status == "canceled"
assert mock_agent.is_active is False
mock_deactivate.assert_called_once()

View File

@@ -0,0 +1,141 @@
"""
Unit tests for usage aggregation logic.
Tests:
- aggregate_usage_by_agent groups prompt_tokens, completion_tokens, cost_usd by agent_id
- aggregate_usage_by_provider groups cost_usd by provider name
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
import pytest
from shared.api.usage import (
_aggregate_rows_by_agent,
_aggregate_rows_by_provider,
)
# ---------------------------------------------------------------------------
# Sample audit event rows (simulating DB query results)
# ---------------------------------------------------------------------------
def _make_row(
agent_id: str,
provider: str,
prompt_tokens: int = 100,
completion_tokens: int = 50,
cost_usd: float = 0.01,
) -> dict:
return {
"agent_id": agent_id,
"provider": provider,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
"cost_usd": cost_usd,
}
# ---------------------------------------------------------------------------
# Tests: per-agent aggregation
# ---------------------------------------------------------------------------
def test_usage_group_by_agent_single_agent() -> None:
"""Single agent with two calls aggregates tokens and cost correctly."""
agent_id = str(uuid.uuid4())
rows = [
_make_row(agent_id, "openai", prompt_tokens=100, completion_tokens=50, cost_usd=0.01),
_make_row(agent_id, "openai", prompt_tokens=200, completion_tokens=80, cost_usd=0.02),
]
result = _aggregate_rows_by_agent(rows)
assert len(result) == 1
assert result[0]["agent_id"] == agent_id
assert result[0]["prompt_tokens"] == 300
assert result[0]["completion_tokens"] == 130
assert result[0]["total_tokens"] == 430
assert abs(result[0]["cost_usd"] - 0.03) < 1e-9
assert result[0]["call_count"] == 2
def test_usage_group_by_agent_multiple_agents() -> None:
"""Multiple agents are aggregated separately."""
agent_a = str(uuid.uuid4())
agent_b = str(uuid.uuid4())
rows = [
_make_row(agent_a, "anthropic", prompt_tokens=100, completion_tokens=40, cost_usd=0.005),
_make_row(agent_b, "openai", prompt_tokens=500, completion_tokens=200, cost_usd=0.05),
_make_row(agent_a, "anthropic", prompt_tokens=50, completion_tokens=20, cost_usd=0.002),
]
result = _aggregate_rows_by_agent(rows)
by_id = {r["agent_id"]: r for r in result}
assert agent_a in by_id
assert agent_b in by_id
assert by_id[agent_a]["prompt_tokens"] == 150
assert by_id[agent_a]["call_count"] == 2
assert by_id[agent_b]["prompt_tokens"] == 500
assert by_id[agent_b]["call_count"] == 1
def test_usage_group_by_agent_empty_rows() -> None:
"""Empty input returns empty list."""
result = _aggregate_rows_by_agent([])
assert result == []
# ---------------------------------------------------------------------------
# Tests: per-provider aggregation
# ---------------------------------------------------------------------------
def test_usage_group_by_provider_single_provider() -> None:
"""Single provider aggregates cost correctly."""
agent_id = str(uuid.uuid4())
rows = [
_make_row(agent_id, "openai", cost_usd=0.01),
_make_row(agent_id, "openai", cost_usd=0.02),
]
result = _aggregate_rows_by_provider(rows)
assert len(result) == 1
assert result[0]["provider"] == "openai"
assert abs(result[0]["cost_usd"] - 0.03) < 1e-9
assert result[0]["call_count"] == 2
def test_usage_group_by_provider_multiple_providers() -> None:
"""Multiple providers are grouped independently."""
agent_id = str(uuid.uuid4())
rows = [
_make_row(agent_id, "openai", cost_usd=0.01),
_make_row(agent_id, "anthropic", cost_usd=0.02),
_make_row(agent_id, "openai", cost_usd=0.03),
_make_row(agent_id, "anthropic", cost_usd=0.04),
]
result = _aggregate_rows_by_provider(rows)
by_provider = {r["provider"]: r for r in result}
assert "openai" in by_provider
assert "anthropic" in by_provider
assert abs(by_provider["openai"]["cost_usd"] - 0.04) < 1e-9
assert by_provider["openai"]["call_count"] == 2
assert abs(by_provider["anthropic"]["cost_usd"] - 0.06) < 1e-9
assert by_provider["anthropic"]["call_count"] == 2
def test_usage_group_by_provider_empty_rows() -> None:
"""Empty input returns empty list."""
result = _aggregate_rows_by_provider([])
assert result == []