Files
konstruct/packages/llm-pool/llm_pool/router.py
Adolfo Delorenzo f3e358b418 feat(streaming): add complete_stream() generator and POST /complete/stream NDJSON endpoint to llm-pool
- complete_stream() in router.py yields token strings via acompletion(stream=True)
- POST /complete/stream returns NDJSON: chunk lines then a done line
- Streaming path does not support tool calls (plain text only)
- Non-streaming POST /complete endpoint unchanged
2026-03-25 17:56:56 -06:00

222 lines
7.0 KiB
Python

"""
LiteLLM Router — multi-provider LLM backend with automatic fallback.
Provider groups:
"fast" → Ollama (local, low-latency, no cost)
"quality" → Anthropic claude-sonnet-4 (primary), OpenAI gpt-4o (fallback)
Fallback chain:
quality providers fail → fall back to "fast" group
NOTE: LiteLLM is pinned to ==1.82.5 in pyproject.toml.
Do NOT upgrade without testing — a September 2025 OOM regression exists
in releases between 1.83.x and later versions.
"""
from __future__ import annotations
import logging
from collections.abc import AsyncGenerator
from typing import Any
from litellm import Router
from shared.config import settings
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Model list — three entries across two groups
# ---------------------------------------------------------------------------
_model_list: list[dict] = [
# ── local group — Ollama, no API cost ──
{
"model_name": "local",
"litellm_params": {
"model": f"ollama/{settings.ollama_model}",
"api_base": settings.ollama_base_url,
},
},
# ── fast group — same as local (aliases for preference mapping) ──
{
"model_name": "fast",
"litellm_params": {
"model": f"ollama/{settings.ollama_model}",
"api_base": settings.ollama_base_url,
},
},
# ── economy group — local model, cheaper than commercial ──
{
"model_name": "economy",
"litellm_params": {
"model": f"ollama/{settings.ollama_model}",
"api_base": settings.ollama_base_url,
},
},
# ── balanced group — Ollama primary, commercial fallback ──
{
"model_name": "balanced",
"litellm_params": {
"model": f"ollama/{settings.ollama_model}",
"api_base": settings.ollama_base_url,
},
},
{
"model_name": "balanced",
"litellm_params": {
"model": "anthropic/claude-sonnet-4-20250514",
"api_key": settings.anthropic_api_key,
},
},
# ── quality group — Anthropic primary, OpenAI fallback ──
{
"model_name": "quality",
"litellm_params": {
"model": "anthropic/claude-sonnet-4-20250514",
"api_key": settings.anthropic_api_key,
},
},
{
"model_name": "quality",
"litellm_params": {
"model": "openai/gpt-4o",
"api_key": settings.openai_api_key,
},
},
]
# ---------------------------------------------------------------------------
# Router — latency-based, 2 retries per provider, then cross-group fallback
# ---------------------------------------------------------------------------
llm_router = Router(
model_list=_model_list,
# If all quality providers fail, fall back to the fast group
fallbacks=[{"quality": ["fast"]}, {"balanced": ["fast"]}],
routing_strategy="latency-based-routing",
num_retries=2,
set_verbose=False,
)
class LLMResponse:
"""
Container for LLM completion response.
Attributes:
content: Text content of the response (empty string if tool_calls present).
tool_calls: List of tool call dicts in OpenAI format, or empty list.
"""
def __init__(self, content: str, tool_calls: list[dict[str, Any]]) -> None:
self.content = content
self.tool_calls = tool_calls
async def complete(
model_group: str,
messages: list[dict],
tenant_id: str,
tools: list[dict] | None = None,
) -> LLMResponse:
"""
Request a completion from the LiteLLM Router.
Args:
model_group: "quality" or "fast" — selects the provider group.
messages: OpenAI-format message list, e.g.
[{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]
tenant_id: Konstruct tenant UUID, attached to LiteLLM metadata for
per-tenant cost tracking.
tools: Optional list of OpenAI function-calling tool dicts. When provided,
the LLM may return tool_calls instead of text content.
Returns:
LLMResponse with content (text) and tool_calls (list of tool call dicts).
- If LLM returns text: content is non-empty, tool_calls is empty.
- If LLM returns tool calls: content is empty, tool_calls contains calls.
Raises:
Exception: Propagated if all providers in the group (and fallbacks) fail.
"""
logger.info("LLM request", extra={"model_group": model_group, "tenant_id": tenant_id})
kwargs: dict[str, Any] = {
"model": model_group,
"messages": messages,
"metadata": {"tenant_id": tenant_id},
}
if tools:
kwargs["tools"] = tools
response = await llm_router.acompletion(**kwargs)
choice = response.choices[0]
message = choice.message
# Extract tool_calls if present
raw_tool_calls = getattr(message, "tool_calls", None) or []
tool_calls: list[dict[str, Any]] = []
for tc in raw_tool_calls:
# LiteLLM returns tool calls as objects with .id, .function.name, .function.arguments
try:
tool_calls.append({
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
})
except AttributeError:
# Fallback: if it's already a dict (some providers)
if isinstance(tc, dict):
tool_calls.append(tc)
content: str = message.content or ""
return LLMResponse(content=content, tool_calls=tool_calls)
async def complete_stream(
model_group: str,
messages: list[dict],
tenant_id: str,
) -> AsyncGenerator[str, None]:
"""
Stream a completion from the LiteLLM Router, yielding token strings.
Only used for the web channel streaming path — does NOT support tool calls
(tool-call responses are not streamed). The caller is responsible for
assembling the full response from the yielded chunks.
Args:
model_group: "quality", "fast", etc. — selects the provider group.
messages: OpenAI-format message list.
tenant_id: Konstruct tenant UUID for cost tracking metadata.
Yields:
Token strings as they are generated by the LLM.
Raises:
Exception: Propagated if all providers (and fallbacks) fail.
"""
logger.info(
"LLM stream request",
extra={"model_group": model_group, "tenant_id": tenant_id},
)
response = await llm_router.acompletion(
model=model_group,
messages=messages,
metadata={"tenant_id": tenant_id},
stream=True,
)
async for chunk in response:
try:
delta = chunk.choices[0].delta
token = getattr(delta, "content", None)
if token:
yield token
except (IndexError, AttributeError):
continue