feat(10-01): Celery ingestion task, executor injection, KB search wiring

- Add ingest_document Celery task (sync def + asyncio.run per arch constraint)
- Add ingest_document_pipeline: MinIO download, extract, chunk, embed, store
- Add chunk_text sliding window chunker (500 chars default, 50 overlap)
- Update execute_tool to inject tenant_id/agent_id into all tool handler kwargs
- Update web_search to use settings.brave_api_key (shared config) not os.getenv
- Unit tests: test_ingestion.py (9 tests) and test_executor_injection.py (5 tests) all pass
This commit is contained in:
2026-03-26 09:09:36 -06:00
parent 08572fcc40
commit 9c7686a7b4
6 changed files with 748 additions and 4 deletions

View File

@@ -997,3 +997,45 @@ async def _update_slack_placeholder(
channel_id, channel_id,
placeholder_ts, placeholder_ts,
) )
# =============================================================================
# KB Document Ingestion Task
# =============================================================================
@app.task(
name="orchestrator.tasks.ingest_document",
bind=True,
max_retries=2,
default_retry_delay=60,
ignore_result=True,
)
def ingest_document(self, document_id: str, tenant_id: str) -> None: # type: ignore[override]
"""
Celery task: run the KB document ingestion pipeline.
Downloads the document from MinIO (or scrapes URL/YouTube), extracts text,
chunks, embeds with all-MiniLM-L6-v2, and stores kb_chunks rows.
Updates kb_documents.status to 'ready' on success, 'error' on failure.
MUST be sync def — Celery workers are not async-native. asyncio.run() is
used to bridge the sync Celery world to the async pipeline.
Args:
document_id: UUID string of the KnowledgeBaseDocument row.
tenant_id: UUID string of the owning tenant.
"""
from orchestrator.tools.ingest import ingest_document_pipeline
try:
asyncio.run(ingest_document_pipeline(document_id, tenant_id))
except Exception as exc:
logger.exception(
"ingest_document task failed for document=%s tenant=%s: %s",
document_id,
tenant_id,
exc,
)
self.retry(exc=exc, countdown=60)

View File

@@ -13,10 +13,11 @@ raising an exception (graceful degradation for agents without search configured)
from __future__ import annotations from __future__ import annotations
import logging import logging
import os
import httpx import httpx
from shared.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_BRAVE_API_URL = "https://api.search.brave.com/res/v1/web/search" _BRAVE_API_URL = "https://api.search.brave.com/res/v1/web/search"
@@ -24,24 +25,26 @@ _BRAVE_TIMEOUT = httpx.Timeout(timeout=15.0, connect=5.0)
_MAX_RESULTS = 3 _MAX_RESULTS = 3
async def web_search(query: str) -> str: async def web_search(query: str, **kwargs: object) -> str:
""" """
Search the web using Brave Search API. Search the web using Brave Search API.
Args: Args:
query: The search query string. query: The search query string.
**kwargs: Accepts injected tenant_id/agent_id from executor (unused).
Returns: Returns:
Formatted string with top 3 search results (title + URL + description), Formatted string with top 3 search results (title + URL + description),
or an error message if the API is unavailable. or an error message if the API is unavailable.
""" """
api_key = os.getenv("BRAVE_API_KEY", "") api_key = settings.brave_api_key
if not api_key: if not api_key:
return ( return (
"Web search is not configured. " "Web search is not configured. "
"Set the BRAVE_API_KEY environment variable to enable web search." "Set the BRAVE_API_KEY environment variable to enable web search."
) )
try: try:
async with httpx.AsyncClient(timeout=_BRAVE_TIMEOUT) as client: async with httpx.AsyncClient(timeout=_BRAVE_TIMEOUT) as client:
response = await client.get( response = await client.get(

View File

@@ -119,7 +119,15 @@ async def execute_tool(
return confirmation_msg return confirmation_msg
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# 5. Execute the handler # 5. Inject tenant context into args AFTER schema validation
# This ensures kb_search, calendar_lookup, and future context-aware
# tools receive tenant/agent context without the LLM providing it.
# ------------------------------------------------------------------
args["tenant_id"] = str(tenant_id)
args["agent_id"] = str(agent_id)
# ------------------------------------------------------------------
# 6. Execute the handler
# ------------------------------------------------------------------ # ------------------------------------------------------------------
start_ms = time.monotonic() start_ms = time.monotonic()
try: try:

View File

@@ -0,0 +1,322 @@
"""
Knowledge base document ingestion pipeline.
This module provides:
chunk_text() — sliding window text chunker
ingest_document_pipeline() — async pipeline: fetch → extract → chunk → embed → store
Pipeline steps:
1. Load KnowledgeBaseDocument from DB
2. Download file from MinIO (if filename) OR scrape URL / fetch YouTube transcript
3. Extract text using orchestrator.tools.extractors.extract_text
4. Chunk text with sliding window (500 chars, 50 overlap)
5. Batch embed chunks via all-MiniLM-L6-v2
6. INSERT kb_chunks rows with vector embeddings
7. UPDATE kb_documents SET status='ready', chunk_count=N
On any error: UPDATE kb_documents SET status='error', error_message=str(exc)
IMPORTANT: This module is called from a Celery task via asyncio.run(). All DB
and MinIO operations are async. The embedding call (embed_texts) is synchronous
(SentenceTransformer is sync) — this is fine inside asyncio.run().
"""
from __future__ import annotations
import logging
import uuid
from typing import Any
import boto3
from shared.config import settings
from shared.db import async_session_factory, engine
from shared.rls import configure_rls_hook, current_tenant_id
from orchestrator.memory.embedder import embed_texts
from orchestrator.tools.extractors import extract_text
logger = logging.getLogger(__name__)
# Default chunking parameters
_DEFAULT_CHUNK_SIZE = 500
_DEFAULT_OVERLAP = 50
def _get_minio_client() -> Any:
"""Create a boto3 S3 client pointed at MinIO."""
return boto3.client(
"s3",
endpoint_url=settings.minio_endpoint,
aws_access_key_id=settings.minio_access_key,
aws_secret_access_key=settings.minio_secret_key,
)
def chunk_text(
text: str,
chunk_size: int = _DEFAULT_CHUNK_SIZE,
overlap: int = _DEFAULT_OVERLAP,
) -> list[str]:
"""
Split text into overlapping chunks using a sliding window.
Args:
text: The text to chunk.
chunk_size: Maximum characters per chunk.
overlap: Number of characters to overlap between consecutive chunks.
Returns:
List of non-empty text chunks. Returns empty list for empty/whitespace text.
"""
text = text.strip()
if not text:
return []
if len(text) <= chunk_size:
return [text]
chunks: list[str] = []
start = 0
step = chunk_size - overlap
while start < len(text):
end = start + chunk_size
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
if end >= len(text):
break
start += step
return chunks
async def ingest_document_pipeline(document_id: str, tenant_id: str) -> None:
"""
Run the full document ingestion pipeline for a KB document.
Steps:
1. Load the KnowledgeBaseDocument from the database
2. Fetch content (MinIO file OR URL scrape OR YouTube transcript)
3. Extract plain text
4. Chunk text
5. Embed chunks
6. Store kb_chunks rows in the database
7. Mark document as 'ready'
On any error: set status='error' with error_message.
Args:
document_id: UUID string of the KnowledgeBaseDocument to process.
tenant_id: UUID string of the tenant (for RLS context).
"""
from sqlalchemy import select, text as sa_text
from shared.models.kb import KnowledgeBaseDocument
tenant_uuid = uuid.UUID(tenant_id)
doc_uuid = uuid.UUID(document_id)
configure_rls_hook(engine)
token = current_tenant_id.set(tenant_uuid)
try:
async with async_session_factory() as session:
result = await session.execute(
select(KnowledgeBaseDocument).where(
KnowledgeBaseDocument.id == doc_uuid
)
)
doc = result.scalar_one_or_none()
if doc is None:
logger.warning(
"ingest_document_pipeline: document %s not found, skipping",
document_id,
)
return
filename = doc.filename
source_url = doc.source_url
# ------------------------------------------------------------------
# Step 2: Fetch content
# ------------------------------------------------------------------
try:
file_bytes: bytes | None = None
extracted_text: str
if filename:
# Download from MinIO
bucket = settings.minio_kb_bucket
key = f"{tenant_id}/{document_id}/{filename}"
minio = _get_minio_client()
response = minio.get_object(Bucket=bucket, Key=key)
file_bytes = response.read()
extracted_text = extract_text(filename, file_bytes)
elif source_url:
extracted_text = await _fetch_url_content(source_url)
else:
raise ValueError("Document has neither filename nor source_url")
# ------------------------------------------------------------------
# Step 3-4: Chunk text
# ------------------------------------------------------------------
chunks = chunk_text(extracted_text)
if not chunks:
raise ValueError("No text content could be extracted from this document")
# ------------------------------------------------------------------
# Step 5: Embed chunks
# ------------------------------------------------------------------
embeddings = embed_texts(chunks)
# ------------------------------------------------------------------
# Step 6: Insert kb_chunks
# ------------------------------------------------------------------
# Delete any existing chunks for this document first
await session.execute(
sa_text("DELETE FROM kb_chunks WHERE document_id = :doc_id"),
{"doc_id": str(doc_uuid)},
)
for idx, (chunk_content, embedding) in enumerate(zip(chunks, embeddings)):
embedding_str = "[" + ",".join(str(x) for x in embedding) + "]"
await session.execute(
sa_text("""
INSERT INTO kb_chunks
(tenant_id, document_id, content, chunk_index, embedding)
VALUES
(:tenant_id, :document_id, :content, :chunk_index,
CAST(:embedding AS vector))
"""),
{
"tenant_id": str(tenant_uuid),
"document_id": str(doc_uuid),
"content": chunk_content,
"chunk_index": idx,
"embedding": embedding_str,
},
)
# ------------------------------------------------------------------
# Step 7: Mark document as ready
# ------------------------------------------------------------------
doc.status = "ready"
doc.chunk_count = len(chunks)
doc.error_message = None
await session.commit()
logger.info(
"ingest_document_pipeline: %s ingested %d chunks for document %s",
tenant_id,
len(chunks),
document_id,
)
except Exception as exc:
logger.exception(
"ingest_document_pipeline: error processing document %s: %s",
document_id,
exc,
)
# Try to mark document as error
try:
doc.status = "error"
doc.error_message = str(exc)
await session.commit()
except Exception:
logger.exception(
"ingest_document_pipeline: failed to mark document %s as error",
document_id,
)
finally:
current_tenant_id.reset(token)
async def _fetch_url_content(url: str) -> str:
"""
Fetch text content from a URL.
Supports:
- YouTube URLs (via youtube-transcript-api)
- Generic web pages (via firecrawl-py, graceful fallback if key not set)
"""
if _is_youtube_url(url):
return await _fetch_youtube_transcript(url)
else:
return await _scrape_web_url(url)
def _is_youtube_url(url: str) -> bool:
"""Return True if the URL is a YouTube video."""
return "youtube.com" in url or "youtu.be" in url
async def _fetch_youtube_transcript(url: str) -> str:
"""Fetch YouTube video transcript using youtube-transcript-api."""
try:
from youtube_transcript_api import YouTubeTranscriptApi
# Extract video ID from URL
video_id = _extract_youtube_id(url)
if not video_id:
raise ValueError(f"Could not extract YouTube video ID from URL: {url}")
transcript = YouTubeTranscriptApi.get_transcript(video_id)
return " ".join(entry["text"] for entry in transcript)
except Exception as exc:
raise ValueError(f"Failed to fetch YouTube transcript: {exc}") from exc
def _extract_youtube_id(url: str) -> str | None:
"""Extract YouTube video ID from various URL formats."""
import re
patterns = [
r"youtube\.com/watch\?v=([a-zA-Z0-9_-]+)",
r"youtu\.be/([a-zA-Z0-9_-]+)",
r"youtube\.com/embed/([a-zA-Z0-9_-]+)",
]
for pattern in patterns:
match = re.search(pattern, url)
if match:
return match.group(1)
return None
async def _scrape_web_url(url: str) -> str:
"""Scrape a web URL to markdown using firecrawl-py."""
if not settings.firecrawl_api_key:
# Fallback: try simple httpx fetch
return await _simple_fetch(url)
try:
from firecrawl import FirecrawlApp
app = FirecrawlApp(api_key=settings.firecrawl_api_key)
result = app.scrape_url(url, params={"formats": ["markdown"]})
if isinstance(result, dict):
return result.get("markdown", result.get("content", str(result)))
return str(result)
except Exception as exc:
logger.warning("Firecrawl failed for %s: %s — falling back to simple fetch", url, exc)
return await _simple_fetch(url)
async def _simple_fetch(url: str) -> str:
"""Simple httpx GET fetch as fallback for URL scraping."""
import httpx
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url, follow_redirects=True)
response.raise_for_status()
return response.text
except Exception as exc:
raise ValueError(f"Failed to fetch URL {url}: {exc}") from exc

View File

@@ -0,0 +1,186 @@
"""
Unit tests for executor tenant_id/agent_id injection.
Tests that execute_tool injects tenant_id and agent_id into handler kwargs
before calling the handler, so context-aware tools (kb_search, calendar_lookup)
receive tenant context without the LLM needing to provide it.
"""
from __future__ import annotations
import uuid
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
def _make_tool(handler: Any, requires_confirmation: bool = False) -> Any:
"""Create a minimal ToolDefinition-like object for tests."""
tool = MagicMock()
tool.handler = handler
tool.requires_confirmation = requires_confirmation
tool.parameters = {
"type": "object",
"properties": {
"query": {"type": "string"},
},
"required": ["query"],
}
return tool
class TestExecutorTenantInjection:
@pytest.mark.asyncio
async def test_tenant_id_injected_into_handler_kwargs(self) -> None:
"""Handler should receive tenant_id even though LLM didn't provide it."""
from orchestrator.tools.executor import execute_tool
received_kwargs: dict[str, Any] = {}
async def mock_handler(**kwargs: Any) -> str:
received_kwargs.update(kwargs)
return "handler result"
tool = _make_tool(mock_handler)
registry = {"test_tool": tool}
tenant_id = uuid.uuid4()
agent_id = uuid.uuid4()
audit_logger = MagicMock()
audit_logger.log_tool_call = AsyncMock()
tool_call = {
"function": {
"name": "test_tool",
"arguments": '{"query": "hello world"}',
}
}
result = await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
assert result == "handler result"
assert "tenant_id" in received_kwargs
assert received_kwargs["tenant_id"] == str(tenant_id)
@pytest.mark.asyncio
async def test_agent_id_injected_into_handler_kwargs(self) -> None:
"""Handler should receive agent_id even though LLM didn't provide it."""
from orchestrator.tools.executor import execute_tool
received_kwargs: dict[str, Any] = {}
async def mock_handler(**kwargs: Any) -> str:
received_kwargs.update(kwargs)
return "ok"
tool = _make_tool(mock_handler)
registry = {"test_tool": tool}
tenant_id = uuid.uuid4()
agent_id = uuid.uuid4()
audit_logger = MagicMock()
audit_logger.log_tool_call = AsyncMock()
tool_call = {
"function": {
"name": "test_tool",
"arguments": '{"query": "test"}',
}
}
await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
assert "agent_id" in received_kwargs
assert received_kwargs["agent_id"] == str(agent_id)
@pytest.mark.asyncio
async def test_injected_ids_are_strings(self) -> None:
"""Injected tenant_id and agent_id should be strings, not UUIDs."""
from orchestrator.tools.executor import execute_tool
received_kwargs: dict[str, Any] = {}
async def mock_handler(**kwargs: Any) -> str:
received_kwargs.update(kwargs)
return "ok"
tool = _make_tool(mock_handler)
registry = {"test_tool": tool}
tenant_id = uuid.uuid4()
agent_id = uuid.uuid4()
audit_logger = MagicMock()
audit_logger.log_tool_call = AsyncMock()
tool_call = {
"function": {
"name": "test_tool",
"arguments": '{"query": "test"}',
}
}
await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
assert isinstance(received_kwargs["tenant_id"], str)
assert isinstance(received_kwargs["agent_id"], str)
@pytest.mark.asyncio
async def test_llm_provided_args_preserved(self) -> None:
"""Original LLM-provided args should still be present after injection."""
from orchestrator.tools.executor import execute_tool
received_kwargs: dict[str, Any] = {}
async def mock_handler(**kwargs: Any) -> str:
received_kwargs.update(kwargs)
return "ok"
tool = _make_tool(mock_handler)
registry = {"test_tool": tool}
tenant_id = uuid.uuid4()
agent_id = uuid.uuid4()
audit_logger = MagicMock()
audit_logger.log_tool_call = AsyncMock()
tool_call = {
"function": {
"name": "test_tool",
"arguments": '{"query": "search term from LLM"}',
}
}
await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
assert received_kwargs["query"] == "search term from LLM"
assert received_kwargs["tenant_id"] == str(tenant_id)
assert received_kwargs["agent_id"] == str(agent_id)
@pytest.mark.asyncio
async def test_injection_after_schema_validation(self) -> None:
"""Injection happens after validation — injected keys don't cause schema failures."""
from orchestrator.tools.executor import execute_tool
# Tool requires exactly 'query', nothing else in schema required
# Schema should pass even though we inject tenant_id/agent_id
async def mock_handler(**kwargs: Any) -> str:
return "passed"
tool = _make_tool(mock_handler)
registry = {"test_tool": tool}
tenant_id = uuid.uuid4()
agent_id = uuid.uuid4()
audit_logger = MagicMock()
audit_logger.log_tool_call = AsyncMock()
tool_call = {
"function": {
"name": "test_tool",
"arguments": '{"query": "test"}',
}
}
result = await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
assert result == "passed"

View File

@@ -0,0 +1,183 @@
"""
Unit tests for the KB ingestion pipeline.
Tests:
- chunk_text: sliding window chunker produces correctly-sized, overlapping chunks
- ingest_document_pipeline: downloads file from MinIO, extracts, chunks, embeds, stores
- ingest_document_pipeline: sets status='error' on failure
"""
from __future__ import annotations
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
class TestChunkText:
def test_basic_chunking(self) -> None:
from orchestrator.tools.ingest import chunk_text
text = "a" * 1000
chunks = chunk_text(text, chunk_size=100, overlap=10)
assert len(chunks) > 0
for chunk in chunks:
assert len(chunk) <= 100
def test_overlap_between_chunks(self) -> None:
from orchestrator.tools.ingest import chunk_text
# Create text with identifiable segments
text = "AAAA" * 50 + "BBBB" * 50 # 400 chars
chunks = chunk_text(text, chunk_size=200, overlap=50)
# With overlap=50, consecutive chunks should share chars
assert len(chunks) >= 2
def test_short_text_returns_one_chunk(self) -> None:
from orchestrator.tools.ingest import chunk_text
text = "Hello world"
chunks = chunk_text(text, chunk_size=500, overlap=50)
assert len(chunks) == 1
assert chunks[0] == "Hello world"
def test_empty_text_returns_empty_list(self) -> None:
from orchestrator.tools.ingest import chunk_text
chunks = chunk_text("", chunk_size=500, overlap=50)
assert chunks == []
def test_whitespace_only_returns_empty_list(self) -> None:
from orchestrator.tools.ingest import chunk_text
chunks = chunk_text(" \n ", chunk_size=500, overlap=50)
assert chunks == []
def test_default_parameters(self) -> None:
from orchestrator.tools.ingest import chunk_text
text = "word " * 500 # 2500 chars
chunks = chunk_text(text)
assert len(chunks) > 1
# Default chunk_size is 500
for chunk in chunks:
assert len(chunk) <= 500
class TestIngestDocumentPipeline:
@pytest.mark.asyncio
async def test_file_upload_sets_status_ready(self) -> None:
"""Pipeline downloads file, extracts, chunks, embeds, stores, sets ready."""
from orchestrator.tools.ingest import ingest_document_pipeline
tenant_id = str(uuid.uuid4())
document_id = str(uuid.uuid4())
mock_doc = MagicMock()
mock_doc.id = uuid.UUID(document_id)
mock_doc.tenant_id = uuid.UUID(tenant_id)
mock_doc.filename = "test.txt"
mock_doc.source_url = None
mock_doc.status = "processing"
with (
patch("orchestrator.tools.ingest.async_session_factory") as mock_sf,
patch("orchestrator.tools.ingest.engine"),
patch("orchestrator.tools.ingest.configure_rls_hook"),
patch("orchestrator.tools.ingest.current_tenant_id"),
patch("orchestrator.tools.ingest._get_minio_client") as mock_minio,
patch("orchestrator.tools.ingest.extract_text", return_value="Test content " * 50) as mock_extract,
patch("orchestrator.tools.ingest.embed_texts", return_value=[[0.1] * 384]) as mock_embed,
):
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_doc
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
mock_sf.return_value = mock_session
# MinIO returns file bytes
minio_client = MagicMock()
response_obj = MagicMock()
response_obj.read.return_value = b"Test content " * 50
minio_client.get_object.return_value = response_obj
mock_minio.return_value = minio_client
await ingest_document_pipeline(document_id, tenant_id)
# Status should be set to 'ready' on the document
assert mock_doc.status == "ready"
assert mock_doc.chunk_count is not None
@pytest.mark.asyncio
async def test_pipeline_sets_error_on_exception(self) -> None:
"""Pipeline marks document as error when extraction fails."""
from orchestrator.tools.ingest import ingest_document_pipeline
tenant_id = str(uuid.uuid4())
document_id = str(uuid.uuid4())
mock_doc = MagicMock()
mock_doc.id = uuid.UUID(document_id)
mock_doc.tenant_id = uuid.UUID(tenant_id)
mock_doc.filename = "test.txt"
mock_doc.source_url = None
mock_doc.status = "processing"
with (
patch("orchestrator.tools.ingest.async_session_factory") as mock_sf,
patch("orchestrator.tools.ingest.engine"),
patch("orchestrator.tools.ingest.configure_rls_hook"),
patch("orchestrator.tools.ingest.current_tenant_id"),
patch("orchestrator.tools.ingest._get_minio_client") as mock_minio,
):
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_doc
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
mock_sf.return_value = mock_session
# MinIO raises an error
minio_client = MagicMock()
minio_client.get_object.side_effect = Exception("MinIO connection failed")
mock_minio.return_value = minio_client
await ingest_document_pipeline(document_id, tenant_id)
assert mock_doc.status == "error"
assert mock_doc.error_message is not None
@pytest.mark.asyncio
async def test_document_not_found_is_no_op(self) -> None:
"""If document doesn't exist, pipeline exits gracefully."""
from orchestrator.tools.ingest import ingest_document_pipeline
tenant_id = str(uuid.uuid4())
document_id = str(uuid.uuid4())
with (
patch("orchestrator.tools.ingest.async_session_factory") as mock_sf,
patch("orchestrator.tools.ingest.engine"),
patch("orchestrator.tools.ingest.configure_rls_hook"),
patch("orchestrator.tools.ingest.current_tenant_id"),
):
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None # Not found
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
mock_sf.return_value = mock_session
# Should not raise
await ingest_document_pipeline(document_id, tenant_id)