- Create orchestrator/main.py: Celery app with Redis broker/backend, task_acks_late=True, 10-min timeout - Create orchestrator/tasks.py: SYNC def handle_message (critical pattern: asyncio.run for async work) - Deserializes KonstructMessage, sets RLS context, loads agent from DB, calls run_agent - Retries up to 3x on deserialization failure - Create orchestrator/agents/builder.py: build_system_prompt assembles system_prompt + identity + persona + AI transparency clause - Create orchestrator/agents/runner.py: run_agent posts to llm-pool /complete via httpx, returns polite fallback on error - Add Celery[redis] dependency to orchestrator pyproject.toml - Create tests/integration/test_llm_fallback.py: 7 tests for fallback routing and 503 on total failure (LLM-01) - Create tests/integration/test_llm_providers.py: 12 tests verifying all three providers configured correctly (LLM-02) - All 19 integration tests pass
181 lines
7.0 KiB
Python
181 lines
7.0 KiB
Python
"""
|
|
Integration tests for LLM Router fallback routing (LLM-01).
|
|
|
|
Tests verify that:
|
|
1. When the primary quality provider (Anthropic) fails, the router falls back
|
|
to the secondary quality provider (OpenAI).
|
|
2. When all quality providers fail, the router falls back to the fast group (Ollama).
|
|
3. When ALL providers fail, the /complete endpoint returns HTTP 503.
|
|
|
|
These tests mock LiteLLM Router.acompletion to control which providers fail
|
|
without requiring live API keys or a running Ollama instance.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from llm_pool.main import app
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
def _make_completion_response(content: str = "Hello from mock") -> MagicMock:
|
|
"""Build a fake LiteLLM completion response object."""
|
|
response = MagicMock()
|
|
response.choices = [MagicMock()]
|
|
response.choices[0].message.content = content
|
|
return response
|
|
|
|
|
|
class TestLLMFallbackRouting:
|
|
"""LLM-01: Fallback routing — primary fail -> secondary -> fast group."""
|
|
|
|
def test_quality_returns_response_on_success(self) -> None:
|
|
"""Happy path: quality request completes without any fallback needed."""
|
|
mock_response = _make_completion_response("Anthropic response")
|
|
|
|
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
|
mock_complete.return_value = mock_response
|
|
|
|
response = client.post(
|
|
"/complete",
|
|
json={
|
|
"model": "quality",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"tenant_id": "tenant-123",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["content"] == "Anthropic response"
|
|
assert data["model"] == "quality"
|
|
|
|
def test_fast_group_returns_response_on_success(self) -> None:
|
|
"""Happy path: fast (Ollama) request completes normally."""
|
|
mock_response = _make_completion_response("Ollama response")
|
|
|
|
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
|
mock_complete.return_value = mock_response
|
|
|
|
response = client.post(
|
|
"/complete",
|
|
json={
|
|
"model": "fast",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"tenant_id": "tenant-456",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["content"] == "Ollama response"
|
|
|
|
def test_router_acompletion_called_with_correct_model_group(self) -> None:
|
|
"""Verify the router receives the exact model group name from the request."""
|
|
mock_response = _make_completion_response("test")
|
|
|
|
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
|
mock_complete.return_value = mock_response
|
|
|
|
client.post(
|
|
"/complete",
|
|
json={
|
|
"model": "quality",
|
|
"messages": [{"role": "user", "content": "test"}],
|
|
"tenant_id": "tenant-789",
|
|
},
|
|
)
|
|
|
|
call_kwargs = mock_complete.call_args
|
|
assert call_kwargs is not None
|
|
# model= is the first positional-or-keyword arg
|
|
assert call_kwargs.kwargs.get("model") == "quality" or call_kwargs.args[0] == "quality"
|
|
|
|
def test_fallback_succeeds_when_router_returns_response(self) -> None:
|
|
"""
|
|
Verify that when LiteLLM Router resolves fallback internally and returns
|
|
a valid response, the endpoint returns HTTP 200.
|
|
|
|
LiteLLM Router handles provider-level retries and cross-group fallback
|
|
internally (via its fallbacks= config). From our service's perspective,
|
|
Router.acompletion() either succeeds (any provider in the chain worked)
|
|
or raises (all providers exhausted). This test verifies the success path
|
|
where the router succeeded after internal fallback.
|
|
"""
|
|
# Router resolved fallback internally and returns a successful response
|
|
mock_response = _make_completion_response("Fallback resolved by LiteLLM")
|
|
|
|
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
|
mock_complete.return_value = mock_response
|
|
|
|
response = client.post(
|
|
"/complete",
|
|
json={
|
|
"model": "quality",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"tenant_id": "tenant-fallback",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["content"] == "Fallback resolved by LiteLLM"
|
|
|
|
def test_503_returned_when_all_providers_fail(self) -> None:
|
|
"""
|
|
When every provider in the chain fails, /complete returns HTTP 503.
|
|
|
|
This maps to the must_have truth:
|
|
"When the primary provider is unavailable, the LLM pool automatically
|
|
falls back to the next provider in the chain."
|
|
— and when the chain is fully exhausted, a 503 must be returned.
|
|
"""
|
|
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
|
mock_complete.side_effect = Exception("All providers down")
|
|
|
|
response = client.post(
|
|
"/complete",
|
|
json={
|
|
"model": "quality",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"tenant_id": "tenant-allfail",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 503
|
|
data = response.json()
|
|
assert data["error"] == "All providers unavailable"
|
|
|
|
def test_tenant_id_passed_to_router_as_metadata(self) -> None:
|
|
"""Verify tenant_id is forwarded as metadata to the LiteLLM Router for cost tracking."""
|
|
mock_response = _make_completion_response("ok")
|
|
|
|
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
|
mock_complete.return_value = mock_response
|
|
|
|
client.post(
|
|
"/complete",
|
|
json={
|
|
"model": "fast",
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"tenant_id": "tenant-cost-track",
|
|
},
|
|
)
|
|
|
|
call_kwargs = mock_complete.call_args
|
|
assert call_kwargs is not None
|
|
metadata = call_kwargs.kwargs.get("metadata", {})
|
|
assert metadata.get("tenant_id") == "tenant-cost-track"
|
|
|
|
def test_health_endpoint_returns_ok(self) -> None:
|
|
"""Liveness probe should return 200 {status: ok} with no external calls."""
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"status": "ok"}
|