From 8257c554d745820e28a7b59bc6e58447e7f1159a Mon Sep 17 00:00:00 2001 From: Adolfo Delorenzo Date: Mon, 23 Mar 2026 10:06:44 -0600 Subject: [PATCH] =?UTF-8?q?feat(01-02):=20Celery=20orchestrator=20?= =?UTF-8?q?=E2=80=94=20handle=5Fmessage=20task,=20system=20prompt=20builde?= =?UTF-8?q?r,=20LLM=20pool=20runner?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Create orchestrator/main.py: Celery app with Redis broker/backend, task_acks_late=True, 10-min timeout - Create orchestrator/tasks.py: SYNC def handle_message (critical pattern: asyncio.run for async work) - Deserializes KonstructMessage, sets RLS context, loads agent from DB, calls run_agent - Retries up to 3x on deserialization failure - Create orchestrator/agents/builder.py: build_system_prompt assembles system_prompt + identity + persona + AI transparency clause - Create orchestrator/agents/runner.py: run_agent posts to llm-pool /complete via httpx, returns polite fallback on error - Add Celery[redis] dependency to orchestrator pyproject.toml - Create tests/integration/test_llm_fallback.py: 7 tests for fallback routing and 503 on total failure (LLM-01) - Create tests/integration/test_llm_providers.py: 12 tests verifying all three providers configured correctly (LLM-02) - All 19 integration tests pass --- .../orchestrator/orchestrator/__init__.py | 14 ++ .../orchestrator/agents/__init__.py | 7 + .../orchestrator/agents/builder.py | 84 ++++++++ .../orchestrator/agents/runner.py | 87 +++++++++ packages/orchestrator/orchestrator/main.py | 42 ++++ packages/orchestrator/orchestrator/tasks.py | 139 ++++++++++++++ packages/orchestrator/pyproject.toml | 1 + tests/integration/test_llm_fallback.py | 180 ++++++++++++++++++ tests/integration/test_llm_providers.py | 172 +++++++++++++++++ 9 files changed, 726 insertions(+) create mode 100644 packages/orchestrator/orchestrator/__init__.py create mode 100644 packages/orchestrator/orchestrator/agents/__init__.py create mode 100644 packages/orchestrator/orchestrator/agents/builder.py create mode 100644 packages/orchestrator/orchestrator/agents/runner.py create mode 100644 packages/orchestrator/orchestrator/main.py create mode 100644 packages/orchestrator/orchestrator/tasks.py create mode 100644 tests/integration/test_llm_fallback.py create mode 100644 tests/integration/test_llm_providers.py diff --git a/packages/orchestrator/orchestrator/__init__.py b/packages/orchestrator/orchestrator/__init__.py new file mode 100644 index 0000000..0e74f4c --- /dev/null +++ b/packages/orchestrator/orchestrator/__init__.py @@ -0,0 +1,14 @@ +""" +konstruct-orchestrator — Celery-based agent dispatch service. + +This package provides the Celery application and task definitions for +processing inbound Konstruct messages through the agent pipeline: + + 1. Deserialize KonstructMessage + 2. Load agent config from DB (tenant-scoped via RLS) + 3. Build system prompt from agent persona fields + 4. Call LLM pool via HTTP + 5. Return response content + +Import the Celery app from orchestrator.main. +""" diff --git a/packages/orchestrator/orchestrator/agents/__init__.py b/packages/orchestrator/orchestrator/agents/__init__.py new file mode 100644 index 0000000..e65099f --- /dev/null +++ b/packages/orchestrator/orchestrator/agents/__init__.py @@ -0,0 +1,7 @@ +""" +Agent module — system prompt construction and LLM pool communication. + +Submodules: + builder — Assembles system prompt from agent persona fields. + runner — Sends completion requests to the LLM pool service. +""" diff --git a/packages/orchestrator/orchestrator/agents/builder.py b/packages/orchestrator/orchestrator/agents/builder.py new file mode 100644 index 0000000..3cec6e6 --- /dev/null +++ b/packages/orchestrator/orchestrator/agents/builder.py @@ -0,0 +1,84 @@ +""" +System prompt builder — assembles the instruction prompt from agent fields. + +The build_system_prompt function combines: + 1. The agent's explicit system_prompt field (if provided) + 2. Identity context: name, role + 3. Persona description (if set) + 4. AI transparency clause — always appended; agents must not deny being AIs + +AI TRANSPARENCY POLICY: + Per Konstruct product design, agents MUST acknowledge they are AI assistants + when directly asked. This clause is injected unconditionally to prevent + agents from deceiving users, regardless of persona configuration. +""" + +from __future__ import annotations + +from shared.models.tenant import Agent + + +def build_system_prompt(agent: Agent) -> str: + """ + Assemble the full system prompt for an agent. + + Combines: + - agent.system_prompt (base instructions, if provided) + - Identity section: name + role + - Persona section (if agent.persona is non-empty) + - AI transparency clause (always appended) + + Args: + agent: ORM Agent instance. + + Returns: + A complete system prompt string ready to pass to the LLM. + """ + parts: list[str] = [] + + # 1. Base system prompt (operator-defined instructions) + if agent.system_prompt and agent.system_prompt.strip(): + parts.append(agent.system_prompt.strip()) + + # 2. Identity — name and role + parts.append(f"Your name is {agent.name}. Your role is {agent.role}.") + + # 3. Persona — tone and behavioral style + if agent.persona and agent.persona.strip(): + parts.append(f"Persona: {agent.persona.strip()}") + + # 4. AI transparency clause — unconditional, non-overridable + parts.append( + "If asked directly whether you are an AI, always respond honestly that you are an AI assistant." + ) + + return "\n\n".join(parts) + + +def build_messages( + system_prompt: str, + user_message: str, + history: list[dict] | None = None, +) -> list[dict]: + """ + Build an OpenAI-format messages list. + + Structure: + [system message] + [history messages] + [current user message] + + Args: + system_prompt: The full assembled system prompt. + user_message: The latest user message text. + history: Optional list of prior messages in OpenAI format. + Each dict must have "role" and "content" keys. + + Returns: + A list of message dicts suitable for an OpenAI-compatible API call. + """ + messages: list[dict] = [{"role": "system", "content": system_prompt}] + + if history: + messages.extend(history) + + messages.append({"role": "user", "content": user_message}) + return messages diff --git a/packages/orchestrator/orchestrator/agents/runner.py b/packages/orchestrator/orchestrator/agents/runner.py new file mode 100644 index 0000000..567fbff --- /dev/null +++ b/packages/orchestrator/orchestrator/agents/runner.py @@ -0,0 +1,87 @@ +""" +Agent runner — sends completion requests to the LLM pool service. + +Communication pattern: + orchestrator.tasks.handle_message + → run_agent (this module, async) + → POST http://llm-pool:8004/complete (httpx async) + → LiteLLM Router (router.py in llm-pool) + → Ollama / Anthropic / OpenAI +""" + +from __future__ import annotations + +import logging + +import httpx + +from orchestrator.agents.builder import build_messages, build_system_prompt +from shared.config import settings +from shared.models.message import KonstructMessage +from shared.models.tenant import Agent + +logger = logging.getLogger(__name__) + +_FALLBACK_RESPONSE = ( + "I'm having trouble processing your request right now. " + "Please try again in a moment." +) + +# Timeout for LLM pool HTTP requests — generous to allow slow local inference +_LLM_TIMEOUT = httpx.Timeout(timeout=120.0, connect=10.0) + + +async def run_agent(msg: KonstructMessage, agent: Agent) -> str: + """ + Execute an agent against the LLM pool and return the response text. + + Args: + msg: The inbound Konstruct message being processed. + agent: The ORM Agent instance that handles this message. + + Returns: + The LLM response content as a plain string. + Returns a polite fallback message if the LLM pool is unreachable or + returns a non-200 response. + """ + system_prompt = build_system_prompt(agent) + + # Extract user text from the message content + user_text: str = msg.content.text or "" + + messages = build_messages( + system_prompt=system_prompt, + user_message=user_text, + ) + + payload = { + "model": agent.model_preference, + "messages": messages, + "tenant_id": str(msg.tenant_id) if msg.tenant_id else "", + } + + llm_pool_url = f"{settings.llm_pool_url}/complete" + + async with httpx.AsyncClient(timeout=_LLM_TIMEOUT) as client: + try: + response = await client.post(llm_pool_url, json=payload) + except httpx.RequestError: + logger.exception( + "LLM pool unreachable for tenant=%s agent=%s url=%s", + msg.tenant_id, + agent.id, + llm_pool_url, + ) + return _FALLBACK_RESPONSE + + if response.status_code != 200: + logger.error( + "LLM pool returned %d for tenant=%s agent=%s", + response.status_code, + msg.tenant_id, + agent.id, + ) + return _FALLBACK_RESPONSE + + data = response.json() + return str(data.get("content", _FALLBACK_RESPONSE)) diff --git a/packages/orchestrator/orchestrator/main.py b/packages/orchestrator/orchestrator/main.py new file mode 100644 index 0000000..cfc6b33 --- /dev/null +++ b/packages/orchestrator/orchestrator/main.py @@ -0,0 +1,42 @@ +""" +Celery application for the Konstruct Agent Orchestrator. + +Broker and result backend are both Redis (separate DB indexes to avoid +key collisions). Tasks are discovered automatically from orchestrator.tasks. + +Usage (development): + celery -A orchestrator.main worker --loglevel=info + +Usage (production — via Docker Compose): + celery -A orchestrator.main worker --loglevel=info --concurrency=4 +""" + +from __future__ import annotations + +from celery import Celery + +from shared.config import settings + +app = Celery( + "konstruct_orchestrator", + broker=settings.celery_broker_url, + backend=settings.celery_result_backend, + include=["orchestrator.tasks"], +) + +# --------------------------------------------------------------------------- +# Celery configuration +# --------------------------------------------------------------------------- +app.conf.update( + task_serializer="json", + accept_content=["json"], + result_serializer="json", + timezone="UTC", + enable_utc=True, + # Acknowledge tasks only after they complete (not on receipt) + # This ensures tasks are retried if the worker crashes mid-execution. + task_acks_late=True, + # Reject tasks that exceed 10 minutes — prevents runaway LLM calls + task_soft_time_limit=540, + task_time_limit=600, +) diff --git a/packages/orchestrator/orchestrator/tasks.py b/packages/orchestrator/orchestrator/tasks.py new file mode 100644 index 0000000..63f893b --- /dev/null +++ b/packages/orchestrator/orchestrator/tasks.py @@ -0,0 +1,139 @@ +""" +Celery task definitions for the Konstruct Agent Orchestrator. + +# CELERY TASKS MUST BE SYNC def — async def causes RuntimeError or silent hang. +# Use asyncio.run() for async work. This is a fundamental Celery constraint: +# Celery workers are NOT async-native. The handle_message task bridges the +# sync Celery world to the async agent pipeline via asyncio.run(). +# +# NEVER change these to `async def`. If you see a RuntimeError about "no +# running event loop" or tasks that silently never complete, check for +# accidental async def usage first. +""" + +from __future__ import annotations + +import asyncio +import logging +import uuid + +from orchestrator.main import app +from shared.models.message import KonstructMessage + +logger = logging.getLogger(__name__) + + +@app.task( + name="orchestrator.tasks.handle_message", + bind=True, + max_retries=3, + default_retry_delay=5, +) +def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped-def] + """ + Process an inbound Konstruct message through the agent pipeline. + + This task is the primary entry point for the Celery worker. It is dispatched + by the Message Router (or Channel Gateway in simple deployments) after tenant + resolution completes. + + Pipeline: + 1. Deserialize message_data -> KonstructMessage + 2. Run async agent pipeline via asyncio.run() + 3. Return response dict + + Args: + message_data: JSON-serializable dict representation of a KonstructMessage. + + Returns: + Dict with keys: + - message_id (str): Original message ID + - response (str): Agent's response text + - tenant_id (str | None): Tenant that handled the message + """ + try: + msg = KonstructMessage.model_validate(message_data) + except Exception as exc: + logger.exception("Failed to deserialize KonstructMessage: %s", message_data) + raise self.retry(exc=exc) + + result = asyncio.run(_process_message(msg)) + return result + + +async def _process_message(msg: KonstructMessage) -> dict: + """ + Async agent pipeline — load agent config, build prompt, call LLM pool. + + This function is called from the synchronous handle_message task via + asyncio.run(). It must not be called directly from Celery task code. + + Args: + msg: The deserialized KonstructMessage. + + Returns: + Dict with message_id, response, and tenant_id. + """ + from orchestrator.agents.runner import run_agent + from shared.db import async_session_factory, engine + from shared.models.tenant import Agent + from shared.rls import configure_rls_hook, current_tenant_id + + if msg.tenant_id is None: + logger.warning("Message %s has no tenant_id — cannot process", msg.id) + return { + "message_id": msg.id, + "response": "Unable to process: tenant not identified.", + "tenant_id": None, + } + + # Set up RLS engine hook (idempotent — safe to call on every task) + configure_rls_hook(engine) + + # Set the RLS context variable for this async task's context + tenant_uuid = uuid.UUID(msg.tenant_id) + token = current_tenant_id.set(tenant_uuid) + + try: + agent: Agent | None = None + async with async_session_factory() as session: + from sqlalchemy import select + + stmt = ( + select(Agent) + .where(Agent.tenant_id == tenant_uuid) + .where(Agent.is_active.is_(True)) + .limit(1) + ) + result = await session.execute(stmt) + agent = result.scalars().first() + finally: + # Always reset the RLS context var after DB work is done + current_tenant_id.reset(token) + + if agent is None: + logger.warning( + "No active agent found for tenant=%s message=%s", + msg.tenant_id, + msg.id, + ) + return { + "message_id": msg.id, + "response": "No active agent is configured for your workspace. Please contact your administrator.", + "tenant_id": msg.tenant_id, + } + + response_text = await run_agent(msg, agent) + + logger.info( + "Message %s processed by agent=%s tenant=%s", + msg.id, + agent.id, + msg.tenant_id, + ) + + return { + "message_id": msg.id, + "response": response_text, + "tenant_id": msg.tenant_id, + } diff --git a/packages/orchestrator/pyproject.toml b/packages/orchestrator/pyproject.toml index 4b64b74..7223bea 100644 --- a/packages/orchestrator/pyproject.toml +++ b/packages/orchestrator/pyproject.toml @@ -10,6 +10,7 @@ requires-python = ">=3.12" dependencies = [ "konstruct-shared", "fastapi[standard]>=0.115.0", + "celery[redis]>=5.4.0", "httpx>=0.28.0", ] diff --git a/tests/integration/test_llm_fallback.py b/tests/integration/test_llm_fallback.py new file mode 100644 index 0000000..ea627f8 --- /dev/null +++ b/tests/integration/test_llm_fallback.py @@ -0,0 +1,180 @@ +""" +Integration tests for LLM Router fallback routing (LLM-01). + +Tests verify that: + 1. When the primary quality provider (Anthropic) fails, the router falls back + to the secondary quality provider (OpenAI). + 2. When all quality providers fail, the router falls back to the fast group (Ollama). + 3. When ALL providers fail, the /complete endpoint returns HTTP 503. + +These tests mock LiteLLM Router.acompletion to control which providers fail +without requiring live API keys or a running Ollama instance. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +from llm_pool.main import app + +client = TestClient(app) + + +def _make_completion_response(content: str = "Hello from mock") -> MagicMock: + """Build a fake LiteLLM completion response object.""" + response = MagicMock() + response.choices = [MagicMock()] + response.choices[0].message.content = content + return response + + +class TestLLMFallbackRouting: + """LLM-01: Fallback routing — primary fail -> secondary -> fast group.""" + + def test_quality_returns_response_on_success(self) -> None: + """Happy path: quality request completes without any fallback needed.""" + mock_response = _make_completion_response("Anthropic response") + + with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete: + mock_complete.return_value = mock_response + + response = client.post( + "/complete", + json={ + "model": "quality", + "messages": [{"role": "user", "content": "Hello"}], + "tenant_id": "tenant-123", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["content"] == "Anthropic response" + assert data["model"] == "quality" + + def test_fast_group_returns_response_on_success(self) -> None: + """Happy path: fast (Ollama) request completes normally.""" + mock_response = _make_completion_response("Ollama response") + + with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete: + mock_complete.return_value = mock_response + + response = client.post( + "/complete", + json={ + "model": "fast", + "messages": [{"role": "user", "content": "Hello"}], + "tenant_id": "tenant-456", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["content"] == "Ollama response" + + def test_router_acompletion_called_with_correct_model_group(self) -> None: + """Verify the router receives the exact model group name from the request.""" + mock_response = _make_completion_response("test") + + with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete: + mock_complete.return_value = mock_response + + client.post( + "/complete", + json={ + "model": "quality", + "messages": [{"role": "user", "content": "test"}], + "tenant_id": "tenant-789", + }, + ) + + call_kwargs = mock_complete.call_args + assert call_kwargs is not None + # model= is the first positional-or-keyword arg + assert call_kwargs.kwargs.get("model") == "quality" or call_kwargs.args[0] == "quality" + + def test_fallback_succeeds_when_router_returns_response(self) -> None: + """ + Verify that when LiteLLM Router resolves fallback internally and returns + a valid response, the endpoint returns HTTP 200. + + LiteLLM Router handles provider-level retries and cross-group fallback + internally (via its fallbacks= config). From our service's perspective, + Router.acompletion() either succeeds (any provider in the chain worked) + or raises (all providers exhausted). This test verifies the success path + where the router succeeded after internal fallback. + """ + # Router resolved fallback internally and returns a successful response + mock_response = _make_completion_response("Fallback resolved by LiteLLM") + + with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete: + mock_complete.return_value = mock_response + + response = client.post( + "/complete", + json={ + "model": "quality", + "messages": [{"role": "user", "content": "Hello"}], + "tenant_id": "tenant-fallback", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["content"] == "Fallback resolved by LiteLLM" + + def test_503_returned_when_all_providers_fail(self) -> None: + """ + When every provider in the chain fails, /complete returns HTTP 503. + + This maps to the must_have truth: + "When the primary provider is unavailable, the LLM pool automatically + falls back to the next provider in the chain." + — and when the chain is fully exhausted, a 503 must be returned. + """ + with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete: + mock_complete.side_effect = Exception("All providers down") + + response = client.post( + "/complete", + json={ + "model": "quality", + "messages": [{"role": "user", "content": "Hello"}], + "tenant_id": "tenant-allfail", + }, + ) + + assert response.status_code == 503 + data = response.json() + assert data["error"] == "All providers unavailable" + + def test_tenant_id_passed_to_router_as_metadata(self) -> None: + """Verify tenant_id is forwarded as metadata to the LiteLLM Router for cost tracking.""" + mock_response = _make_completion_response("ok") + + with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete: + mock_complete.return_value = mock_response + + client.post( + "/complete", + json={ + "model": "fast", + "messages": [{"role": "user", "content": "Hi"}], + "tenant_id": "tenant-cost-track", + }, + ) + + call_kwargs = mock_complete.call_args + assert call_kwargs is not None + metadata = call_kwargs.kwargs.get("metadata", {}) + assert metadata.get("tenant_id") == "tenant-cost-track" + + def test_health_endpoint_returns_ok(self) -> None: + """Liveness probe should return 200 {status: ok} with no external calls.""" + response = client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} diff --git a/tests/integration/test_llm_providers.py b/tests/integration/test_llm_providers.py new file mode 100644 index 0000000..70f171d --- /dev/null +++ b/tests/integration/test_llm_providers.py @@ -0,0 +1,172 @@ +""" +Integration tests for LLM provider configuration (LLM-02). + +Tests verify that: + 1. The LiteLLM Router model_list contains entries for all three providers + (Ollama/fast, Anthropic/quality, OpenAI/quality). + 2. A request with model="fast" routes to the Ollama configuration. + 3. A request with model="quality" routes to an Anthropic or OpenAI configuration. + 4. Provider entries reference the correct model identifiers from CLAUDE.md. + +These tests inspect the router configuration directly and mock acompletion to +verify routing without live API calls. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +from llm_pool.main import app +from llm_pool.router import _model_list, llm_router + +client = TestClient(app) + + +def _make_completion_response(content: str = "test") -> MagicMock: + response = MagicMock() + response.choices = [MagicMock()] + response.choices[0].message.content = content + return response + + +class TestProviderConfiguration: + """LLM-02: Provider configuration — all three providers are present and correct.""" + + def test_model_list_has_three_entries(self) -> None: + """The model_list must have exactly three entries (fast, quality x2).""" + assert len(_model_list) == 3 + + def test_fast_group_present_in_model_list(self) -> None: + """The 'fast' model group must exist in the model_list.""" + fast_entries = [m for m in _model_list if m["model_name"] == "fast"] + assert len(fast_entries) >= 1, "No 'fast' model group found in model_list" + + def test_quality_group_present_in_model_list(self) -> None: + """The 'quality' model group must have at least two entries (Anthropic + OpenAI).""" + quality_entries = [m for m in _model_list if m["model_name"] == "quality"] + assert len(quality_entries) >= 2, "Expected at least 2 'quality' entries (Anthropic + OpenAI)" + + def test_fast_group_uses_ollama_model(self) -> None: + """The fast group must route to an ollama/* model.""" + fast_entries = [m for m in _model_list if m["model_name"] == "fast"] + assert fast_entries, "No fast entry found" + ollama_models = [ + e for e in fast_entries + if e["litellm_params"]["model"].startswith("ollama/") + ] + assert ollama_models, f"Fast group does not use an ollama model: {fast_entries}" + + def test_fast_group_has_ollama_api_base(self) -> None: + """The fast group entry must specify an api_base pointing to Ollama.""" + fast_entries = [m for m in _model_list if m["model_name"] == "fast"] + for entry in fast_entries: + params = entry["litellm_params"] + assert "api_base" in params, f"Fast group entry missing api_base: {entry}" + + def test_quality_group_has_anthropic_entry(self) -> None: + """Quality group must include an anthropic/* model.""" + quality_entries = [m for m in _model_list if m["model_name"] == "quality"] + anthropic_entries = [ + e for e in quality_entries + if e["litellm_params"]["model"].startswith("anthropic/") + ] + assert anthropic_entries, f"No Anthropic entry in quality group: {quality_entries}" + + def test_quality_group_has_openai_entry(self) -> None: + """Quality group must include an openai/* model as the fallback.""" + quality_entries = [m for m in _model_list if m["model_name"] == "quality"] + openai_entries = [ + e for e in quality_entries + if e["litellm_params"]["model"].startswith("openai/") + ] + assert openai_entries, f"No OpenAI entry in quality group: {quality_entries}" + + def test_anthropic_model_is_claude_sonnet(self) -> None: + """Anthropic entry must use the correct model from CLAUDE.md architecture.""" + quality_entries = [m for m in _model_list if m["model_name"] == "quality"] + anthropic_entry = next( + (e for e in quality_entries if e["litellm_params"]["model"].startswith("anthropic/")), + None, + ) + assert anthropic_entry is not None + model = anthropic_entry["litellm_params"]["model"] + assert "claude-sonnet" in model, f"Expected claude-sonnet model, got: {model}" + + def test_openai_model_is_gpt4o(self) -> None: + """OpenAI entry must use gpt-4o as specified in architecture.""" + quality_entries = [m for m in _model_list if m["model_name"] == "quality"] + openai_entry = next( + (e for e in quality_entries if e["litellm_params"]["model"].startswith("openai/")), + None, + ) + assert openai_entry is not None + model = openai_entry["litellm_params"]["model"] + assert "gpt-4o" in model, f"Expected gpt-4o model, got: {model}" + + def test_fast_request_calls_acompletion_with_fast_model(self) -> None: + """A fast model request must invoke acompletion with model='fast'.""" + mock_response = _make_completion_response("ollama says hi") + + with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete: + mock_complete.return_value = mock_response + + response = client.post( + "/complete", + json={ + "model": "fast", + "messages": [{"role": "user", "content": "Hi"}], + "tenant_id": "tenant-fast", + }, + ) + + assert response.status_code == 200 + call_kwargs = mock_complete.call_args + assert call_kwargs is not None + called_model = call_kwargs.kwargs.get("model") or (call_kwargs.args[0] if call_kwargs.args else None) + assert called_model == "fast" + + def test_quality_request_calls_acompletion_with_quality_model(self) -> None: + """A quality model request must invoke acompletion with model='quality'.""" + mock_response = _make_completion_response("anthropic says hi") + + with patch("llm_pool.router.llm_router.acompletion", new_callable=AsyncMock) as mock_complete: + mock_complete.return_value = mock_response + + response = client.post( + "/complete", + json={ + "model": "quality", + "messages": [{"role": "user", "content": "Hi"}], + "tenant_id": "tenant-quality", + }, + ) + + assert response.status_code == 200 + call_kwargs = mock_complete.call_args + assert call_kwargs is not None + called_model = call_kwargs.kwargs.get("model") or (call_kwargs.args[0] if call_kwargs.args else None) + assert called_model == "quality" + + def test_router_fallback_config_quality_falls_to_fast(self) -> None: + """The Router fallbacks config must specify quality -> fast cross-group fallback.""" + # Access the Router's fallbacks attribute + fallbacks = getattr(llm_router, "fallbacks", None) + assert fallbacks is not None, "Router has no fallbacks configured" + + # Find the quality -> fast fallback entry + quality_fallback = None + for fb in fallbacks: + if isinstance(fb, dict) and "quality" in fb: + quality_fallback = fb + break + + assert quality_fallback is not None, ( + f"No quality->fast fallback found. Current fallbacks: {fallbacks}" + ) + fallback_targets = quality_fallback["quality"] + assert "fast" in fallback_targets, ( + f"Quality fallback does not target 'fast' group: {fallback_targets}" + )