""" Unit tests for usage aggregation logic. Tests: - aggregate_usage_by_agent groups prompt_tokens, completion_tokens, cost_usd by agent_id - aggregate_usage_by_provider groups cost_usd by provider name """ from __future__ import annotations import uuid from datetime import datetime, timezone import pytest from shared.api.usage import ( _aggregate_rows_by_agent, _aggregate_rows_by_provider, ) # --------------------------------------------------------------------------- # Sample audit event rows (simulating DB query results) # --------------------------------------------------------------------------- def _make_row( agent_id: str, provider: str, prompt_tokens: int = 100, completion_tokens: int = 50, cost_usd: float = 0.01, ) -> dict: return { "agent_id": agent_id, "provider": provider, "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, "cost_usd": cost_usd, } # --------------------------------------------------------------------------- # Tests: per-agent aggregation # --------------------------------------------------------------------------- def test_usage_group_by_agent_single_agent() -> None: """Single agent with two calls aggregates tokens and cost correctly.""" agent_id = str(uuid.uuid4()) rows = [ _make_row(agent_id, "openai", prompt_tokens=100, completion_tokens=50, cost_usd=0.01), _make_row(agent_id, "openai", prompt_tokens=200, completion_tokens=80, cost_usd=0.02), ] result = _aggregate_rows_by_agent(rows) assert len(result) == 1 assert result[0]["agent_id"] == agent_id assert result[0]["prompt_tokens"] == 300 assert result[0]["completion_tokens"] == 130 assert result[0]["total_tokens"] == 430 assert abs(result[0]["cost_usd"] - 0.03) < 1e-9 assert result[0]["call_count"] == 2 def test_usage_group_by_agent_multiple_agents() -> None: """Multiple agents are aggregated separately.""" agent_a = str(uuid.uuid4()) agent_b = str(uuid.uuid4()) rows = [ _make_row(agent_a, "anthropic", prompt_tokens=100, completion_tokens=40, cost_usd=0.005), _make_row(agent_b, "openai", prompt_tokens=500, completion_tokens=200, cost_usd=0.05), _make_row(agent_a, "anthropic", prompt_tokens=50, completion_tokens=20, cost_usd=0.002), ] result = _aggregate_rows_by_agent(rows) by_id = {r["agent_id"]: r for r in result} assert agent_a in by_id assert agent_b in by_id assert by_id[agent_a]["prompt_tokens"] == 150 assert by_id[agent_a]["call_count"] == 2 assert by_id[agent_b]["prompt_tokens"] == 500 assert by_id[agent_b]["call_count"] == 1 def test_usage_group_by_agent_empty_rows() -> None: """Empty input returns empty list.""" result = _aggregate_rows_by_agent([]) assert result == [] # --------------------------------------------------------------------------- # Tests: per-provider aggregation # --------------------------------------------------------------------------- def test_usage_group_by_provider_single_provider() -> None: """Single provider aggregates cost correctly.""" agent_id = str(uuid.uuid4()) rows = [ _make_row(agent_id, "openai", cost_usd=0.01), _make_row(agent_id, "openai", cost_usd=0.02), ] result = _aggregate_rows_by_provider(rows) assert len(result) == 1 assert result[0]["provider"] == "openai" assert abs(result[0]["cost_usd"] - 0.03) < 1e-9 assert result[0]["call_count"] == 2 def test_usage_group_by_provider_multiple_providers() -> None: """Multiple providers are grouped independently.""" agent_id = str(uuid.uuid4()) rows = [ _make_row(agent_id, "openai", cost_usd=0.01), _make_row(agent_id, "anthropic", cost_usd=0.02), _make_row(agent_id, "openai", cost_usd=0.03), _make_row(agent_id, "anthropic", cost_usd=0.04), ] result = _aggregate_rows_by_provider(rows) by_provider = {r["provider"]: r for r in result} assert "openai" in by_provider assert "anthropic" in by_provider assert abs(by_provider["openai"]["cost_usd"] - 0.04) < 1e-9 assert by_provider["openai"]["call_count"] == 2 assert abs(by_provider["anthropic"]["cost_usd"] - 0.06) < 1e-9 assert by_provider["anthropic"]["call_count"] == 2 def test_usage_group_by_provider_empty_rows() -> None: """Empty input returns empty list.""" result = _aggregate_rows_by_provider([]) assert result == []