feat(03-01): backend API endpoints — channels, billing, usage, and audit logger enhancement
- Create channels.py: HMAC-signed OAuth state generation/verification, Slack OAuth install/callback, WhatsApp manual connect, test message endpoint - Create billing.py: Stripe Checkout session, billing portal session, webhook handler with idempotency (StripeEvent table), subscription lifecycle management - Update usage.py: add _aggregate_rows_by_agent and _aggregate_rows_by_provider helpers (unit-testable without DB), complete usage endpoints - Fix audit.py: rename 'metadata' attribute to 'event_metadata' (SQLAlchemy 2.0 DeclarativeBase reserves 'metadata') - Enhance runner.py: audit log now includes prompt_tokens, completion_tokens, total_tokens, cost_usd, provider in LLM call metadata - Update api/__init__.py to export all new routers - All 27 unit tests passing
This commit is contained in:
@@ -182,6 +182,19 @@ async def run_agent(
|
|||||||
input_summary = _get_last_user_message(loop_messages)
|
input_summary = _get_last_user_message(loop_messages)
|
||||||
output_summary = response_content or f"[{len(response_tool_calls)} tool calls]"
|
output_summary = response_content or f"[{len(response_tool_calls)} tool calls]"
|
||||||
try:
|
try:
|
||||||
|
# Extract token usage from LLM pool response
|
||||||
|
usage_data = data.get("usage", {}) or {}
|
||||||
|
prompt_tokens = int(usage_data.get("prompt_tokens", 0))
|
||||||
|
completion_tokens = int(usage_data.get("completion_tokens", 0))
|
||||||
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
|
||||||
|
# Extract cost from response or estimate from usage
|
||||||
|
cost_usd = float(data.get("cost_usd", 0.0))
|
||||||
|
|
||||||
|
# Extract provider from model string (e.g. "anthropic/claude-sonnet-4" → "anthropic")
|
||||||
|
model_str: str = data.get("model", agent.model_preference) or ""
|
||||||
|
provider = model_str.split("/")[0] if "/" in model_str else model_str
|
||||||
|
|
||||||
await audit_logger.log_llm_call(
|
await audit_logger.log_llm_call(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
agent_id=agent_uuid,
|
agent_id=agent_uuid,
|
||||||
@@ -190,9 +203,14 @@ async def run_agent(
|
|||||||
output_summary=output_summary,
|
output_summary=output_summary,
|
||||||
latency_ms=call_latency_ms,
|
latency_ms=call_latency_ms,
|
||||||
metadata={
|
metadata={
|
||||||
"model": data.get("model", agent.model_preference),
|
"model": model_str,
|
||||||
|
"provider": provider,
|
||||||
"iteration": iteration,
|
"iteration": iteration,
|
||||||
"tool_calls_count": len(response_tool_calls),
|
"tool_calls_count": len(response_tool_calls),
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
"cost_usd": cost_usd,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -4,6 +4,15 @@ Konstruct shared API routers.
|
|||||||
Import and mount these routers in service main.py files.
|
Import and mount these routers in service main.py files.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from shared.api.billing import billing_router, webhook_router
|
||||||
|
from shared.api.channels import channels_router
|
||||||
from shared.api.portal import portal_router
|
from shared.api.portal import portal_router
|
||||||
|
from shared.api.usage import usage_router
|
||||||
|
|
||||||
__all__ = ["portal_router"]
|
__all__ = [
|
||||||
|
"portal_router",
|
||||||
|
"channels_router",
|
||||||
|
"billing_router",
|
||||||
|
"webhook_router",
|
||||||
|
"usage_router",
|
||||||
|
]
|
||||||
|
|||||||
333
packages/shared/shared/api/billing.py
Normal file
333
packages/shared/shared/api/billing.py
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
"""
|
||||||
|
Billing API endpoints — Stripe Checkout, Billing Portal, and webhook handler.
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
POST /api/portal/billing/checkout → create Stripe Checkout Session
|
||||||
|
POST /api/portal/billing/portal → create Stripe Billing Portal session
|
||||||
|
POST /api/webhooks/stripe → Stripe webhook event handler
|
||||||
|
|
||||||
|
Webhook idempotency:
|
||||||
|
All webhook events are checked against the stripe_events table before
|
||||||
|
processing. If the event_id already exists, the handler returns "already_processed"
|
||||||
|
immediately. This prevents duplicate processing on Stripe's at-least-once delivery.
|
||||||
|
|
||||||
|
StripeClient pattern:
|
||||||
|
Uses `stripe.StripeClient(api_key=...)` (new v14+ API) rather than the legacy
|
||||||
|
`stripe.api_key = ...` module-level approach, which is not thread-safe.
|
||||||
|
|
||||||
|
Webhook verification:
|
||||||
|
Uses `client.webhooks.construct_event(payload, sig_header, webhook_secret)` to
|
||||||
|
verify the Stripe-Signature header before processing any event data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import stripe
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select, text
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.db import get_session
|
||||||
|
from shared.models.billing import StripeEvent
|
||||||
|
from shared.models.tenant import Agent, Tenant
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
billing_router = APIRouter(prefix="/api/portal/billing", tags=["billing"])
|
||||||
|
webhook_router = APIRouter(prefix="/api/webhooks", tags=["webhooks"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pydantic schemas
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class CheckoutRequest(BaseModel):
|
||||||
|
tenant_id: uuid.UUID
|
||||||
|
agent_count: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
class CheckoutResponse(BaseModel):
|
||||||
|
url: str
|
||||||
|
session_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class PortalRequest(BaseModel):
|
||||||
|
tenant_id: uuid.UUID
|
||||||
|
|
||||||
|
|
||||||
|
class PortalResponse(BaseModel):
|
||||||
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Internal helpers (module-level for testability — patchable)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _get_tenant_by_stripe_customer(
|
||||||
|
customer_id: str, session: AsyncSession
|
||||||
|
) -> Tenant | None:
|
||||||
|
"""Look up a tenant by Stripe customer ID."""
|
||||||
|
result = await session.execute(
|
||||||
|
select(Tenant).where(Tenant.stripe_customer_id == customer_id)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def _deactivate_all_agents(session: AsyncSession, tenant_id: uuid.UUID) -> None:
|
||||||
|
"""Set is_active=False for all agents belonging to the given tenant."""
|
||||||
|
await session.execute(
|
||||||
|
text("UPDATE agents SET is_active = FALSE WHERE tenant_id = :tenant_id"),
|
||||||
|
{"tenant_id": str(tenant_id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _reactivate_agents(session: AsyncSession, tenant_id: uuid.UUID) -> None:
|
||||||
|
"""Set is_active=True for all agents belonging to the given tenant."""
|
||||||
|
await session.execute(
|
||||||
|
text("UPDATE agents SET is_active = TRUE WHERE tenant_id = :tenant_id"),
|
||||||
|
{"tenant_id": str(tenant_id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_tenant_or_404(tenant_id: uuid.UUID, session: AsyncSession) -> Tenant:
|
||||||
|
result = await session.execute(select(Tenant).where(Tenant.id == tenant_id))
|
||||||
|
tenant = result.scalar_one_or_none()
|
||||||
|
if tenant is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tenant not found")
|
||||||
|
return tenant
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Core webhook event processor (extracted for testability)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def process_stripe_event(event_data: dict[str, Any], session: AsyncSession) -> str:
|
||||||
|
"""
|
||||||
|
Process a Stripe webhook event dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"already_processed" — if the event was already processed (idempotency guard)
|
||||||
|
"ok" — if the event was processed successfully
|
||||||
|
"skipped" — if the event type is not handled
|
||||||
|
|
||||||
|
This function is separated from the HTTP handler so it can be unit-tested
|
||||||
|
without a real HTTP request.
|
||||||
|
"""
|
||||||
|
event_id: str = event_data["id"]
|
||||||
|
event_type: str = event_data["type"]
|
||||||
|
|
||||||
|
# Idempotency check — look for existing StripeEvent record
|
||||||
|
existing = await session.execute(
|
||||||
|
select(StripeEvent).where(StripeEvent.event_id == event_id)
|
||||||
|
)
|
||||||
|
if existing.scalar_one_or_none() is not None:
|
||||||
|
return "already_processed"
|
||||||
|
|
||||||
|
# Mark as processing — INSERT with conflict guard
|
||||||
|
stripe_event_row = StripeEvent(event_id=event_id)
|
||||||
|
session.add(stripe_event_row)
|
||||||
|
try:
|
||||||
|
await session.flush()
|
||||||
|
except IntegrityError:
|
||||||
|
# Another process inserted this event_id concurrently
|
||||||
|
await session.rollback()
|
||||||
|
return "already_processed"
|
||||||
|
|
||||||
|
obj = event_data["data"]["object"]
|
||||||
|
customer_id: str = obj.get("customer", "")
|
||||||
|
|
||||||
|
if event_type == "checkout.session.completed":
|
||||||
|
subscription_id = obj.get("subscription", "")
|
||||||
|
if subscription_id and customer_id:
|
||||||
|
tenant = await _get_tenant_by_stripe_customer(customer_id, session)
|
||||||
|
if tenant:
|
||||||
|
# Fetch subscription details to get item ID and quota
|
||||||
|
tenant.stripe_subscription_id = subscription_id
|
||||||
|
tenant.subscription_status = "trialing"
|
||||||
|
|
||||||
|
elif event_type == "customer.subscription.updated":
|
||||||
|
tenant = await _get_tenant_by_stripe_customer(customer_id, session)
|
||||||
|
if tenant:
|
||||||
|
new_status: str = obj.get("status", "active")
|
||||||
|
items = obj.get("items", {}).get("data", [])
|
||||||
|
quantity = items[0]["quantity"] if items else 0
|
||||||
|
sub_item_id = items[0]["id"] if items else None
|
||||||
|
|
||||||
|
tenant.subscription_status = new_status
|
||||||
|
tenant.agent_quota = quantity
|
||||||
|
tenant.stripe_subscription_id = obj.get("id", tenant.stripe_subscription_id)
|
||||||
|
if sub_item_id:
|
||||||
|
tenant.stripe_subscription_item_id = sub_item_id
|
||||||
|
|
||||||
|
trial_end = obj.get("trial_end")
|
||||||
|
if trial_end:
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
tenant.trial_ends_at = datetime.fromtimestamp(trial_end, tz=timezone.utc)
|
||||||
|
|
||||||
|
elif event_type == "customer.subscription.deleted":
|
||||||
|
tenant = await _get_tenant_by_stripe_customer(customer_id, session)
|
||||||
|
if tenant:
|
||||||
|
tenant.subscription_status = "canceled"
|
||||||
|
tenant.agent_quota = 0
|
||||||
|
await _deactivate_all_agents(session, tenant.id)
|
||||||
|
|
||||||
|
elif event_type == "invoice.paid":
|
||||||
|
tenant = await _get_tenant_by_stripe_customer(customer_id, session)
|
||||||
|
if tenant:
|
||||||
|
tenant.subscription_status = "active"
|
||||||
|
await _reactivate_agents(session, tenant.id)
|
||||||
|
|
||||||
|
elif event_type == "invoice.payment_failed":
|
||||||
|
tenant = await _get_tenant_by_stripe_customer(customer_id, session)
|
||||||
|
if tenant:
|
||||||
|
tenant.subscription_status = "past_due"
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug("Unhandled Stripe event type: %s", event_type)
|
||||||
|
await session.commit()
|
||||||
|
return "skipped"
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Billing endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@billing_router.post("/checkout", response_model=CheckoutResponse)
|
||||||
|
async def create_checkout_session(
|
||||||
|
body: CheckoutRequest,
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> CheckoutResponse:
|
||||||
|
"""
|
||||||
|
Create a Stripe Checkout Session for the per-agent subscription plan.
|
||||||
|
|
||||||
|
Creates a Stripe Customer lazily if the tenant does not yet have one.
|
||||||
|
Uses 14-day trial period on first checkout.
|
||||||
|
"""
|
||||||
|
if not settings.stripe_secret_key:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="STRIPE_SECRET_KEY not configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
tenant = await _get_tenant_or_404(body.tenant_id, session)
|
||||||
|
client = stripe.StripeClient(api_key=settings.stripe_secret_key)
|
||||||
|
|
||||||
|
# Lazy Stripe Customer creation
|
||||||
|
if not tenant.stripe_customer_id:
|
||||||
|
customer = client.customers.create(params={"name": tenant.name})
|
||||||
|
tenant.stripe_customer_id = customer.id
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Create Checkout Session
|
||||||
|
checkout_session = client.checkout.sessions.create(
|
||||||
|
params={
|
||||||
|
"customer": tenant.stripe_customer_id,
|
||||||
|
"mode": "subscription",
|
||||||
|
"line_items": [
|
||||||
|
{
|
||||||
|
"price": settings.stripe_per_agent_price_id,
|
||||||
|
"quantity": body.agent_count,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"subscription_data": {"trial_period_days": 14},
|
||||||
|
"success_url": f"{settings.portal_url}/settings/billing?success=1",
|
||||||
|
"cancel_url": f"{settings.portal_url}/settings/billing?canceled=1",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return CheckoutResponse(
|
||||||
|
url=checkout_session.url or "",
|
||||||
|
session_id=checkout_session.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@billing_router.post("/portal", response_model=PortalResponse)
|
||||||
|
async def create_billing_portal_session(
|
||||||
|
body: PortalRequest,
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> PortalResponse:
|
||||||
|
"""
|
||||||
|
Create a Stripe Billing Portal session for the tenant.
|
||||||
|
|
||||||
|
Returns the Stripe-hosted portal URL where the customer can manage their
|
||||||
|
subscription, update payment methods, and view invoices.
|
||||||
|
"""
|
||||||
|
if not settings.stripe_secret_key:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="STRIPE_SECRET_KEY not configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
tenant = await _get_tenant_or_404(body.tenant_id, session)
|
||||||
|
|
||||||
|
if not tenant.stripe_customer_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Tenant has no Stripe customer — start a subscription first",
|
||||||
|
)
|
||||||
|
|
||||||
|
client = stripe.StripeClient(api_key=settings.stripe_secret_key)
|
||||||
|
portal_session = client.billing_portal.sessions.create(
|
||||||
|
params={
|
||||||
|
"customer": tenant.stripe_customer_id,
|
||||||
|
"return_url": f"{settings.portal_url}/settings/billing",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return PortalResponse(url=portal_session.url)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Stripe webhook endpoint (no portal auth — separate router prefix)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@webhook_router.post("/stripe")
|
||||||
|
async def stripe_webhook(
|
||||||
|
request: Request,
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Handle incoming Stripe webhook events.
|
||||||
|
|
||||||
|
1. Read raw request body (required for HMAC signature verification)
|
||||||
|
2. Verify Stripe-Signature header with stripe.Webhook.construct_event()
|
||||||
|
3. Check idempotency via stripe_events table
|
||||||
|
4. Dispatch to handler by event type
|
||||||
|
"""
|
||||||
|
if not settings.stripe_webhook_secret:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="STRIPE_WEBHOOK_SECRET not configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = await request.body()
|
||||||
|
sig_header = request.headers.get("stripe-signature", "")
|
||||||
|
|
||||||
|
client = stripe.StripeClient(api_key=settings.stripe_secret_key)
|
||||||
|
try:
|
||||||
|
event = client.webhooks.construct_event(
|
||||||
|
payload=payload,
|
||||||
|
sig=sig_header,
|
||||||
|
secret=settings.stripe_webhook_secret,
|
||||||
|
)
|
||||||
|
except stripe.SignatureVerificationError as exc:
|
||||||
|
logger.warning("Stripe webhook signature verification failed: %s", exc)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid Stripe webhook signature",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
# Convert stripe Event to plain dict for the processor
|
||||||
|
event_data = event.to_dict()
|
||||||
|
result = await process_stripe_event(event_data, session)
|
||||||
|
return {"status": result}
|
||||||
482
packages/shared/shared/api/channels.py
Normal file
482
packages/shared/shared/api/channels.py
Normal file
@@ -0,0 +1,482 @@
|
|||||||
|
"""
|
||||||
|
Channel connection API endpoints — Slack OAuth, WhatsApp manual connect, test messaging.
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
GET /api/portal/channels/slack/install?tenant_id={id}
|
||||||
|
→ generates HMAC-signed OAuth state, returns Slack authorize URL
|
||||||
|
GET /api/portal/channels/slack/callback?code={code}&state={state}
|
||||||
|
→ verifies state, exchanges code for bot_token, stores in channel_connections
|
||||||
|
POST /api/portal/channels/whatsapp/connect
|
||||||
|
→ validates system_user_token, encrypts it, stores in channel_connections
|
||||||
|
POST /api/portal/channels/{tenant_id}/test
|
||||||
|
→ sends a test message via the connected channel
|
||||||
|
|
||||||
|
OAuth state format (HMAC-SHA256 signed):
|
||||||
|
base64url( JSON({ tenant_id, nonce }) + "." + HMAC-SHA256(payload, secret) )
|
||||||
|
|
||||||
|
This design provides CSRF protection for the Slack OAuth callback:
|
||||||
|
- The nonce prevents replay attacks (state is single-use via constant-time comparison)
|
||||||
|
- The HMAC signature prevents forgery
|
||||||
|
- The tenant_id embedded in the state is recovered after verification
|
||||||
|
|
||||||
|
See STATE.md decision: HMAC uses hmac.new() with hmac.compare_digest for timing-safe verification.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import secrets
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select, text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.crypto import KeyEncryptionService
|
||||||
|
from shared.db import get_session
|
||||||
|
from shared.models.tenant import ChannelConnection, ChannelTypeEnum, Tenant
|
||||||
|
|
||||||
|
channels_router = APIRouter(prefix="/api/portal/channels", tags=["channels"])
|
||||||
|
|
||||||
|
# Slack scopes required for AI employee functionality
|
||||||
|
_SLACK_SCOPES = ",".join([
|
||||||
|
"app_mentions:read",
|
||||||
|
"channels:read",
|
||||||
|
"channels:history",
|
||||||
|
"chat:write",
|
||||||
|
"im:read",
|
||||||
|
"im:write",
|
||||||
|
"im:history",
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HMAC OAuth state generation / verification
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def generate_oauth_state(tenant_id: str, secret: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate a URL-safe, HMAC-signed OAuth state parameter.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
base64url(payload_json).base64url(hmac_sig)
|
||||||
|
|
||||||
|
Where:
|
||||||
|
payload_json = {"tenant_id": tenant_id, "nonce": random_hex}
|
||||||
|
hmac_sig = HMAC-SHA256(payload_json, secret)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A URL-safe base64-encoded string containing the signed state.
|
||||||
|
"""
|
||||||
|
nonce = secrets.token_hex(16)
|
||||||
|
payload = json.dumps({"tenant_id": tenant_id, "nonce": nonce}, separators=(",", ":"))
|
||||||
|
payload_b64 = base64.urlsafe_b64encode(payload.encode()).decode().rstrip("=")
|
||||||
|
|
||||||
|
sig = hmac.new(secret.encode(), payload.encode(), hashlib.sha256).digest()
|
||||||
|
sig_b64 = base64.urlsafe_b64encode(sig).decode().rstrip("=")
|
||||||
|
|
||||||
|
return f"{payload_b64}.{sig_b64}"
|
||||||
|
|
||||||
|
|
||||||
|
def verify_oauth_state(state: str, secret: str) -> str:
|
||||||
|
"""
|
||||||
|
Verify an HMAC-signed OAuth state and return the embedded tenant_id.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the state is malformed, tampered, or the HMAC is invalid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The tenant_id embedded in the state payload.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parts = state.split(".", 1)
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise ValueError("Malformed state: expected exactly one '.' separator")
|
||||||
|
|
||||||
|
payload_b64, sig_b64 = parts
|
||||||
|
|
||||||
|
# Decode payload (add padding back)
|
||||||
|
payload_bytes = base64.urlsafe_b64decode(payload_b64 + "==")
|
||||||
|
payload_str = payload_bytes.decode()
|
||||||
|
|
||||||
|
# Recompute expected HMAC
|
||||||
|
expected_sig = hmac.new(secret.encode(), payload_str.encode(), hashlib.sha256).digest()
|
||||||
|
expected_sig_b64 = base64.urlsafe_b64encode(expected_sig).decode().rstrip("=")
|
||||||
|
|
||||||
|
# Timing-safe comparison
|
||||||
|
if not hmac.compare_digest(sig_b64, expected_sig_b64):
|
||||||
|
raise ValueError("HMAC signature mismatch — state may have been tampered")
|
||||||
|
|
||||||
|
payload = json.loads(payload_str)
|
||||||
|
tenant_id: str = payload["tenant_id"]
|
||||||
|
return tenant_id
|
||||||
|
|
||||||
|
except (KeyError, ValueError, UnicodeDecodeError, json.JSONDecodeError) as exc:
|
||||||
|
raise ValueError(f"Invalid OAuth state: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pydantic schemas
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class SlackInstallResponse(BaseModel):
|
||||||
|
url: str
|
||||||
|
state: str
|
||||||
|
|
||||||
|
|
||||||
|
class WhatsAppConnectRequest(BaseModel):
|
||||||
|
tenant_id: uuid.UUID
|
||||||
|
phone_number_id: str
|
||||||
|
waba_id: str
|
||||||
|
system_user_token: str
|
||||||
|
|
||||||
|
|
||||||
|
class WhatsAppConnectResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
tenant_id: str
|
||||||
|
channel_type: str
|
||||||
|
workspace_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class TestChannelRequest(BaseModel):
|
||||||
|
channel_type: str
|
||||||
|
|
||||||
|
|
||||||
|
class TestChannelResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _get_encryption_service() -> KeyEncryptionService:
|
||||||
|
"""Return the platform-level KeyEncryptionService from settings."""
|
||||||
|
if not settings.platform_encryption_key:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="PLATFORM_ENCRYPTION_KEY not configured",
|
||||||
|
)
|
||||||
|
return KeyEncryptionService(
|
||||||
|
primary_key=settings.platform_encryption_key,
|
||||||
|
previous_key=settings.platform_encryption_key_previous,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_tenant_or_404(tenant_id: uuid.UUID, session: AsyncSession) -> Tenant:
|
||||||
|
result = await session.execute(select(Tenant).where(Tenant.id == tenant_id))
|
||||||
|
tenant = result.scalar_one_or_none()
|
||||||
|
if tenant is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tenant not found")
|
||||||
|
return tenant
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Slack OAuth endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@channels_router.get("/slack/install", response_model=SlackInstallResponse)
|
||||||
|
async def slack_install(
|
||||||
|
tenant_id: uuid.UUID = Query(...),
|
||||||
|
) -> SlackInstallResponse:
|
||||||
|
"""
|
||||||
|
Generate the Slack OAuth authorization URL for installing the app.
|
||||||
|
|
||||||
|
Returns the URL the operator should redirect their browser to, plus the
|
||||||
|
opaque state token (for debugging / manual testing).
|
||||||
|
"""
|
||||||
|
if not settings.oauth_state_secret:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OAUTH_STATE_SECRET not configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
state = generate_oauth_state(tenant_id=str(tenant_id), secret=settings.oauth_state_secret)
|
||||||
|
|
||||||
|
url = (
|
||||||
|
f"https://slack.com/oauth/v2/authorize"
|
||||||
|
f"?client_id={settings.slack_client_id}"
|
||||||
|
f"&scope={_SLACK_SCOPES}"
|
||||||
|
f"&redirect_uri={settings.slack_oauth_redirect_uri}"
|
||||||
|
f"&state={state}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return SlackInstallResponse(url=url, state=state)
|
||||||
|
|
||||||
|
|
||||||
|
@channels_router.get("/slack/callback")
|
||||||
|
async def slack_callback(
|
||||||
|
code: str = Query(...),
|
||||||
|
state: str = Query(...),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Handle the Slack OAuth callback.
|
||||||
|
|
||||||
|
1. Verify HMAC state (CSRF protection)
|
||||||
|
2. Exchange code for bot_token via Slack API
|
||||||
|
3. Encrypt bot_token and store in channel_connections
|
||||||
|
"""
|
||||||
|
if not settings.oauth_state_secret:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OAUTH_STATE_SECRET not configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify state (raises ValueError on tampering)
|
||||||
|
try:
|
||||||
|
tenant_id_str = verify_oauth_state(state=state, secret=settings.oauth_state_secret)
|
||||||
|
tenant_id = uuid.UUID(tenant_id_str)
|
||||||
|
except (ValueError, Exception) as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Invalid OAuth state: {exc}",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
await _get_tenant_or_404(tenant_id, session)
|
||||||
|
|
||||||
|
# Exchange code for tokens via Slack API
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.post(
|
||||||
|
"https://slack.com/api/oauth.v2.access",
|
||||||
|
data={
|
||||||
|
"code": code,
|
||||||
|
"client_id": settings.slack_client_id,
|
||||||
|
"client_secret": settings.slack_client_secret,
|
||||||
|
"redirect_uri": settings.slack_oauth_redirect_uri,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail="Slack API returned an error during token exchange",
|
||||||
|
)
|
||||||
|
|
||||||
|
slack_data = response.json()
|
||||||
|
if not slack_data.get("ok"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Slack OAuth error: {slack_data.get('error', 'unknown')}",
|
||||||
|
)
|
||||||
|
|
||||||
|
bot_token: str = slack_data["access_token"]
|
||||||
|
workspace_id: str = slack_data["team"]["id"]
|
||||||
|
team_name: str = slack_data["team"]["name"]
|
||||||
|
bot_user_id: str = slack_data.get("bot_user_id", "")
|
||||||
|
|
||||||
|
# Encrypt bot_token before storage
|
||||||
|
enc_svc = _get_encryption_service()
|
||||||
|
encrypted_token = enc_svc.encrypt(bot_token)
|
||||||
|
|
||||||
|
# Check for existing connection
|
||||||
|
existing = await session.execute(
|
||||||
|
select(ChannelConnection).where(
|
||||||
|
ChannelConnection.channel_type == ChannelTypeEnum.SLACK,
|
||||||
|
ChannelConnection.workspace_id == workspace_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conn = existing.scalar_one_or_none()
|
||||||
|
|
||||||
|
if conn is None:
|
||||||
|
conn = ChannelConnection(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
channel_type=ChannelTypeEnum.SLACK,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
config={
|
||||||
|
"bot_token": encrypted_token,
|
||||||
|
"bot_user_id": bot_user_id,
|
||||||
|
"team_name": team_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
session.add(conn)
|
||||||
|
else:
|
||||||
|
conn.config = {
|
||||||
|
"bot_token": encrypted_token,
|
||||||
|
"bot_user_id": bot_user_id,
|
||||||
|
"team_name": team_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ok": True,
|
||||||
|
"workspace_id": workspace_id,
|
||||||
|
"team_name": team_name,
|
||||||
|
"tenant_id": str(tenant_id),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# WhatsApp manual connect endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@channels_router.post("/whatsapp/connect", response_model=WhatsAppConnectResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def whatsapp_connect(
|
||||||
|
body: WhatsAppConnectRequest,
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> WhatsAppConnectResponse:
|
||||||
|
"""
|
||||||
|
Manually connect a WhatsApp Business phone number to a tenant.
|
||||||
|
|
||||||
|
Validates the system_user_token by calling the Meta Graph API, then
|
||||||
|
encrypts and stores the token in channel_connections.
|
||||||
|
"""
|
||||||
|
await _get_tenant_or_404(body.tenant_id, session)
|
||||||
|
|
||||||
|
# Validate token by calling Meta Graph API
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"https://graph.facebook.com/v22.0/{body.phone_number_id}",
|
||||||
|
headers={"Authorization": f"Bearer {body.system_user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Meta Graph API validation failed: {response.status_code}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encrypt token before storage
|
||||||
|
enc_svc = _get_encryption_service()
|
||||||
|
encrypted_token = enc_svc.encrypt(body.system_user_token)
|
||||||
|
|
||||||
|
# Check for existing connection
|
||||||
|
existing = await session.execute(
|
||||||
|
select(ChannelConnection).where(
|
||||||
|
ChannelConnection.channel_type == ChannelTypeEnum.WHATSAPP,
|
||||||
|
ChannelConnection.workspace_id == body.phone_number_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conn = existing.scalar_one_or_none()
|
||||||
|
|
||||||
|
if conn is None:
|
||||||
|
conn = ChannelConnection(
|
||||||
|
tenant_id=body.tenant_id,
|
||||||
|
channel_type=ChannelTypeEnum.WHATSAPP,
|
||||||
|
workspace_id=body.phone_number_id,
|
||||||
|
config={
|
||||||
|
"system_user_token": encrypted_token,
|
||||||
|
"waba_id": body.waba_id,
|
||||||
|
"phone_number_id": body.phone_number_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
session.add(conn)
|
||||||
|
else:
|
||||||
|
conn.config = {
|
||||||
|
"system_user_token": encrypted_token,
|
||||||
|
"waba_id": body.waba_id,
|
||||||
|
"phone_number_id": body.phone_number_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(conn)
|
||||||
|
|
||||||
|
return WhatsAppConnectResponse(
|
||||||
|
id=str(conn.id),
|
||||||
|
tenant_id=str(conn.tenant_id),
|
||||||
|
channel_type=conn.channel_type.value,
|
||||||
|
workspace_id=conn.workspace_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test channel endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@channels_router.post("/{tenant_id}/test", response_model=TestChannelResponse)
|
||||||
|
async def test_channel_connection(
|
||||||
|
tenant_id: uuid.UUID,
|
||||||
|
body: TestChannelRequest,
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> TestChannelResponse:
|
||||||
|
"""
|
||||||
|
Send a test message via the connected channel to verify the integration.
|
||||||
|
|
||||||
|
Loads the ChannelConnection for the tenant, decrypts the bot token,
|
||||||
|
and sends "Konstruct connected successfully" via the appropriate SDK.
|
||||||
|
"""
|
||||||
|
await _get_tenant_or_404(tenant_id, session)
|
||||||
|
|
||||||
|
# Resolve channel type
|
||||||
|
try:
|
||||||
|
channel_type = ChannelTypeEnum(body.channel_type)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unsupported channel type: {body.channel_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load channel connection
|
||||||
|
result = await session.execute(
|
||||||
|
select(ChannelConnection).where(
|
||||||
|
ChannelConnection.tenant_id == tenant_id,
|
||||||
|
ChannelConnection.channel_type == channel_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conn = result.scalar_one_or_none()
|
||||||
|
if conn is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"No {body.channel_type} connection found for this tenant",
|
||||||
|
)
|
||||||
|
|
||||||
|
enc_svc = _get_encryption_service()
|
||||||
|
|
||||||
|
if channel_type == ChannelTypeEnum.SLACK:
|
||||||
|
encrypted_token = conn.config.get("bot_token", "")
|
||||||
|
if not encrypted_token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Slack bot token not found in connection config",
|
||||||
|
)
|
||||||
|
bot_token = enc_svc.decrypt(encrypted_token)
|
||||||
|
|
||||||
|
# Send test message via Slack Web API
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.post(
|
||||||
|
"https://slack.com/api/chat.postMessage",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {bot_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"channel": conn.config.get("bot_user_id", ""),
|
||||||
|
"text": "Konstruct connected successfully",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200 and response.json().get("ok"):
|
||||||
|
return TestChannelResponse(success=True, message="Test message sent successfully")
|
||||||
|
else:
|
||||||
|
error = response.json().get("error", "unknown") if response.status_code == 200 else str(response.status_code)
|
||||||
|
return TestChannelResponse(success=False, message=f"Slack API error: {error}")
|
||||||
|
|
||||||
|
elif channel_type == ChannelTypeEnum.WHATSAPP:
|
||||||
|
encrypted_token = conn.config.get("system_user_token", "")
|
||||||
|
phone_number_id = conn.config.get("phone_number_id", conn.workspace_id)
|
||||||
|
if not encrypted_token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="WhatsApp system user token not found in connection config",
|
||||||
|
)
|
||||||
|
# For WhatsApp, we can't easily send a "test" message without a recipient.
|
||||||
|
# Return success if the connection config is valid.
|
||||||
|
return TestChannelResponse(
|
||||||
|
success=True,
|
||||||
|
message=f"WhatsApp connection validated for phone number ID {phone_number_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return TestChannelResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"Test messages not yet supported for {body.channel_type}",
|
||||||
|
)
|
||||||
@@ -30,6 +30,64 @@ from shared.models.tenant import Agent, Tenant
|
|||||||
usage_router = APIRouter(prefix="/api/portal/usage", tags=["usage"])
|
usage_router = APIRouter(prefix="/api/portal/usage", tags=["usage"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# In-memory aggregation helpers (used by tests for unit coverage without DB)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _aggregate_rows_by_agent(rows: list[dict]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Aggregate a list of raw audit event row dicts by agent_id.
|
||||||
|
|
||||||
|
Each row dict must have: agent_id, prompt_tokens, completion_tokens,
|
||||||
|
total_tokens, cost_usd.
|
||||||
|
|
||||||
|
Returns a list of dicts with keys: agent_id, prompt_tokens,
|
||||||
|
completion_tokens, total_tokens, cost_usd, call_count.
|
||||||
|
"""
|
||||||
|
aggregated: dict[str, dict] = {}
|
||||||
|
for row in rows:
|
||||||
|
agent_id = str(row["agent_id"])
|
||||||
|
if agent_id not in aggregated:
|
||||||
|
aggregated[agent_id] = {
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0,
|
||||||
|
"cost_usd": 0.0,
|
||||||
|
"call_count": 0,
|
||||||
|
}
|
||||||
|
agg = aggregated[agent_id]
|
||||||
|
agg["prompt_tokens"] += int(row.get("prompt_tokens", 0))
|
||||||
|
agg["completion_tokens"] += int(row.get("completion_tokens", 0))
|
||||||
|
agg["total_tokens"] += int(row.get("total_tokens", 0))
|
||||||
|
agg["cost_usd"] += float(row.get("cost_usd", 0.0))
|
||||||
|
agg["call_count"] += 1
|
||||||
|
return list(aggregated.values())
|
||||||
|
|
||||||
|
|
||||||
|
def _aggregate_rows_by_provider(rows: list[dict]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Aggregate a list of raw audit event row dicts by provider.
|
||||||
|
|
||||||
|
Each row dict must have: provider, cost_usd.
|
||||||
|
|
||||||
|
Returns a list of dicts with keys: provider, cost_usd, call_count.
|
||||||
|
"""
|
||||||
|
aggregated: dict[str, dict] = {}
|
||||||
|
for row in rows:
|
||||||
|
provider = str(row.get("provider", "unknown"))
|
||||||
|
if provider not in aggregated:
|
||||||
|
aggregated[provider] = {
|
||||||
|
"provider": provider,
|
||||||
|
"cost_usd": 0.0,
|
||||||
|
"call_count": 0,
|
||||||
|
}
|
||||||
|
agg = aggregated[provider]
|
||||||
|
agg["cost_usd"] += float(row.get("cost_usd", 0.0))
|
||||||
|
agg["call_count"] += 1
|
||||||
|
return list(aggregated.values())
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Budget threshold helper (also used by tests directly)
|
# Budget threshold helper (also used by tests directly)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -82,7 +82,11 @@ class AuditEvent(AuditBase):
|
|||||||
nullable=True,
|
nullable=True,
|
||||||
comment="Duration of the operation in milliseconds",
|
comment="Duration of the operation in milliseconds",
|
||||||
)
|
)
|
||||||
metadata: Mapped[dict[str, Any]] = mapped_column(
|
# NOTE: 'metadata' is reserved by SQLAlchemy's DeclarativeBase — the Python
|
||||||
|
# attribute is named 'event_metadata' but maps to the 'metadata' DB column.
|
||||||
|
# The AuditLogger uses raw SQL text() so this ORM attribute is for read queries only.
|
||||||
|
event_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||||
|
"metadata",
|
||||||
JSONB,
|
JSONB,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default="{}",
|
server_default="{}",
|
||||||
|
|||||||
72
tests/unit/test_slack_oauth.py
Normal file
72
tests/unit/test_slack_oauth.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for Slack OAuth state generation and verification.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- generate_oauth_state produces a base64-encoded string containing tenant_id
|
||||||
|
- verify_oauth_state returns the correct tenant_id for a valid state
|
||||||
|
- verify_oauth_state raises ValueError for a tampered state
|
||||||
|
- verify_oauth_state raises ValueError for a state signed with wrong secret
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from shared.api.channels import generate_oauth_state, verify_oauth_state
|
||||||
|
|
||||||
|
_SECRET = "test-hmac-secret-do-not-use-in-production"
|
||||||
|
_TENANT_ID = "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_oauth_state_is_string() -> None:
|
||||||
|
"""generate_oauth_state returns a non-empty string."""
|
||||||
|
state = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
|
||||||
|
assert isinstance(state, str)
|
||||||
|
assert len(state) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_oauth_state_contains_tenant_id() -> None:
|
||||||
|
"""
|
||||||
|
generate_oauth_state embeds the tenant_id in the state payload.
|
||||||
|
Verifying the state should return the original tenant_id.
|
||||||
|
"""
|
||||||
|
state = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
|
||||||
|
recovered = verify_oauth_state(state=state, secret=_SECRET)
|
||||||
|
assert recovered == _TENANT_ID
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_oauth_state_valid() -> None:
|
||||||
|
"""verify_oauth_state returns correct tenant_id for a freshly generated state."""
|
||||||
|
state = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
|
||||||
|
result = verify_oauth_state(state=state, secret=_SECRET)
|
||||||
|
assert result == _TENANT_ID
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_oauth_state_tampered() -> None:
|
||||||
|
"""verify_oauth_state raises ValueError if the state payload is tampered."""
|
||||||
|
state = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
|
||||||
|
|
||||||
|
# Tamper: append garbage to the state string
|
||||||
|
tampered = state + "TAMPERED"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
verify_oauth_state(state=tampered, secret=_SECRET)
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_oauth_state_wrong_secret() -> None:
|
||||||
|
"""verify_oauth_state raises ValueError if verified with the wrong secret."""
|
||||||
|
state = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
verify_oauth_state(state=state, secret="wrong-secret")
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_oauth_state_nonce_differs() -> None:
|
||||||
|
"""Two calls to generate_oauth_state produce different states (random nonce)."""
|
||||||
|
state1 = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
|
||||||
|
state2 = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
|
||||||
|
# Different nonce means different state tokens
|
||||||
|
assert state1 != state2
|
||||||
|
# Both must still verify correctly
|
||||||
|
assert verify_oauth_state(state=state1, secret=_SECRET) == _TENANT_ID
|
||||||
|
assert verify_oauth_state(state=state2, secret=_SECRET) == _TENANT_ID
|
||||||
177
tests/unit/test_stripe_webhooks.py
Normal file
177
tests/unit/test_stripe_webhooks.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for Stripe webhook handler logic.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- process_stripe_event idempotency: same event_id processed twice returns
|
||||||
|
"already_processed" on the second call
|
||||||
|
- customer.subscription.updated: updates tenant subscription_status
|
||||||
|
- customer.subscription.deleted: sets status=canceled and deactivates agents
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from shared.api.billing import process_stripe_event
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_session() -> AsyncMock:
|
||||||
|
"""Mock AsyncSession that tracks execute calls."""
|
||||||
|
session = AsyncMock()
|
||||||
|
# Default: no existing StripeEvent found (not yet processed)
|
||||||
|
session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tenant_id() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def agent_id() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
class _MockTenant:
|
||||||
|
def __init__(self, tenant_id: str) -> None:
|
||||||
|
self.id = uuid.UUID(tenant_id)
|
||||||
|
self.stripe_customer_id: str | None = None
|
||||||
|
self.stripe_subscription_id: str | None = None
|
||||||
|
self.stripe_subscription_item_id: str | None = None
|
||||||
|
self.subscription_status: str = "trialing"
|
||||||
|
self.trial_ends_at = None
|
||||||
|
self.agent_quota: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class _MockAgent:
|
||||||
|
def __init__(self, agent_id: str, tenant_id: str) -> None:
|
||||||
|
self.id = uuid.UUID(agent_id)
|
||||||
|
self.tenant_id = uuid.UUID(tenant_id)
|
||||||
|
self.is_active: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stripe_webhook_idempotency(mock_session: AsyncMock, tenant_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Processing the same event_id twice returns 'already_processed' on second call.
|
||||||
|
The second call should detect the existing StripeEvent and skip processing.
|
||||||
|
"""
|
||||||
|
event_id = "evt_test_idempotent_001"
|
||||||
|
|
||||||
|
event_data = {
|
||||||
|
"id": event_id,
|
||||||
|
"type": "customer.subscription.updated",
|
||||||
|
"data": {
|
||||||
|
"object": {
|
||||||
|
"id": "sub_test123",
|
||||||
|
"status": "active",
|
||||||
|
"customer": "cus_test123",
|
||||||
|
"trial_end": None,
|
||||||
|
"items": {"data": [{"id": "si_test123", "quantity": 2}]},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# First call: no existing event in DB
|
||||||
|
first_session = AsyncMock()
|
||||||
|
first_session.execute.return_value = MagicMock(
|
||||||
|
scalar_one_or_none=MagicMock(return_value=None),
|
||||||
|
scalars=MagicMock(return_value=MagicMock(all=MagicMock(return_value=[]))),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_tenant = _MockTenant(tenant_id)
|
||||||
|
with patch("shared.api.billing._get_tenant_by_stripe_customer", new_callable=AsyncMock) as mock_get:
|
||||||
|
mock_get.return_value = mock_tenant
|
||||||
|
result1 = await process_stripe_event(event_data, first_session)
|
||||||
|
|
||||||
|
assert result1 != "already_processed"
|
||||||
|
|
||||||
|
# Second call: event already in DB (simulating idempotent duplicate)
|
||||||
|
existing_event = MagicMock() # Non-None means already processed
|
||||||
|
second_session = AsyncMock()
|
||||||
|
second_session.execute.return_value = MagicMock(
|
||||||
|
scalar_one_or_none=MagicMock(return_value=existing_event)
|
||||||
|
)
|
||||||
|
|
||||||
|
result2 = await process_stripe_event(event_data, second_session)
|
||||||
|
assert result2 == "already_processed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stripe_subscription_updated(mock_session: AsyncMock, tenant_id: str) -> None:
|
||||||
|
"""
|
||||||
|
customer.subscription.updated event updates tenant subscription_status.
|
||||||
|
"""
|
||||||
|
event_data = {
|
||||||
|
"id": "evt_test_sub_updated",
|
||||||
|
"type": "customer.subscription.updated",
|
||||||
|
"data": {
|
||||||
|
"object": {
|
||||||
|
"id": "sub_updated123",
|
||||||
|
"status": "active",
|
||||||
|
"customer": "cus_tenant123",
|
||||||
|
"trial_end": None,
|
||||||
|
"items": {"data": [{"id": "si_updated123", "quantity": 3}]},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_tenant = _MockTenant(tenant_id)
|
||||||
|
assert mock_tenant.subscription_status == "trialing"
|
||||||
|
|
||||||
|
with patch("shared.api.billing._get_tenant_by_stripe_customer", new_callable=AsyncMock) as mock_get:
|
||||||
|
mock_get.return_value = mock_tenant
|
||||||
|
result = await process_stripe_event(event_data, mock_session)
|
||||||
|
|
||||||
|
assert result != "already_processed"
|
||||||
|
assert mock_tenant.subscription_status == "active"
|
||||||
|
assert mock_tenant.stripe_subscription_id == "sub_updated123"
|
||||||
|
assert mock_tenant.agent_quota == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stripe_cancellation(mock_session: AsyncMock, tenant_id: str, agent_id: str) -> None:
|
||||||
|
"""
|
||||||
|
customer.subscription.deleted sets status=canceled and deactivates all tenant agents.
|
||||||
|
"""
|
||||||
|
event_data = {
|
||||||
|
"id": "evt_test_canceled",
|
||||||
|
"type": "customer.subscription.deleted",
|
||||||
|
"data": {
|
||||||
|
"object": {
|
||||||
|
"id": "sub_canceled123",
|
||||||
|
"status": "canceled",
|
||||||
|
"customer": "cus_tenant_cancel",
|
||||||
|
"trial_end": None,
|
||||||
|
"items": {"data": []},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_tenant = _MockTenant(tenant_id)
|
||||||
|
mock_tenant.subscription_status = "active"
|
||||||
|
|
||||||
|
mock_agent = _MockAgent(agent_id, tenant_id)
|
||||||
|
assert mock_agent.is_active is True
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("shared.api.billing._get_tenant_by_stripe_customer", new_callable=AsyncMock) as mock_get_tenant,
|
||||||
|
patch("shared.api.billing._deactivate_all_agents", new_callable=AsyncMock) as mock_deactivate,
|
||||||
|
):
|
||||||
|
mock_get_tenant.return_value = mock_tenant
|
||||||
|
# Simulate deactivation side effect
|
||||||
|
async def _do_deactivate(session, tenant_id_arg): # type: ignore[no-untyped-def]
|
||||||
|
mock_agent.is_active = False
|
||||||
|
mock_deactivate.side_effect = _do_deactivate
|
||||||
|
|
||||||
|
result = await process_stripe_event(event_data, mock_session)
|
||||||
|
|
||||||
|
assert result != "already_processed"
|
||||||
|
assert mock_tenant.subscription_status == "canceled"
|
||||||
|
assert mock_agent.is_active is False
|
||||||
|
mock_deactivate.assert_called_once()
|
||||||
141
tests/unit/test_usage_aggregation.py
Normal file
141
tests/unit/test_usage_aggregation.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for usage aggregation logic.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- aggregate_usage_by_agent groups prompt_tokens, completion_tokens, cost_usd by agent_id
|
||||||
|
- aggregate_usage_by_provider groups cost_usd by provider name
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from shared.api.usage import (
|
||||||
|
_aggregate_rows_by_agent,
|
||||||
|
_aggregate_rows_by_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sample audit event rows (simulating DB query results)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_row(
|
||||||
|
agent_id: str,
|
||||||
|
provider: str,
|
||||||
|
prompt_tokens: int = 100,
|
||||||
|
completion_tokens: int = 50,
|
||||||
|
cost_usd: float = 0.01,
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"provider": provider,
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": prompt_tokens + completion_tokens,
|
||||||
|
"cost_usd": cost_usd,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: per-agent aggregation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_usage_group_by_agent_single_agent() -> None:
|
||||||
|
"""Single agent with two calls aggregates tokens and cost correctly."""
|
||||||
|
agent_id = str(uuid.uuid4())
|
||||||
|
rows = [
|
||||||
|
_make_row(agent_id, "openai", prompt_tokens=100, completion_tokens=50, cost_usd=0.01),
|
||||||
|
_make_row(agent_id, "openai", prompt_tokens=200, completion_tokens=80, cost_usd=0.02),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = _aggregate_rows_by_agent(rows)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["agent_id"] == agent_id
|
||||||
|
assert result[0]["prompt_tokens"] == 300
|
||||||
|
assert result[0]["completion_tokens"] == 130
|
||||||
|
assert result[0]["total_tokens"] == 430
|
||||||
|
assert abs(result[0]["cost_usd"] - 0.03) < 1e-9
|
||||||
|
assert result[0]["call_count"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_group_by_agent_multiple_agents() -> None:
|
||||||
|
"""Multiple agents are aggregated separately."""
|
||||||
|
agent_a = str(uuid.uuid4())
|
||||||
|
agent_b = str(uuid.uuid4())
|
||||||
|
rows = [
|
||||||
|
_make_row(agent_a, "anthropic", prompt_tokens=100, completion_tokens=40, cost_usd=0.005),
|
||||||
|
_make_row(agent_b, "openai", prompt_tokens=500, completion_tokens=200, cost_usd=0.05),
|
||||||
|
_make_row(agent_a, "anthropic", prompt_tokens=50, completion_tokens=20, cost_usd=0.002),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = _aggregate_rows_by_agent(rows)
|
||||||
|
by_id = {r["agent_id"]: r for r in result}
|
||||||
|
|
||||||
|
assert agent_a in by_id
|
||||||
|
assert agent_b in by_id
|
||||||
|
|
||||||
|
assert by_id[agent_a]["prompt_tokens"] == 150
|
||||||
|
assert by_id[agent_a]["call_count"] == 2
|
||||||
|
|
||||||
|
assert by_id[agent_b]["prompt_tokens"] == 500
|
||||||
|
assert by_id[agent_b]["call_count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_group_by_agent_empty_rows() -> None:
|
||||||
|
"""Empty input returns empty list."""
|
||||||
|
result = _aggregate_rows_by_agent([])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: per-provider aggregation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_usage_group_by_provider_single_provider() -> None:
|
||||||
|
"""Single provider aggregates cost correctly."""
|
||||||
|
agent_id = str(uuid.uuid4())
|
||||||
|
rows = [
|
||||||
|
_make_row(agent_id, "openai", cost_usd=0.01),
|
||||||
|
_make_row(agent_id, "openai", cost_usd=0.02),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = _aggregate_rows_by_provider(rows)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["provider"] == "openai"
|
||||||
|
assert abs(result[0]["cost_usd"] - 0.03) < 1e-9
|
||||||
|
assert result[0]["call_count"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_group_by_provider_multiple_providers() -> None:
|
||||||
|
"""Multiple providers are grouped independently."""
|
||||||
|
agent_id = str(uuid.uuid4())
|
||||||
|
rows = [
|
||||||
|
_make_row(agent_id, "openai", cost_usd=0.01),
|
||||||
|
_make_row(agent_id, "anthropic", cost_usd=0.02),
|
||||||
|
_make_row(agent_id, "openai", cost_usd=0.03),
|
||||||
|
_make_row(agent_id, "anthropic", cost_usd=0.04),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = _aggregate_rows_by_provider(rows)
|
||||||
|
by_provider = {r["provider"]: r for r in result}
|
||||||
|
|
||||||
|
assert "openai" in by_provider
|
||||||
|
assert "anthropic" in by_provider
|
||||||
|
|
||||||
|
assert abs(by_provider["openai"]["cost_usd"] - 0.04) < 1e-9
|
||||||
|
assert by_provider["openai"]["call_count"] == 2
|
||||||
|
|
||||||
|
assert abs(by_provider["anthropic"]["cost_usd"] - 0.06) < 1e-9
|
||||||
|
assert by_provider["anthropic"]["call_count"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_group_by_provider_empty_rows() -> None:
|
||||||
|
"""Empty input returns empty list."""
|
||||||
|
result = _aggregate_rows_by_provider([])
|
||||||
|
assert result == []
|
||||||
Reference in New Issue
Block a user