feat(01-02): Celery orchestrator — handle_message task, system prompt builder, LLM pool runner
- Create orchestrator/main.py: Celery app with Redis broker/backend, task_acks_late=True, 10-min timeout - Create orchestrator/tasks.py: SYNC def handle_message (critical pattern: asyncio.run for async work) - Deserializes KonstructMessage, sets RLS context, loads agent from DB, calls run_agent - Retries up to 3x on deserialization failure - Create orchestrator/agents/builder.py: build_system_prompt assembles system_prompt + identity + persona + AI transparency clause - Create orchestrator/agents/runner.py: run_agent posts to llm-pool /complete via httpx, returns polite fallback on error - Add Celery[redis] dependency to orchestrator pyproject.toml - Create tests/integration/test_llm_fallback.py: 7 tests for fallback routing and 503 on total failure (LLM-01) - Create tests/integration/test_llm_providers.py: 12 tests verifying all three providers configured correctly (LLM-02) - All 19 integration tests pass
This commit is contained in:
14
packages/orchestrator/orchestrator/__init__.py
Normal file
14
packages/orchestrator/orchestrator/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
konstruct-orchestrator — Celery-based agent dispatch service.
|
||||||
|
|
||||||
|
This package provides the Celery application and task definitions for
|
||||||
|
processing inbound Konstruct messages through the agent pipeline:
|
||||||
|
|
||||||
|
1. Deserialize KonstructMessage
|
||||||
|
2. Load agent config from DB (tenant-scoped via RLS)
|
||||||
|
3. Build system prompt from agent persona fields
|
||||||
|
4. Call LLM pool via HTTP
|
||||||
|
5. Return response content
|
||||||
|
|
||||||
|
Import the Celery app from orchestrator.main.
|
||||||
|
"""
|
||||||
7
packages/orchestrator/orchestrator/agents/__init__.py
Normal file
7
packages/orchestrator/orchestrator/agents/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
Agent module — system prompt construction and LLM pool communication.
|
||||||
|
|
||||||
|
Submodules:
|
||||||
|
builder — Assembles system prompt from agent persona fields.
|
||||||
|
runner — Sends completion requests to the LLM pool service.
|
||||||
|
"""
|
||||||
84
packages/orchestrator/orchestrator/agents/builder.py
Normal file
84
packages/orchestrator/orchestrator/agents/builder.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""
|
||||||
|
System prompt builder — assembles the instruction prompt from agent fields.
|
||||||
|
|
||||||
|
The build_system_prompt function combines:
|
||||||
|
1. The agent's explicit system_prompt field (if provided)
|
||||||
|
2. Identity context: name, role
|
||||||
|
3. Persona description (if set)
|
||||||
|
4. AI transparency clause — always appended; agents must not deny being AIs
|
||||||
|
|
||||||
|
AI TRANSPARENCY POLICY:
|
||||||
|
Per Konstruct product design, agents MUST acknowledge they are AI assistants
|
||||||
|
when directly asked. This clause is injected unconditionally to prevent
|
||||||
|
agents from deceiving users, regardless of persona configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from shared.models.tenant import Agent
|
||||||
|
|
||||||
|
|
||||||
|
def build_system_prompt(agent: Agent) -> str:
|
||||||
|
"""
|
||||||
|
Assemble the full system prompt for an agent.
|
||||||
|
|
||||||
|
Combines:
|
||||||
|
- agent.system_prompt (base instructions, if provided)
|
||||||
|
- Identity section: name + role
|
||||||
|
- Persona section (if agent.persona is non-empty)
|
||||||
|
- AI transparency clause (always appended)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: ORM Agent instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A complete system prompt string ready to pass to the LLM.
|
||||||
|
"""
|
||||||
|
parts: list[str] = []
|
||||||
|
|
||||||
|
# 1. Base system prompt (operator-defined instructions)
|
||||||
|
if agent.system_prompt and agent.system_prompt.strip():
|
||||||
|
parts.append(agent.system_prompt.strip())
|
||||||
|
|
||||||
|
# 2. Identity — name and role
|
||||||
|
parts.append(f"Your name is {agent.name}. Your role is {agent.role}.")
|
||||||
|
|
||||||
|
# 3. Persona — tone and behavioral style
|
||||||
|
if agent.persona and agent.persona.strip():
|
||||||
|
parts.append(f"Persona: {agent.persona.strip()}")
|
||||||
|
|
||||||
|
# 4. AI transparency clause — unconditional, non-overridable
|
||||||
|
parts.append(
|
||||||
|
"If asked directly whether you are an AI, always respond honestly that you are an AI assistant."
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def build_messages(
|
||||||
|
system_prompt: str,
|
||||||
|
user_message: str,
|
||||||
|
history: list[dict] | None = None,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Build an OpenAI-format messages list.
|
||||||
|
|
||||||
|
Structure:
|
||||||
|
[system message] + [history messages] + [current user message]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: The full assembled system prompt.
|
||||||
|
user_message: The latest user message text.
|
||||||
|
history: Optional list of prior messages in OpenAI format.
|
||||||
|
Each dict must have "role" and "content" keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of message dicts suitable for an OpenAI-compatible API call.
|
||||||
|
"""
|
||||||
|
messages: list[dict] = [{"role": "system", "content": system_prompt}]
|
||||||
|
|
||||||
|
if history:
|
||||||
|
messages.extend(history)
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": user_message})
|
||||||
|
return messages
|
||||||
87
packages/orchestrator/orchestrator/agents/runner.py
Normal file
87
packages/orchestrator/orchestrator/agents/runner.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""
|
||||||
|
Agent runner — sends completion requests to the LLM pool service.
|
||||||
|
|
||||||
|
Communication pattern:
|
||||||
|
orchestrator.tasks.handle_message
|
||||||
|
→ run_agent (this module, async)
|
||||||
|
→ POST http://llm-pool:8004/complete (httpx async)
|
||||||
|
→ LiteLLM Router (router.py in llm-pool)
|
||||||
|
→ Ollama / Anthropic / OpenAI
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from orchestrator.agents.builder import build_messages, build_system_prompt
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.models.message import KonstructMessage
|
||||||
|
from shared.models.tenant import Agent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_FALLBACK_RESPONSE = (
|
||||||
|
"I'm having trouble processing your request right now. "
|
||||||
|
"Please try again in a moment."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Timeout for LLM pool HTTP requests — generous to allow slow local inference
|
||||||
|
_LLM_TIMEOUT = httpx.Timeout(timeout=120.0, connect=10.0)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_agent(msg: KonstructMessage, agent: Agent) -> str:
|
||||||
|
"""
|
||||||
|
Execute an agent against the LLM pool and return the response text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg: The inbound Konstruct message being processed.
|
||||||
|
agent: The ORM Agent instance that handles this message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The LLM response content as a plain string.
|
||||||
|
Returns a polite fallback message if the LLM pool is unreachable or
|
||||||
|
returns a non-200 response.
|
||||||
|
"""
|
||||||
|
system_prompt = build_system_prompt(agent)
|
||||||
|
|
||||||
|
# Extract user text from the message content
|
||||||
|
user_text: str = msg.content.text or ""
|
||||||
|
|
||||||
|
messages = build_messages(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_message=user_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": agent.model_preference,
|
||||||
|
"messages": messages,
|
||||||
|
"tenant_id": str(msg.tenant_id) if msg.tenant_id else "",
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_pool_url = f"{settings.llm_pool_url}/complete"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=_LLM_TIMEOUT) as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(llm_pool_url, json=payload)
|
||||||
|
except httpx.RequestError:
|
||||||
|
logger.exception(
|
||||||
|
"LLM pool unreachable for tenant=%s agent=%s url=%s",
|
||||||
|
msg.tenant_id,
|
||||||
|
agent.id,
|
||||||
|
llm_pool_url,
|
||||||
|
)
|
||||||
|
return _FALLBACK_RESPONSE
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(
|
||||||
|
"LLM pool returned %d for tenant=%s agent=%s",
|
||||||
|
response.status_code,
|
||||||
|
msg.tenant_id,
|
||||||
|
agent.id,
|
||||||
|
)
|
||||||
|
return _FALLBACK_RESPONSE
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
return str(data.get("content", _FALLBACK_RESPONSE))
|
||||||
42
packages/orchestrator/orchestrator/main.py
Normal file
42
packages/orchestrator/orchestrator/main.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""
|
||||||
|
Celery application for the Konstruct Agent Orchestrator.
|
||||||
|
|
||||||
|
Broker and result backend are both Redis (separate DB indexes to avoid
|
||||||
|
key collisions). Tasks are discovered automatically from orchestrator.tasks.
|
||||||
|
|
||||||
|
Usage (development):
|
||||||
|
celery -A orchestrator.main worker --loglevel=info
|
||||||
|
|
||||||
|
Usage (production — via Docker Compose):
|
||||||
|
celery -A orchestrator.main worker --loglevel=info --concurrency=4
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from celery import Celery
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
app = Celery(
|
||||||
|
"konstruct_orchestrator",
|
||||||
|
broker=settings.celery_broker_url,
|
||||||
|
backend=settings.celery_result_backend,
|
||||||
|
include=["orchestrator.tasks"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Celery configuration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
app.conf.update(
|
||||||
|
task_serializer="json",
|
||||||
|
accept_content=["json"],
|
||||||
|
result_serializer="json",
|
||||||
|
timezone="UTC",
|
||||||
|
enable_utc=True,
|
||||||
|
# Acknowledge tasks only after they complete (not on receipt)
|
||||||
|
# This ensures tasks are retried if the worker crashes mid-execution.
|
||||||
|
task_acks_late=True,
|
||||||
|
# Reject tasks that exceed 10 minutes — prevents runaway LLM calls
|
||||||
|
task_soft_time_limit=540,
|
||||||
|
task_time_limit=600,
|
||||||
|
)
|
||||||
139
packages/orchestrator/orchestrator/tasks.py
Normal file
139
packages/orchestrator/orchestrator/tasks.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""
|
||||||
|
Celery task definitions for the Konstruct Agent Orchestrator.
|
||||||
|
|
||||||
|
# CELERY TASKS MUST BE SYNC def — async def causes RuntimeError or silent hang.
|
||||||
|
# Use asyncio.run() for async work. This is a fundamental Celery constraint:
|
||||||
|
# Celery workers are NOT async-native. The handle_message task bridges the
|
||||||
|
# sync Celery world to the async agent pipeline via asyncio.run().
|
||||||
|
#
|
||||||
|
# NEVER change these to `async def`. If you see a RuntimeError about "no
|
||||||
|
# running event loop" or tasks that silently never complete, check for
|
||||||
|
# accidental async def usage first.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from orchestrator.main import app
|
||||||
|
from shared.models.message import KonstructMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(
|
||||||
|
name="orchestrator.tasks.handle_message",
|
||||||
|
bind=True,
|
||||||
|
max_retries=3,
|
||||||
|
default_retry_delay=5,
|
||||||
|
)
|
||||||
|
def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped-def]
|
||||||
|
"""
|
||||||
|
Process an inbound Konstruct message through the agent pipeline.
|
||||||
|
|
||||||
|
This task is the primary entry point for the Celery worker. It is dispatched
|
||||||
|
by the Message Router (or Channel Gateway in simple deployments) after tenant
|
||||||
|
resolution completes.
|
||||||
|
|
||||||
|
Pipeline:
|
||||||
|
1. Deserialize message_data -> KonstructMessage
|
||||||
|
2. Run async agent pipeline via asyncio.run()
|
||||||
|
3. Return response dict
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_data: JSON-serializable dict representation of a KonstructMessage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with keys:
|
||||||
|
- message_id (str): Original message ID
|
||||||
|
- response (str): Agent's response text
|
||||||
|
- tenant_id (str | None): Tenant that handled the message
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
msg = KonstructMessage.model_validate(message_data)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to deserialize KonstructMessage: %s", message_data)
|
||||||
|
raise self.retry(exc=exc)
|
||||||
|
|
||||||
|
result = asyncio.run(_process_message(msg))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def _process_message(msg: KonstructMessage) -> dict:
|
||||||
|
"""
|
||||||
|
Async agent pipeline — load agent config, build prompt, call LLM pool.
|
||||||
|
|
||||||
|
This function is called from the synchronous handle_message task via
|
||||||
|
asyncio.run(). It must not be called directly from Celery task code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg: The deserialized KonstructMessage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with message_id, response, and tenant_id.
|
||||||
|
"""
|
||||||
|
from orchestrator.agents.runner import run_agent
|
||||||
|
from shared.db import async_session_factory, engine
|
||||||
|
from shared.models.tenant import Agent
|
||||||
|
from shared.rls import configure_rls_hook, current_tenant_id
|
||||||
|
|
||||||
|
if msg.tenant_id is None:
|
||||||
|
logger.warning("Message %s has no tenant_id — cannot process", msg.id)
|
||||||
|
return {
|
||||||
|
"message_id": msg.id,
|
||||||
|
"response": "Unable to process: tenant not identified.",
|
||||||
|
"tenant_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set up RLS engine hook (idempotent — safe to call on every task)
|
||||||
|
configure_rls_hook(engine)
|
||||||
|
|
||||||
|
# Set the RLS context variable for this async task's context
|
||||||
|
tenant_uuid = uuid.UUID(msg.tenant_id)
|
||||||
|
token = current_tenant_id.set(tenant_uuid)
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent: Agent | None = None
|
||||||
|
async with async_session_factory() as session:
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(Agent)
|
||||||
|
.where(Agent.tenant_id == tenant_uuid)
|
||||||
|
.where(Agent.is_active.is_(True))
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
agent = result.scalars().first()
|
||||||
|
finally:
|
||||||
|
# Always reset the RLS context var after DB work is done
|
||||||
|
current_tenant_id.reset(token)
|
||||||
|
|
||||||
|
if agent is None:
|
||||||
|
logger.warning(
|
||||||
|
"No active agent found for tenant=%s message=%s",
|
||||||
|
msg.tenant_id,
|
||||||
|
msg.id,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"message_id": msg.id,
|
||||||
|
"response": "No active agent is configured for your workspace. Please contact your administrator.",
|
||||||
|
"tenant_id": msg.tenant_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
response_text = await run_agent(msg, agent)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Message %s processed by agent=%s tenant=%s",
|
||||||
|
msg.id,
|
||||||
|
agent.id,
|
||||||
|
msg.tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message_id": msg.id,
|
||||||
|
"response": response_text,
|
||||||
|
"tenant_id": msg.tenant_id,
|
||||||
|
}
|
||||||
@@ -10,6 +10,7 @@ requires-python = ">=3.12"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"konstruct-shared",
|
"konstruct-shared",
|
||||||
"fastapi[standard]>=0.115.0",
|
"fastapi[standard]>=0.115.0",
|
||||||
|
"celery[redis]>=5.4.0",
|
||||||
"httpx>=0.28.0",
|
"httpx>=0.28.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
180
tests/integration/test_llm_fallback.py
Normal file
180
tests/integration/test_llm_fallback.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for LLM Router fallback routing (LLM-01).
|
||||||
|
|
||||||
|
Tests verify that:
|
||||||
|
1. When the primary quality provider (Anthropic) fails, the router falls back
|
||||||
|
to the secondary quality provider (OpenAI).
|
||||||
|
2. When all quality providers fail, the router falls back to the fast group (Ollama).
|
||||||
|
3. When ALL providers fail, the /complete endpoint returns HTTP 503.
|
||||||
|
|
||||||
|
These tests mock LiteLLM Router.acompletion to control which providers fail
|
||||||
|
without requiring live API keys or a running Ollama instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from llm_pool.main import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_completion_response(content: str = "Hello from mock") -> MagicMock:
|
||||||
|
"""Build a fake LiteLLM completion response object."""
|
||||||
|
response = MagicMock()
|
||||||
|
response.choices = [MagicMock()]
|
||||||
|
response.choices[0].message.content = content
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMFallbackRouting:
|
||||||
|
"""LLM-01: Fallback routing — primary fail -> secondary -> fast group."""
|
||||||
|
|
||||||
|
def test_quality_returns_response_on_success(self) -> None:
|
||||||
|
"""Happy path: quality request completes without any fallback needed."""
|
||||||
|
mock_response = _make_completion_response("Anthropic response")
|
||||||
|
|
||||||
|
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||||
|
mock_complete.return_value = mock_response
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/complete",
|
||||||
|
json={
|
||||||
|
"model": "quality",
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"tenant_id": "tenant-123",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["content"] == "Anthropic response"
|
||||||
|
assert data["model"] == "quality"
|
||||||
|
|
||||||
|
def test_fast_group_returns_response_on_success(self) -> None:
|
||||||
|
"""Happy path: fast (Ollama) request completes normally."""
|
||||||
|
mock_response = _make_completion_response("Ollama response")
|
||||||
|
|
||||||
|
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||||
|
mock_complete.return_value = mock_response
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/complete",
|
||||||
|
json={
|
||||||
|
"model": "fast",
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"tenant_id": "tenant-456",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["content"] == "Ollama response"
|
||||||
|
|
||||||
|
def test_router_acompletion_called_with_correct_model_group(self) -> None:
|
||||||
|
"""Verify the router receives the exact model group name from the request."""
|
||||||
|
mock_response = _make_completion_response("test")
|
||||||
|
|
||||||
|
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||||
|
mock_complete.return_value = mock_response
|
||||||
|
|
||||||
|
client.post(
|
||||||
|
"/complete",
|
||||||
|
json={
|
||||||
|
"model": "quality",
|
||||||
|
"messages": [{"role": "user", "content": "test"}],
|
||||||
|
"tenant_id": "tenant-789",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
call_kwargs = mock_complete.call_args
|
||||||
|
assert call_kwargs is not None
|
||||||
|
# model= is the first positional-or-keyword arg
|
||||||
|
assert call_kwargs.kwargs.get("model") == "quality" or call_kwargs.args[0] == "quality"
|
||||||
|
|
||||||
|
def test_fallback_succeeds_when_router_returns_response(self) -> None:
|
||||||
|
"""
|
||||||
|
Verify that when LiteLLM Router resolves fallback internally and returns
|
||||||
|
a valid response, the endpoint returns HTTP 200.
|
||||||
|
|
||||||
|
LiteLLM Router handles provider-level retries and cross-group fallback
|
||||||
|
internally (via its fallbacks= config). From our service's perspective,
|
||||||
|
Router.acompletion() either succeeds (any provider in the chain worked)
|
||||||
|
or raises (all providers exhausted). This test verifies the success path
|
||||||
|
where the router succeeded after internal fallback.
|
||||||
|
"""
|
||||||
|
# Router resolved fallback internally and returns a successful response
|
||||||
|
mock_response = _make_completion_response("Fallback resolved by LiteLLM")
|
||||||
|
|
||||||
|
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||||
|
mock_complete.return_value = mock_response
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/complete",
|
||||||
|
json={
|
||||||
|
"model": "quality",
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"tenant_id": "tenant-fallback",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["content"] == "Fallback resolved by LiteLLM"
|
||||||
|
|
||||||
|
def test_503_returned_when_all_providers_fail(self) -> None:
|
||||||
|
"""
|
||||||
|
When every provider in the chain fails, /complete returns HTTP 503.
|
||||||
|
|
||||||
|
This maps to the must_have truth:
|
||||||
|
"When the primary provider is unavailable, the LLM pool automatically
|
||||||
|
falls back to the next provider in the chain."
|
||||||
|
— and when the chain is fully exhausted, a 503 must be returned.
|
||||||
|
"""
|
||||||
|
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||||
|
mock_complete.side_effect = Exception("All providers down")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/complete",
|
||||||
|
json={
|
||||||
|
"model": "quality",
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"tenant_id": "tenant-allfail",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 503
|
||||||
|
data = response.json()
|
||||||
|
assert data["error"] == "All providers unavailable"
|
||||||
|
|
||||||
|
def test_tenant_id_passed_to_router_as_metadata(self) -> None:
|
||||||
|
"""Verify tenant_id is forwarded as metadata to the LiteLLM Router for cost tracking."""
|
||||||
|
mock_response = _make_completion_response("ok")
|
||||||
|
|
||||||
|
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||||
|
mock_complete.return_value = mock_response
|
||||||
|
|
||||||
|
client.post(
|
||||||
|
"/complete",
|
||||||
|
json={
|
||||||
|
"model": "fast",
|
||||||
|
"messages": [{"role": "user", "content": "Hi"}],
|
||||||
|
"tenant_id": "tenant-cost-track",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
call_kwargs = mock_complete.call_args
|
||||||
|
assert call_kwargs is not None
|
||||||
|
metadata = call_kwargs.kwargs.get("metadata", {})
|
||||||
|
assert metadata.get("tenant_id") == "tenant-cost-track"
|
||||||
|
|
||||||
|
def test_health_endpoint_returns_ok(self) -> None:
|
||||||
|
"""Liveness probe should return 200 {status: ok} with no external calls."""
|
||||||
|
response = client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"status": "ok"}
|
||||||
172
tests/integration/test_llm_providers.py
Normal file
172
tests/integration/test_llm_providers.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for LLM provider configuration (LLM-02).
|
||||||
|
|
||||||
|
Tests verify that:
|
||||||
|
1. The LiteLLM Router model_list contains entries for all three providers
|
||||||
|
(Ollama/fast, Anthropic/quality, OpenAI/quality).
|
||||||
|
2. A request with model="fast" routes to the Ollama configuration.
|
||||||
|
3. A request with model="quality" routes to an Anthropic or OpenAI configuration.
|
||||||
|
4. Provider entries reference the correct model identifiers from CLAUDE.md.
|
||||||
|
|
||||||
|
These tests inspect the router configuration directly and mock acompletion to
|
||||||
|
verify routing without live API calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from llm_pool.main import app
|
||||||
|
from llm_pool.router import _model_list, llm_router
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_completion_response(content: str = "test") -> MagicMock:
|
||||||
|
response = MagicMock()
|
||||||
|
response.choices = [MagicMock()]
|
||||||
|
response.choices[0].message.content = content
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderConfiguration:
|
||||||
|
"""LLM-02: Provider configuration — all three providers are present and correct."""
|
||||||
|
|
||||||
|
def test_model_list_has_three_entries(self) -> None:
|
||||||
|
"""The model_list must have exactly three entries (fast, quality x2)."""
|
||||||
|
assert len(_model_list) == 3
|
||||||
|
|
||||||
|
def test_fast_group_present_in_model_list(self) -> None:
|
||||||
|
"""The 'fast' model group must exist in the model_list."""
|
||||||
|
fast_entries = [m for m in _model_list if m["model_name"] == "fast"]
|
||||||
|
assert len(fast_entries) >= 1, "No 'fast' model group found in model_list"
|
||||||
|
|
||||||
|
def test_quality_group_present_in_model_list(self) -> None:
|
||||||
|
"""The 'quality' model group must have at least two entries (Anthropic + OpenAI)."""
|
||||||
|
quality_entries = [m for m in _model_list if m["model_name"] == "quality"]
|
||||||
|
assert len(quality_entries) >= 2, "Expected at least 2 'quality' entries (Anthropic + OpenAI)"
|
||||||
|
|
||||||
|
def test_fast_group_uses_ollama_model(self) -> None:
|
||||||
|
"""The fast group must route to an ollama/* model."""
|
||||||
|
fast_entries = [m for m in _model_list if m["model_name"] == "fast"]
|
||||||
|
assert fast_entries, "No fast entry found"
|
||||||
|
ollama_models = [
|
||||||
|
e for e in fast_entries
|
||||||
|
if e["litellm_params"]["model"].startswith("ollama/")
|
||||||
|
]
|
||||||
|
assert ollama_models, f"Fast group does not use an ollama model: {fast_entries}"
|
||||||
|
|
||||||
|
def test_fast_group_has_ollama_api_base(self) -> None:
|
||||||
|
"""The fast group entry must specify an api_base pointing to Ollama."""
|
||||||
|
fast_entries = [m for m in _model_list if m["model_name"] == "fast"]
|
||||||
|
for entry in fast_entries:
|
||||||
|
params = entry["litellm_params"]
|
||||||
|
assert "api_base" in params, f"Fast group entry missing api_base: {entry}"
|
||||||
|
|
||||||
|
def test_quality_group_has_anthropic_entry(self) -> None:
|
||||||
|
"""Quality group must include an anthropic/* model."""
|
||||||
|
quality_entries = [m for m in _model_list if m["model_name"] == "quality"]
|
||||||
|
anthropic_entries = [
|
||||||
|
e for e in quality_entries
|
||||||
|
if e["litellm_params"]["model"].startswith("anthropic/")
|
||||||
|
]
|
||||||
|
assert anthropic_entries, f"No Anthropic entry in quality group: {quality_entries}"
|
||||||
|
|
||||||
|
def test_quality_group_has_openai_entry(self) -> None:
|
||||||
|
"""Quality group must include an openai/* model as the fallback."""
|
||||||
|
quality_entries = [m for m in _model_list if m["model_name"] == "quality"]
|
||||||
|
openai_entries = [
|
||||||
|
e for e in quality_entries
|
||||||
|
if e["litellm_params"]["model"].startswith("openai/")
|
||||||
|
]
|
||||||
|
assert openai_entries, f"No OpenAI entry in quality group: {quality_entries}"
|
||||||
|
|
||||||
|
def test_anthropic_model_is_claude_sonnet(self) -> None:
|
||||||
|
"""Anthropic entry must use the correct model from CLAUDE.md architecture."""
|
||||||
|
quality_entries = [m for m in _model_list if m["model_name"] == "quality"]
|
||||||
|
anthropic_entry = next(
|
||||||
|
(e for e in quality_entries if e["litellm_params"]["model"].startswith("anthropic/")),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert anthropic_entry is not None
|
||||||
|
model = anthropic_entry["litellm_params"]["model"]
|
||||||
|
assert "claude-sonnet" in model, f"Expected claude-sonnet model, got: {model}"
|
||||||
|
|
||||||
|
def test_openai_model_is_gpt4o(self) -> None:
|
||||||
|
"""OpenAI entry must use gpt-4o as specified in architecture."""
|
||||||
|
quality_entries = [m for m in _model_list if m["model_name"] == "quality"]
|
||||||
|
openai_entry = next(
|
||||||
|
(e for e in quality_entries if e["litellm_params"]["model"].startswith("openai/")),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert openai_entry is not None
|
||||||
|
model = openai_entry["litellm_params"]["model"]
|
||||||
|
assert "gpt-4o" in model, f"Expected gpt-4o model, got: {model}"
|
||||||
|
|
||||||
|
def test_fast_request_calls_acompletion_with_fast_model(self) -> None:
|
||||||
|
"""A fast model request must invoke acompletion with model='fast'."""
|
||||||
|
mock_response = _make_completion_response("ollama says hi")
|
||||||
|
|
||||||
|
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||||
|
mock_complete.return_value = mock_response
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/complete",
|
||||||
|
json={
|
||||||
|
"model": "fast",
|
||||||
|
"messages": [{"role": "user", "content": "Hi"}],
|
||||||
|
"tenant_id": "tenant-fast",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
call_kwargs = mock_complete.call_args
|
||||||
|
assert call_kwargs is not None
|
||||||
|
called_model = call_kwargs.kwargs.get("model") or (call_kwargs.args[0] if call_kwargs.args else None)
|
||||||
|
assert called_model == "fast"
|
||||||
|
|
||||||
|
def test_quality_request_calls_acompletion_with_quality_model(self) -> None:
|
||||||
|
"""A quality model request must invoke acompletion with model='quality'."""
|
||||||
|
mock_response = _make_completion_response("anthropic says hi")
|
||||||
|
|
||||||
|
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||||
|
mock_complete.return_value = mock_response
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/complete",
|
||||||
|
json={
|
||||||
|
"model": "quality",
|
||||||
|
"messages": [{"role": "user", "content": "Hi"}],
|
||||||
|
"tenant_id": "tenant-quality",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
call_kwargs = mock_complete.call_args
|
||||||
|
assert call_kwargs is not None
|
||||||
|
called_model = call_kwargs.kwargs.get("model") or (call_kwargs.args[0] if call_kwargs.args else None)
|
||||||
|
assert called_model == "quality"
|
||||||
|
|
||||||
|
def test_router_fallback_config_quality_falls_to_fast(self) -> None:
|
||||||
|
"""The Router fallbacks config must specify quality -> fast cross-group fallback."""
|
||||||
|
# Access the Router's fallbacks attribute
|
||||||
|
fallbacks = getattr(llm_router, "fallbacks", None)
|
||||||
|
assert fallbacks is not None, "Router has no fallbacks configured"
|
||||||
|
|
||||||
|
# Find the quality -> fast fallback entry
|
||||||
|
quality_fallback = None
|
||||||
|
for fb in fallbacks:
|
||||||
|
if isinstance(fb, dict) and "quality" in fb:
|
||||||
|
quality_fallback = fb
|
||||||
|
break
|
||||||
|
|
||||||
|
assert quality_fallback is not None, (
|
||||||
|
f"No quality->fast fallback found. Current fallbacks: {fallbacks}"
|
||||||
|
)
|
||||||
|
fallback_targets = quality_fallback["quality"]
|
||||||
|
assert "fast" in fallback_targets, (
|
||||||
|
f"Quality fallback does not target 'fast' group: {fallback_targets}"
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user