diff --git a/packages/llm-pool/llm_pool/main.py b/packages/llm-pool/llm_pool/main.py index 217c533..443fb91 100644 --- a/packages/llm-pool/llm_pool/main.py +++ b/packages/llm-pool/llm_pool/main.py @@ -2,19 +2,23 @@ LLM Backend Pool — FastAPI service on port 8004. Endpoints: - POST /complete — route a completion request through the LiteLLM Router. - GET /health — liveness probe. + POST /complete — route a completion request through the LiteLLM Router. + POST /complete/stream — streaming variant; returns NDJSON token chunks. + GET /health — liveness probe. """ from __future__ import annotations +import json import logging from typing import Any from fastapi import FastAPI +from fastapi.responses import StreamingResponse from pydantic import BaseModel from llm_pool.router import complete as router_complete +from llm_pool.router import complete_stream as router_complete_stream logger = logging.getLogger(__name__) @@ -69,6 +73,19 @@ class HealthResponse(BaseModel): status: str +class StreamCompleteRequest(BaseModel): + """Body for POST /complete/stream.""" + + model: str + """Model group name: "quality" or "fast".""" + + messages: list[dict] + """OpenAI-format message list.""" + + tenant_id: str + """Konstruct tenant UUID for cost tracking.""" + + # --------------------------------------------------------------------------- # Routes # --------------------------------------------------------------------------- @@ -123,3 +140,44 @@ async def complete_endpoint(request: CompleteRequest) -> CompleteResponse: status_code=503, content={"error": "All providers unavailable"}, ) + + +@app.post("/complete/stream") +async def complete_stream_endpoint(request: StreamCompleteRequest) -> StreamingResponse: + """ + Stream a completion through the LiteLLM Router using NDJSON. + + Each line of the response body is a JSON object: + {"type": "chunk", "text": ""} — zero or more times + {"type": "done"} — final line, signals end of stream + + On provider failure, yields: + {"type": "error", "message": "All providers unavailable"} + + The caller (orchestrator runner) reads line-by-line and forwards chunks + to Redis pub-sub for the web WebSocket handler. + + NOTE: Tool calls are NOT supported in this endpoint — only plain text + streaming. Use POST /complete for tool-call responses. + """ + async def _generate() -> Any: + try: + async for token in router_complete_stream( + model_group=request.model, + messages=request.messages, + tenant_id=request.tenant_id, + ): + yield json.dumps({"type": "chunk", "text": token}) + "\n" + yield json.dumps({"type": "done"}) + "\n" + except Exception: + logger.exception( + "Streaming LLM failed for tenant=%s model=%s", + request.tenant_id, + request.model, + ) + yield json.dumps({"type": "error", "message": "All providers unavailable"}) + "\n" + + return StreamingResponse( + _generate(), + media_type="application/x-ndjson", + ) diff --git a/packages/llm-pool/llm_pool/router.py b/packages/llm-pool/llm_pool/router.py index 4ba9df9..91ad180 100644 --- a/packages/llm-pool/llm_pool/router.py +++ b/packages/llm-pool/llm_pool/router.py @@ -16,6 +16,7 @@ NOTE: LiteLLM is pinned to ==1.82.5 in pyproject.toml. from __future__ import annotations import logging +from collections.abc import AsyncGenerator from typing import Any from litellm import Router @@ -173,3 +174,48 @@ async def complete( content: str = message.content or "" return LLMResponse(content=content, tool_calls=tool_calls) + + +async def complete_stream( + model_group: str, + messages: list[dict], + tenant_id: str, +) -> AsyncGenerator[str, None]: + """ + Stream a completion from the LiteLLM Router, yielding token strings. + + Only used for the web channel streaming path — does NOT support tool calls + (tool-call responses are not streamed). The caller is responsible for + assembling the full response from the yielded chunks. + + Args: + model_group: "quality", "fast", etc. — selects the provider group. + messages: OpenAI-format message list. + tenant_id: Konstruct tenant UUID for cost tracking metadata. + + Yields: + Token strings as they are generated by the LLM. + + Raises: + Exception: Propagated if all providers (and fallbacks) fail. + """ + logger.info( + "LLM stream request", + extra={"model_group": model_group, "tenant_id": tenant_id}, + ) + + response = await llm_router.acompletion( + model=model_group, + messages=messages, + metadata={"tenant_id": tenant_id}, + stream=True, + ) + + async for chunk in response: + try: + delta = chunk.choices[0].delta + token = getattr(delta, "content", None) + if token: + yield token + except (IndexError, AttributeError): + continue