feat(03-01): backend API endpoints — channels, billing, usage, and audit logger enhancement
- Create channels.py: HMAC-signed OAuth state generation/verification, Slack OAuth install/callback, WhatsApp manual connect, test message endpoint - Create billing.py: Stripe Checkout session, billing portal session, webhook handler with idempotency (StripeEvent table), subscription lifecycle management - Update usage.py: add _aggregate_rows_by_agent and _aggregate_rows_by_provider helpers (unit-testable without DB), complete usage endpoints - Fix audit.py: rename 'metadata' attribute to 'event_metadata' (SQLAlchemy 2.0 DeclarativeBase reserves 'metadata') - Enhance runner.py: audit log now includes prompt_tokens, completion_tokens, total_tokens, cost_usd, provider in LLM call metadata - Update api/__init__.py to export all new routers - All 27 unit tests passing
This commit is contained in:
72
tests/unit/test_slack_oauth.py
Normal file
72
tests/unit/test_slack_oauth.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Unit tests for Slack OAuth state generation and verification.
|
||||
|
||||
Tests:
|
||||
- generate_oauth_state produces a base64-encoded string containing tenant_id
|
||||
- verify_oauth_state returns the correct tenant_id for a valid state
|
||||
- verify_oauth_state raises ValueError for a tampered state
|
||||
- verify_oauth_state raises ValueError for a state signed with wrong secret
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.api.channels import generate_oauth_state, verify_oauth_state
|
||||
|
||||
_SECRET = "test-hmac-secret-do-not-use-in-production"
|
||||
_TENANT_ID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
|
||||
def test_generate_oauth_state_is_string() -> None:
|
||||
"""generate_oauth_state returns a non-empty string."""
|
||||
state = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
|
||||
assert isinstance(state, str)
|
||||
assert len(state) > 0
|
||||
|
||||
|
||||
def test_generate_oauth_state_contains_tenant_id() -> None:
|
||||
"""
|
||||
generate_oauth_state embeds the tenant_id in the state payload.
|
||||
Verifying the state should return the original tenant_id.
|
||||
"""
|
||||
state = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
|
||||
recovered = verify_oauth_state(state=state, secret=_SECRET)
|
||||
assert recovered == _TENANT_ID
|
||||
|
||||
|
||||
def test_verify_oauth_state_valid() -> None:
|
||||
"""verify_oauth_state returns correct tenant_id for a freshly generated state."""
|
||||
state = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
|
||||
result = verify_oauth_state(state=state, secret=_SECRET)
|
||||
assert result == _TENANT_ID
|
||||
|
||||
|
||||
def test_verify_oauth_state_tampered() -> None:
|
||||
"""verify_oauth_state raises ValueError if the state payload is tampered."""
|
||||
state = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
|
||||
|
||||
# Tamper: append garbage to the state string
|
||||
tampered = state + "TAMPERED"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
verify_oauth_state(state=tampered, secret=_SECRET)
|
||||
|
||||
|
||||
def test_verify_oauth_state_wrong_secret() -> None:
|
||||
"""verify_oauth_state raises ValueError if verified with the wrong secret."""
|
||||
state = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
verify_oauth_state(state=state, secret="wrong-secret")
|
||||
|
||||
|
||||
def test_generate_oauth_state_nonce_differs() -> None:
|
||||
"""Two calls to generate_oauth_state produce different states (random nonce)."""
|
||||
state1 = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
|
||||
state2 = generate_oauth_state(tenant_id=_TENANT_ID, secret=_SECRET)
|
||||
# Different nonce means different state tokens
|
||||
assert state1 != state2
|
||||
# Both must still verify correctly
|
||||
assert verify_oauth_state(state=state1, secret=_SECRET) == _TENANT_ID
|
||||
assert verify_oauth_state(state=state2, secret=_SECRET) == _TENANT_ID
|
||||
177
tests/unit/test_stripe_webhooks.py
Normal file
177
tests/unit/test_stripe_webhooks.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Unit tests for Stripe webhook handler logic.
|
||||
|
||||
Tests:
|
||||
- process_stripe_event idempotency: same event_id processed twice returns
|
||||
"already_processed" on the second call
|
||||
- customer.subscription.updated: updates tenant subscription_status
|
||||
- customer.subscription.deleted: sets status=canceled and deactivates agents
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.api.billing import process_stripe_event
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_session() -> AsyncMock:
|
||||
"""Mock AsyncSession that tracks execute calls."""
|
||||
session = AsyncMock()
|
||||
# Default: no existing StripeEvent found (not yet processed)
|
||||
session.execute.return_value = MagicMock(scalar_one_or_none=MagicMock(return_value=None))
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tenant_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def agent_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class _MockTenant:
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self.id = uuid.UUID(tenant_id)
|
||||
self.stripe_customer_id: str | None = None
|
||||
self.stripe_subscription_id: str | None = None
|
||||
self.stripe_subscription_item_id: str | None = None
|
||||
self.subscription_status: str = "trialing"
|
||||
self.trial_ends_at = None
|
||||
self.agent_quota: int = 0
|
||||
|
||||
|
||||
class _MockAgent:
|
||||
def __init__(self, agent_id: str, tenant_id: str) -> None:
|
||||
self.id = uuid.UUID(agent_id)
|
||||
self.tenant_id = uuid.UUID(tenant_id)
|
||||
self.is_active: bool = True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stripe_webhook_idempotency(mock_session: AsyncMock, tenant_id: str) -> None:
|
||||
"""
|
||||
Processing the same event_id twice returns 'already_processed' on second call.
|
||||
The second call should detect the existing StripeEvent and skip processing.
|
||||
"""
|
||||
event_id = "evt_test_idempotent_001"
|
||||
|
||||
event_data = {
|
||||
"id": event_id,
|
||||
"type": "customer.subscription.updated",
|
||||
"data": {
|
||||
"object": {
|
||||
"id": "sub_test123",
|
||||
"status": "active",
|
||||
"customer": "cus_test123",
|
||||
"trial_end": None,
|
||||
"items": {"data": [{"id": "si_test123", "quantity": 2}]},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# First call: no existing event in DB
|
||||
first_session = AsyncMock()
|
||||
first_session.execute.return_value = MagicMock(
|
||||
scalar_one_or_none=MagicMock(return_value=None),
|
||||
scalars=MagicMock(return_value=MagicMock(all=MagicMock(return_value=[]))),
|
||||
)
|
||||
|
||||
mock_tenant = _MockTenant(tenant_id)
|
||||
with patch("shared.api.billing._get_tenant_by_stripe_customer", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_tenant
|
||||
result1 = await process_stripe_event(event_data, first_session)
|
||||
|
||||
assert result1 != "already_processed"
|
||||
|
||||
# Second call: event already in DB (simulating idempotent duplicate)
|
||||
existing_event = MagicMock() # Non-None means already processed
|
||||
second_session = AsyncMock()
|
||||
second_session.execute.return_value = MagicMock(
|
||||
scalar_one_or_none=MagicMock(return_value=existing_event)
|
||||
)
|
||||
|
||||
result2 = await process_stripe_event(event_data, second_session)
|
||||
assert result2 == "already_processed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stripe_subscription_updated(mock_session: AsyncMock, tenant_id: str) -> None:
|
||||
"""
|
||||
customer.subscription.updated event updates tenant subscription_status.
|
||||
"""
|
||||
event_data = {
|
||||
"id": "evt_test_sub_updated",
|
||||
"type": "customer.subscription.updated",
|
||||
"data": {
|
||||
"object": {
|
||||
"id": "sub_updated123",
|
||||
"status": "active",
|
||||
"customer": "cus_tenant123",
|
||||
"trial_end": None,
|
||||
"items": {"data": [{"id": "si_updated123", "quantity": 3}]},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
mock_tenant = _MockTenant(tenant_id)
|
||||
assert mock_tenant.subscription_status == "trialing"
|
||||
|
||||
with patch("shared.api.billing._get_tenant_by_stripe_customer", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_tenant
|
||||
result = await process_stripe_event(event_data, mock_session)
|
||||
|
||||
assert result != "already_processed"
|
||||
assert mock_tenant.subscription_status == "active"
|
||||
assert mock_tenant.stripe_subscription_id == "sub_updated123"
|
||||
assert mock_tenant.agent_quota == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stripe_cancellation(mock_session: AsyncMock, tenant_id: str, agent_id: str) -> None:
|
||||
"""
|
||||
customer.subscription.deleted sets status=canceled and deactivates all tenant agents.
|
||||
"""
|
||||
event_data = {
|
||||
"id": "evt_test_canceled",
|
||||
"type": "customer.subscription.deleted",
|
||||
"data": {
|
||||
"object": {
|
||||
"id": "sub_canceled123",
|
||||
"status": "canceled",
|
||||
"customer": "cus_tenant_cancel",
|
||||
"trial_end": None,
|
||||
"items": {"data": []},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
mock_tenant = _MockTenant(tenant_id)
|
||||
mock_tenant.subscription_status = "active"
|
||||
|
||||
mock_agent = _MockAgent(agent_id, tenant_id)
|
||||
assert mock_agent.is_active is True
|
||||
|
||||
with (
|
||||
patch("shared.api.billing._get_tenant_by_stripe_customer", new_callable=AsyncMock) as mock_get_tenant,
|
||||
patch("shared.api.billing._deactivate_all_agents", new_callable=AsyncMock) as mock_deactivate,
|
||||
):
|
||||
mock_get_tenant.return_value = mock_tenant
|
||||
# Simulate deactivation side effect
|
||||
async def _do_deactivate(session, tenant_id_arg): # type: ignore[no-untyped-def]
|
||||
mock_agent.is_active = False
|
||||
mock_deactivate.side_effect = _do_deactivate
|
||||
|
||||
result = await process_stripe_event(event_data, mock_session)
|
||||
|
||||
assert result != "already_processed"
|
||||
assert mock_tenant.subscription_status == "canceled"
|
||||
assert mock_agent.is_active is False
|
||||
mock_deactivate.assert_called_once()
|
||||
141
tests/unit/test_usage_aggregation.py
Normal file
141
tests/unit/test_usage_aggregation.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Unit tests for usage aggregation logic.
|
||||
|
||||
Tests:
|
||||
- aggregate_usage_by_agent groups prompt_tokens, completion_tokens, cost_usd by agent_id
|
||||
- aggregate_usage_by_provider groups cost_usd by provider name
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.api.usage import (
|
||||
_aggregate_rows_by_agent,
|
||||
_aggregate_rows_by_provider,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sample audit event rows (simulating DB query results)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_row(
|
||||
agent_id: str,
|
||||
provider: str,
|
||||
prompt_tokens: int = 100,
|
||||
completion_tokens: int = 50,
|
||||
cost_usd: float = 0.01,
|
||||
) -> dict:
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"provider": provider,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
"cost_usd": cost_usd,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: per-agent aggregation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_usage_group_by_agent_single_agent() -> None:
|
||||
"""Single agent with two calls aggregates tokens and cost correctly."""
|
||||
agent_id = str(uuid.uuid4())
|
||||
rows = [
|
||||
_make_row(agent_id, "openai", prompt_tokens=100, completion_tokens=50, cost_usd=0.01),
|
||||
_make_row(agent_id, "openai", prompt_tokens=200, completion_tokens=80, cost_usd=0.02),
|
||||
]
|
||||
|
||||
result = _aggregate_rows_by_agent(rows)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["agent_id"] == agent_id
|
||||
assert result[0]["prompt_tokens"] == 300
|
||||
assert result[0]["completion_tokens"] == 130
|
||||
assert result[0]["total_tokens"] == 430
|
||||
assert abs(result[0]["cost_usd"] - 0.03) < 1e-9
|
||||
assert result[0]["call_count"] == 2
|
||||
|
||||
|
||||
def test_usage_group_by_agent_multiple_agents() -> None:
|
||||
"""Multiple agents are aggregated separately."""
|
||||
agent_a = str(uuid.uuid4())
|
||||
agent_b = str(uuid.uuid4())
|
||||
rows = [
|
||||
_make_row(agent_a, "anthropic", prompt_tokens=100, completion_tokens=40, cost_usd=0.005),
|
||||
_make_row(agent_b, "openai", prompt_tokens=500, completion_tokens=200, cost_usd=0.05),
|
||||
_make_row(agent_a, "anthropic", prompt_tokens=50, completion_tokens=20, cost_usd=0.002),
|
||||
]
|
||||
|
||||
result = _aggregate_rows_by_agent(rows)
|
||||
by_id = {r["agent_id"]: r for r in result}
|
||||
|
||||
assert agent_a in by_id
|
||||
assert agent_b in by_id
|
||||
|
||||
assert by_id[agent_a]["prompt_tokens"] == 150
|
||||
assert by_id[agent_a]["call_count"] == 2
|
||||
|
||||
assert by_id[agent_b]["prompt_tokens"] == 500
|
||||
assert by_id[agent_b]["call_count"] == 1
|
||||
|
||||
|
||||
def test_usage_group_by_agent_empty_rows() -> None:
|
||||
"""Empty input returns empty list."""
|
||||
result = _aggregate_rows_by_agent([])
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: per-provider aggregation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_usage_group_by_provider_single_provider() -> None:
|
||||
"""Single provider aggregates cost correctly."""
|
||||
agent_id = str(uuid.uuid4())
|
||||
rows = [
|
||||
_make_row(agent_id, "openai", cost_usd=0.01),
|
||||
_make_row(agent_id, "openai", cost_usd=0.02),
|
||||
]
|
||||
|
||||
result = _aggregate_rows_by_provider(rows)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["provider"] == "openai"
|
||||
assert abs(result[0]["cost_usd"] - 0.03) < 1e-9
|
||||
assert result[0]["call_count"] == 2
|
||||
|
||||
|
||||
def test_usage_group_by_provider_multiple_providers() -> None:
|
||||
"""Multiple providers are grouped independently."""
|
||||
agent_id = str(uuid.uuid4())
|
||||
rows = [
|
||||
_make_row(agent_id, "openai", cost_usd=0.01),
|
||||
_make_row(agent_id, "anthropic", cost_usd=0.02),
|
||||
_make_row(agent_id, "openai", cost_usd=0.03),
|
||||
_make_row(agent_id, "anthropic", cost_usd=0.04),
|
||||
]
|
||||
|
||||
result = _aggregate_rows_by_provider(rows)
|
||||
by_provider = {r["provider"]: r for r in result}
|
||||
|
||||
assert "openai" in by_provider
|
||||
assert "anthropic" in by_provider
|
||||
|
||||
assert abs(by_provider["openai"]["cost_usd"] - 0.04) < 1e-9
|
||||
assert by_provider["openai"]["call_count"] == 2
|
||||
|
||||
assert abs(by_provider["anthropic"]["cost_usd"] - 0.06) < 1e-9
|
||||
assert by_provider["anthropic"]["call_count"] == 2
|
||||
|
||||
|
||||
def test_usage_group_by_provider_empty_rows() -> None:
|
||||
"""Empty input returns empty list."""
|
||||
result = _aggregate_rows_by_provider([])
|
||||
assert result == []
|
||||
Reference in New Issue
Block a user