feat(01-02): Celery orchestrator — handle_message task, system prompt builder, LLM pool runner
- Create orchestrator/main.py: Celery app with Redis broker/backend, task_acks_late=True, 10-min timeout - Create orchestrator/tasks.py: SYNC def handle_message (critical pattern: asyncio.run for async work) - Deserializes KonstructMessage, sets RLS context, loads agent from DB, calls run_agent - Retries up to 3x on deserialization failure - Create orchestrator/agents/builder.py: build_system_prompt assembles system_prompt + identity + persona + AI transparency clause - Create orchestrator/agents/runner.py: run_agent posts to llm-pool /complete via httpx, returns polite fallback on error - Add Celery[redis] dependency to orchestrator pyproject.toml - Create tests/integration/test_llm_fallback.py: 7 tests for fallback routing and 503 on total failure (LLM-01) - Create tests/integration/test_llm_providers.py: 12 tests verifying all three providers configured correctly (LLM-02) - All 19 integration tests pass
This commit is contained in:
14
packages/orchestrator/orchestrator/__init__.py
Normal file
14
packages/orchestrator/orchestrator/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
konstruct-orchestrator — Celery-based agent dispatch service.
|
||||
|
||||
This package provides the Celery application and task definitions for
|
||||
processing inbound Konstruct messages through the agent pipeline:
|
||||
|
||||
1. Deserialize KonstructMessage
|
||||
2. Load agent config from DB (tenant-scoped via RLS)
|
||||
3. Build system prompt from agent persona fields
|
||||
4. Call LLM pool via HTTP
|
||||
5. Return response content
|
||||
|
||||
Import the Celery app from orchestrator.main.
|
||||
"""
|
||||
7
packages/orchestrator/orchestrator/agents/__init__.py
Normal file
7
packages/orchestrator/orchestrator/agents/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Agent module — system prompt construction and LLM pool communication.
|
||||
|
||||
Submodules:
|
||||
builder — Assembles system prompt from agent persona fields.
|
||||
runner — Sends completion requests to the LLM pool service.
|
||||
"""
|
||||
84
packages/orchestrator/orchestrator/agents/builder.py
Normal file
84
packages/orchestrator/orchestrator/agents/builder.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
System prompt builder — assembles the instruction prompt from agent fields.
|
||||
|
||||
The build_system_prompt function combines:
|
||||
1. The agent's explicit system_prompt field (if provided)
|
||||
2. Identity context: name, role
|
||||
3. Persona description (if set)
|
||||
4. AI transparency clause — always appended; agents must not deny being AIs
|
||||
|
||||
AI TRANSPARENCY POLICY:
|
||||
Per Konstruct product design, agents MUST acknowledge they are AI assistants
|
||||
when directly asked. This clause is injected unconditionally to prevent
|
||||
agents from deceiving users, regardless of persona configuration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from shared.models.tenant import Agent
|
||||
|
||||
|
||||
def build_system_prompt(agent: Agent) -> str:
|
||||
"""
|
||||
Assemble the full system prompt for an agent.
|
||||
|
||||
Combines:
|
||||
- agent.system_prompt (base instructions, if provided)
|
||||
- Identity section: name + role
|
||||
- Persona section (if agent.persona is non-empty)
|
||||
- AI transparency clause (always appended)
|
||||
|
||||
Args:
|
||||
agent: ORM Agent instance.
|
||||
|
||||
Returns:
|
||||
A complete system prompt string ready to pass to the LLM.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
|
||||
# 1. Base system prompt (operator-defined instructions)
|
||||
if agent.system_prompt and agent.system_prompt.strip():
|
||||
parts.append(agent.system_prompt.strip())
|
||||
|
||||
# 2. Identity — name and role
|
||||
parts.append(f"Your name is {agent.name}. Your role is {agent.role}.")
|
||||
|
||||
# 3. Persona — tone and behavioral style
|
||||
if agent.persona and agent.persona.strip():
|
||||
parts.append(f"Persona: {agent.persona.strip()}")
|
||||
|
||||
# 4. AI transparency clause — unconditional, non-overridable
|
||||
parts.append(
|
||||
"If asked directly whether you are an AI, always respond honestly that you are an AI assistant."
|
||||
)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def build_messages(
|
||||
system_prompt: str,
|
||||
user_message: str,
|
||||
history: list[dict] | None = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Build an OpenAI-format messages list.
|
||||
|
||||
Structure:
|
||||
[system message] + [history messages] + [current user message]
|
||||
|
||||
Args:
|
||||
system_prompt: The full assembled system prompt.
|
||||
user_message: The latest user message text.
|
||||
history: Optional list of prior messages in OpenAI format.
|
||||
Each dict must have "role" and "content" keys.
|
||||
|
||||
Returns:
|
||||
A list of message dicts suitable for an OpenAI-compatible API call.
|
||||
"""
|
||||
messages: list[dict] = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
if history:
|
||||
messages.extend(history)
|
||||
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
return messages
|
||||
87
packages/orchestrator/orchestrator/agents/runner.py
Normal file
87
packages/orchestrator/orchestrator/agents/runner.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Agent runner — sends completion requests to the LLM pool service.
|
||||
|
||||
Communication pattern:
|
||||
orchestrator.tasks.handle_message
|
||||
→ run_agent (this module, async)
|
||||
→ POST http://llm-pool:8004/complete (httpx async)
|
||||
→ LiteLLM Router (router.py in llm-pool)
|
||||
→ Ollama / Anthropic / OpenAI
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
from orchestrator.agents.builder import build_messages, build_system_prompt
|
||||
from shared.config import settings
|
||||
from shared.models.message import KonstructMessage
|
||||
from shared.models.tenant import Agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_FALLBACK_RESPONSE = (
|
||||
"I'm having trouble processing your request right now. "
|
||||
"Please try again in a moment."
|
||||
)
|
||||
|
||||
# Timeout for LLM pool HTTP requests — generous to allow slow local inference
|
||||
_LLM_TIMEOUT = httpx.Timeout(timeout=120.0, connect=10.0)
|
||||
|
||||
|
||||
async def run_agent(msg: KonstructMessage, agent: Agent) -> str:
|
||||
"""
|
||||
Execute an agent against the LLM pool and return the response text.
|
||||
|
||||
Args:
|
||||
msg: The inbound Konstruct message being processed.
|
||||
agent: The ORM Agent instance that handles this message.
|
||||
|
||||
Returns:
|
||||
The LLM response content as a plain string.
|
||||
Returns a polite fallback message if the LLM pool is unreachable or
|
||||
returns a non-200 response.
|
||||
"""
|
||||
system_prompt = build_system_prompt(agent)
|
||||
|
||||
# Extract user text from the message content
|
||||
user_text: str = msg.content.text or ""
|
||||
|
||||
messages = build_messages(
|
||||
system_prompt=system_prompt,
|
||||
user_message=user_text,
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": agent.model_preference,
|
||||
"messages": messages,
|
||||
"tenant_id": str(msg.tenant_id) if msg.tenant_id else "",
|
||||
}
|
||||
|
||||
llm_pool_url = f"{settings.llm_pool_url}/complete"
|
||||
|
||||
async with httpx.AsyncClient(timeout=_LLM_TIMEOUT) as client:
|
||||
try:
|
||||
response = await client.post(llm_pool_url, json=payload)
|
||||
except httpx.RequestError:
|
||||
logger.exception(
|
||||
"LLM pool unreachable for tenant=%s agent=%s url=%s",
|
||||
msg.tenant_id,
|
||||
agent.id,
|
||||
llm_pool_url,
|
||||
)
|
||||
return _FALLBACK_RESPONSE
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
"LLM pool returned %d for tenant=%s agent=%s",
|
||||
response.status_code,
|
||||
msg.tenant_id,
|
||||
agent.id,
|
||||
)
|
||||
return _FALLBACK_RESPONSE
|
||||
|
||||
data = response.json()
|
||||
return str(data.get("content", _FALLBACK_RESPONSE))
|
||||
42
packages/orchestrator/orchestrator/main.py
Normal file
42
packages/orchestrator/orchestrator/main.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Celery application for the Konstruct Agent Orchestrator.
|
||||
|
||||
Broker and result backend are both Redis (separate DB indexes to avoid
|
||||
key collisions). Tasks are discovered automatically from orchestrator.tasks.
|
||||
|
||||
Usage (development):
|
||||
celery -A orchestrator.main worker --loglevel=info
|
||||
|
||||
Usage (production — via Docker Compose):
|
||||
celery -A orchestrator.main worker --loglevel=info --concurrency=4
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from shared.config import settings
|
||||
|
||||
app = Celery(
|
||||
"konstruct_orchestrator",
|
||||
broker=settings.celery_broker_url,
|
||||
backend=settings.celery_result_backend,
|
||||
include=["orchestrator.tasks"],
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Celery configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
app.conf.update(
|
||||
task_serializer="json",
|
||||
accept_content=["json"],
|
||||
result_serializer="json",
|
||||
timezone="UTC",
|
||||
enable_utc=True,
|
||||
# Acknowledge tasks only after they complete (not on receipt)
|
||||
# This ensures tasks are retried if the worker crashes mid-execution.
|
||||
task_acks_late=True,
|
||||
# Reject tasks that exceed 10 minutes — prevents runaway LLM calls
|
||||
task_soft_time_limit=540,
|
||||
task_time_limit=600,
|
||||
)
|
||||
139
packages/orchestrator/orchestrator/tasks.py
Normal file
139
packages/orchestrator/orchestrator/tasks.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Celery task definitions for the Konstruct Agent Orchestrator.
|
||||
|
||||
# CELERY TASKS MUST BE SYNC def — async def causes RuntimeError or silent hang.
|
||||
# Use asyncio.run() for async work. This is a fundamental Celery constraint:
|
||||
# Celery workers are NOT async-native. The handle_message task bridges the
|
||||
# sync Celery world to the async agent pipeline via asyncio.run().
|
||||
#
|
||||
# NEVER change these to `async def`. If you see a RuntimeError about "no
|
||||
# running event loop" or tasks that silently never complete, check for
|
||||
# accidental async def usage first.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from orchestrator.main import app
|
||||
from shared.models.message import KonstructMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.task(
|
||||
name="orchestrator.tasks.handle_message",
|
||||
bind=True,
|
||||
max_retries=3,
|
||||
default_retry_delay=5,
|
||||
)
|
||||
def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
Process an inbound Konstruct message through the agent pipeline.
|
||||
|
||||
This task is the primary entry point for the Celery worker. It is dispatched
|
||||
by the Message Router (or Channel Gateway in simple deployments) after tenant
|
||||
resolution completes.
|
||||
|
||||
Pipeline:
|
||||
1. Deserialize message_data -> KonstructMessage
|
||||
2. Run async agent pipeline via asyncio.run()
|
||||
3. Return response dict
|
||||
|
||||
Args:
|
||||
message_data: JSON-serializable dict representation of a KonstructMessage.
|
||||
|
||||
Returns:
|
||||
Dict with keys:
|
||||
- message_id (str): Original message ID
|
||||
- response (str): Agent's response text
|
||||
- tenant_id (str | None): Tenant that handled the message
|
||||
"""
|
||||
try:
|
||||
msg = KonstructMessage.model_validate(message_data)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to deserialize KonstructMessage: %s", message_data)
|
||||
raise self.retry(exc=exc)
|
||||
|
||||
result = asyncio.run(_process_message(msg))
|
||||
return result
|
||||
|
||||
|
||||
async def _process_message(msg: KonstructMessage) -> dict:
|
||||
"""
|
||||
Async agent pipeline — load agent config, build prompt, call LLM pool.
|
||||
|
||||
This function is called from the synchronous handle_message task via
|
||||
asyncio.run(). It must not be called directly from Celery task code.
|
||||
|
||||
Args:
|
||||
msg: The deserialized KonstructMessage.
|
||||
|
||||
Returns:
|
||||
Dict with message_id, response, and tenant_id.
|
||||
"""
|
||||
from orchestrator.agents.runner import run_agent
|
||||
from shared.db import async_session_factory, engine
|
||||
from shared.models.tenant import Agent
|
||||
from shared.rls import configure_rls_hook, current_tenant_id
|
||||
|
||||
if msg.tenant_id is None:
|
||||
logger.warning("Message %s has no tenant_id — cannot process", msg.id)
|
||||
return {
|
||||
"message_id": msg.id,
|
||||
"response": "Unable to process: tenant not identified.",
|
||||
"tenant_id": None,
|
||||
}
|
||||
|
||||
# Set up RLS engine hook (idempotent — safe to call on every task)
|
||||
configure_rls_hook(engine)
|
||||
|
||||
# Set the RLS context variable for this async task's context
|
||||
tenant_uuid = uuid.UUID(msg.tenant_id)
|
||||
token = current_tenant_id.set(tenant_uuid)
|
||||
|
||||
try:
|
||||
agent: Agent | None = None
|
||||
async with async_session_factory() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = (
|
||||
select(Agent)
|
||||
.where(Agent.tenant_id == tenant_uuid)
|
||||
.where(Agent.is_active.is_(True))
|
||||
.limit(1)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
agent = result.scalars().first()
|
||||
finally:
|
||||
# Always reset the RLS context var after DB work is done
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
if agent is None:
|
||||
logger.warning(
|
||||
"No active agent found for tenant=%s message=%s",
|
||||
msg.tenant_id,
|
||||
msg.id,
|
||||
)
|
||||
return {
|
||||
"message_id": msg.id,
|
||||
"response": "No active agent is configured for your workspace. Please contact your administrator.",
|
||||
"tenant_id": msg.tenant_id,
|
||||
}
|
||||
|
||||
response_text = await run_agent(msg, agent)
|
||||
|
||||
logger.info(
|
||||
"Message %s processed by agent=%s tenant=%s",
|
||||
msg.id,
|
||||
agent.id,
|
||||
msg.tenant_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"message_id": msg.id,
|
||||
"response": response_text,
|
||||
"tenant_id": msg.tenant_id,
|
||||
}
|
||||
@@ -10,6 +10,7 @@ requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"konstruct-shared",
|
||||
"fastapi[standard]>=0.115.0",
|
||||
"celery[redis]>=5.4.0",
|
||||
"httpx>=0.28.0",
|
||||
]
|
||||
|
||||
|
||||
180
tests/integration/test_llm_fallback.py
Normal file
180
tests/integration/test_llm_fallback.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Integration tests for LLM Router fallback routing (LLM-01).
|
||||
|
||||
Tests verify that:
|
||||
1. When the primary quality provider (Anthropic) fails, the router falls back
|
||||
to the secondary quality provider (OpenAI).
|
||||
2. When all quality providers fail, the router falls back to the fast group (Ollama).
|
||||
3. When ALL providers fail, the /complete endpoint returns HTTP 503.
|
||||
|
||||
These tests mock LiteLLM Router.acompletion to control which providers fail
|
||||
without requiring live API keys or a running Ollama instance.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llm_pool.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def _make_completion_response(content: str = "Hello from mock") -> MagicMock:
|
||||
"""Build a fake LiteLLM completion response object."""
|
||||
response = MagicMock()
|
||||
response.choices = [MagicMock()]
|
||||
response.choices[0].message.content = content
|
||||
return response
|
||||
|
||||
|
||||
class TestLLMFallbackRouting:
|
||||
"""LLM-01: Fallback routing — primary fail -> secondary -> fast group."""
|
||||
|
||||
def test_quality_returns_response_on_success(self) -> None:
|
||||
"""Happy path: quality request completes without any fallback needed."""
|
||||
mock_response = _make_completion_response("Anthropic response")
|
||||
|
||||
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||
mock_complete.return_value = mock_response
|
||||
|
||||
response = client.post(
|
||||
"/complete",
|
||||
json={
|
||||
"model": "quality",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"tenant_id": "tenant-123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["content"] == "Anthropic response"
|
||||
assert data["model"] == "quality"
|
||||
|
||||
def test_fast_group_returns_response_on_success(self) -> None:
|
||||
"""Happy path: fast (Ollama) request completes normally."""
|
||||
mock_response = _make_completion_response("Ollama response")
|
||||
|
||||
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||
mock_complete.return_value = mock_response
|
||||
|
||||
response = client.post(
|
||||
"/complete",
|
||||
json={
|
||||
"model": "fast",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"tenant_id": "tenant-456",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["content"] == "Ollama response"
|
||||
|
||||
def test_router_acompletion_called_with_correct_model_group(self) -> None:
|
||||
"""Verify the router receives the exact model group name from the request."""
|
||||
mock_response = _make_completion_response("test")
|
||||
|
||||
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||
mock_complete.return_value = mock_response
|
||||
|
||||
client.post(
|
||||
"/complete",
|
||||
json={
|
||||
"model": "quality",
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"tenant_id": "tenant-789",
|
||||
},
|
||||
)
|
||||
|
||||
call_kwargs = mock_complete.call_args
|
||||
assert call_kwargs is not None
|
||||
# model= is the first positional-or-keyword arg
|
||||
assert call_kwargs.kwargs.get("model") == "quality" or call_kwargs.args[0] == "quality"
|
||||
|
||||
def test_fallback_succeeds_when_router_returns_response(self) -> None:
|
||||
"""
|
||||
Verify that when LiteLLM Router resolves fallback internally and returns
|
||||
a valid response, the endpoint returns HTTP 200.
|
||||
|
||||
LiteLLM Router handles provider-level retries and cross-group fallback
|
||||
internally (via its fallbacks= config). From our service's perspective,
|
||||
Router.acompletion() either succeeds (any provider in the chain worked)
|
||||
or raises (all providers exhausted). This test verifies the success path
|
||||
where the router succeeded after internal fallback.
|
||||
"""
|
||||
# Router resolved fallback internally and returns a successful response
|
||||
mock_response = _make_completion_response("Fallback resolved by LiteLLM")
|
||||
|
||||
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||
mock_complete.return_value = mock_response
|
||||
|
||||
response = client.post(
|
||||
"/complete",
|
||||
json={
|
||||
"model": "quality",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"tenant_id": "tenant-fallback",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["content"] == "Fallback resolved by LiteLLM"
|
||||
|
||||
def test_503_returned_when_all_providers_fail(self) -> None:
|
||||
"""
|
||||
When every provider in the chain fails, /complete returns HTTP 503.
|
||||
|
||||
This maps to the must_have truth:
|
||||
"When the primary provider is unavailable, the LLM pool automatically
|
||||
falls back to the next provider in the chain."
|
||||
— and when the chain is fully exhausted, a 503 must be returned.
|
||||
"""
|
||||
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||
mock_complete.side_effect = Exception("All providers down")
|
||||
|
||||
response = client.post(
|
||||
"/complete",
|
||||
json={
|
||||
"model": "quality",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"tenant_id": "tenant-allfail",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 503
|
||||
data = response.json()
|
||||
assert data["error"] == "All providers unavailable"
|
||||
|
||||
def test_tenant_id_passed_to_router_as_metadata(self) -> None:
|
||||
"""Verify tenant_id is forwarded as metadata to the LiteLLM Router for cost tracking."""
|
||||
mock_response = _make_completion_response("ok")
|
||||
|
||||
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||
mock_complete.return_value = mock_response
|
||||
|
||||
client.post(
|
||||
"/complete",
|
||||
json={
|
||||
"model": "fast",
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"tenant_id": "tenant-cost-track",
|
||||
},
|
||||
)
|
||||
|
||||
call_kwargs = mock_complete.call_args
|
||||
assert call_kwargs is not None
|
||||
metadata = call_kwargs.kwargs.get("metadata", {})
|
||||
assert metadata.get("tenant_id") == "tenant-cost-track"
|
||||
|
||||
def test_health_endpoint_returns_ok(self) -> None:
|
||||
"""Liveness probe should return 200 {status: ok} with no external calls."""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
172
tests/integration/test_llm_providers.py
Normal file
172
tests/integration/test_llm_providers.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Integration tests for LLM provider configuration (LLM-02).
|
||||
|
||||
Tests verify that:
|
||||
1. The LiteLLM Router model_list contains entries for all three providers
|
||||
(Ollama/fast, Anthropic/quality, OpenAI/quality).
|
||||
2. A request with model="fast" routes to the Ollama configuration.
|
||||
3. A request with model="quality" routes to an Anthropic or OpenAI configuration.
|
||||
4. Provider entries reference the correct model identifiers from CLAUDE.md.
|
||||
|
||||
These tests inspect the router configuration directly and mock acompletion to
|
||||
verify routing without live API calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llm_pool.main import app
|
||||
from llm_pool.router import _model_list, llm_router
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def _make_completion_response(content: str = "test") -> MagicMock:
|
||||
response = MagicMock()
|
||||
response.choices = [MagicMock()]
|
||||
response.choices[0].message.content = content
|
||||
return response
|
||||
|
||||
|
||||
class TestProviderConfiguration:
|
||||
"""LLM-02: Provider configuration — all three providers are present and correct."""
|
||||
|
||||
def test_model_list_has_three_entries(self) -> None:
|
||||
"""The model_list must have exactly three entries (fast, quality x2)."""
|
||||
assert len(_model_list) == 3
|
||||
|
||||
def test_fast_group_present_in_model_list(self) -> None:
|
||||
"""The 'fast' model group must exist in the model_list."""
|
||||
fast_entries = [m for m in _model_list if m["model_name"] == "fast"]
|
||||
assert len(fast_entries) >= 1, "No 'fast' model group found in model_list"
|
||||
|
||||
def test_quality_group_present_in_model_list(self) -> None:
|
||||
"""The 'quality' model group must have at least two entries (Anthropic + OpenAI)."""
|
||||
quality_entries = [m for m in _model_list if m["model_name"] == "quality"]
|
||||
assert len(quality_entries) >= 2, "Expected at least 2 'quality' entries (Anthropic + OpenAI)"
|
||||
|
||||
def test_fast_group_uses_ollama_model(self) -> None:
|
||||
"""The fast group must route to an ollama/* model."""
|
||||
fast_entries = [m for m in _model_list if m["model_name"] == "fast"]
|
||||
assert fast_entries, "No fast entry found"
|
||||
ollama_models = [
|
||||
e for e in fast_entries
|
||||
if e["litellm_params"]["model"].startswith("ollama/")
|
||||
]
|
||||
assert ollama_models, f"Fast group does not use an ollama model: {fast_entries}"
|
||||
|
||||
def test_fast_group_has_ollama_api_base(self) -> None:
|
||||
"""The fast group entry must specify an api_base pointing to Ollama."""
|
||||
fast_entries = [m for m in _model_list if m["model_name"] == "fast"]
|
||||
for entry in fast_entries:
|
||||
params = entry["litellm_params"]
|
||||
assert "api_base" in params, f"Fast group entry missing api_base: {entry}"
|
||||
|
||||
def test_quality_group_has_anthropic_entry(self) -> None:
|
||||
"""Quality group must include an anthropic/* model."""
|
||||
quality_entries = [m for m in _model_list if m["model_name"] == "quality"]
|
||||
anthropic_entries = [
|
||||
e for e in quality_entries
|
||||
if e["litellm_params"]["model"].startswith("anthropic/")
|
||||
]
|
||||
assert anthropic_entries, f"No Anthropic entry in quality group: {quality_entries}"
|
||||
|
||||
def test_quality_group_has_openai_entry(self) -> None:
|
||||
"""Quality group must include an openai/* model as the fallback."""
|
||||
quality_entries = [m for m in _model_list if m["model_name"] == "quality"]
|
||||
openai_entries = [
|
||||
e for e in quality_entries
|
||||
if e["litellm_params"]["model"].startswith("openai/")
|
||||
]
|
||||
assert openai_entries, f"No OpenAI entry in quality group: {quality_entries}"
|
||||
|
||||
def test_anthropic_model_is_claude_sonnet(self) -> None:
|
||||
"""Anthropic entry must use the correct model from CLAUDE.md architecture."""
|
||||
quality_entries = [m for m in _model_list if m["model_name"] == "quality"]
|
||||
anthropic_entry = next(
|
||||
(e for e in quality_entries if e["litellm_params"]["model"].startswith("anthropic/")),
|
||||
None,
|
||||
)
|
||||
assert anthropic_entry is not None
|
||||
model = anthropic_entry["litellm_params"]["model"]
|
||||
assert "claude-sonnet" in model, f"Expected claude-sonnet model, got: {model}"
|
||||
|
||||
def test_openai_model_is_gpt4o(self) -> None:
|
||||
"""OpenAI entry must use gpt-4o as specified in architecture."""
|
||||
quality_entries = [m for m in _model_list if m["model_name"] == "quality"]
|
||||
openai_entry = next(
|
||||
(e for e in quality_entries if e["litellm_params"]["model"].startswith("openai/")),
|
||||
None,
|
||||
)
|
||||
assert openai_entry is not None
|
||||
model = openai_entry["litellm_params"]["model"]
|
||||
assert "gpt-4o" in model, f"Expected gpt-4o model, got: {model}"
|
||||
|
||||
def test_fast_request_calls_acompletion_with_fast_model(self) -> None:
|
||||
"""A fast model request must invoke acompletion with model='fast'."""
|
||||
mock_response = _make_completion_response("ollama says hi")
|
||||
|
||||
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||
mock_complete.return_value = mock_response
|
||||
|
||||
response = client.post(
|
||||
"/complete",
|
||||
json={
|
||||
"model": "fast",
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"tenant_id": "tenant-fast",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
call_kwargs = mock_complete.call_args
|
||||
assert call_kwargs is not None
|
||||
called_model = call_kwargs.kwargs.get("model") or (call_kwargs.args[0] if call_kwargs.args else None)
|
||||
assert called_model == "fast"
|
||||
|
||||
def test_quality_request_calls_acompletion_with_quality_model(self) -> None:
|
||||
"""A quality model request must invoke acompletion with model='quality'."""
|
||||
mock_response = _make_completion_response("anthropic says hi")
|
||||
|
||||
with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete:
|
||||
mock_complete.return_value = mock_response
|
||||
|
||||
response = client.post(
|
||||
"/complete",
|
||||
json={
|
||||
"model": "quality",
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"tenant_id": "tenant-quality",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
call_kwargs = mock_complete.call_args
|
||||
assert call_kwargs is not None
|
||||
called_model = call_kwargs.kwargs.get("model") or (call_kwargs.args[0] if call_kwargs.args else None)
|
||||
assert called_model == "quality"
|
||||
|
||||
def test_router_fallback_config_quality_falls_to_fast(self) -> None:
|
||||
"""The Router fallbacks config must specify quality -> fast cross-group fallback."""
|
||||
# Access the Router's fallbacks attribute
|
||||
fallbacks = getattr(llm_router, "fallbacks", None)
|
||||
assert fallbacks is not None, "Router has no fallbacks configured"
|
||||
|
||||
# Find the quality -> fast fallback entry
|
||||
quality_fallback = None
|
||||
for fb in fallbacks:
|
||||
if isinstance(fb, dict) and "quality" in fb:
|
||||
quality_fallback = fb
|
||||
break
|
||||
|
||||
assert quality_fallback is not None, (
|
||||
f"No quality->fast fallback found. Current fallbacks: {fallbacks}"
|
||||
)
|
||||
fallback_targets = quality_fallback["quality"]
|
||||
assert "fast" in fallback_targets, (
|
||||
f"Quality fallback does not target 'fast' group: {fallback_targets}"
|
||||
)
|
||||
Reference in New Issue
Block a user