feat(10-01): Celery ingestion task, executor injection, KB search wiring

- Add ingest_document Celery task (sync def + asyncio.run per arch constraint)
- Add ingest_document_pipeline: MinIO download, extract, chunk, embed, store
- Add chunk_text sliding window chunker (500 chars default, 50 overlap)
- Update execute_tool to inject tenant_id/agent_id into all tool handler kwargs
- Update web_search to use settings.brave_api_key (shared config) not os.getenv
- Unit tests: test_ingestion.py (9 tests) and test_executor_injection.py (5 tests) all pass
This commit is contained in:
2026-03-26 09:09:36 -06:00
parent 08572fcc40
commit 9c7686a7b4
6 changed files with 748 additions and 4 deletions

View File

@@ -0,0 +1,186 @@
"""
Unit tests for executor tenant_id/agent_id injection.
Tests that execute_tool injects tenant_id and agent_id into handler kwargs
before calling the handler, so context-aware tools (kb_search, calendar_lookup)
receive tenant context without the LLM needing to provide it.
"""
from __future__ import annotations
import uuid
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
def _make_tool(handler: Any, requires_confirmation: bool = False) -> Any:
"""Create a minimal ToolDefinition-like object for tests."""
tool = MagicMock()
tool.handler = handler
tool.requires_confirmation = requires_confirmation
tool.parameters = {
"type": "object",
"properties": {
"query": {"type": "string"},
},
"required": ["query"],
}
return tool
class TestExecutorTenantInjection:
@pytest.mark.asyncio
async def test_tenant_id_injected_into_handler_kwargs(self) -> None:
"""Handler should receive tenant_id even though LLM didn't provide it."""
from orchestrator.tools.executor import execute_tool
received_kwargs: dict[str, Any] = {}
async def mock_handler(**kwargs: Any) -> str:
received_kwargs.update(kwargs)
return "handler result"
tool = _make_tool(mock_handler)
registry = {"test_tool": tool}
tenant_id = uuid.uuid4()
agent_id = uuid.uuid4()
audit_logger = MagicMock()
audit_logger.log_tool_call = AsyncMock()
tool_call = {
"function": {
"name": "test_tool",
"arguments": '{"query": "hello world"}',
}
}
result = await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
assert result == "handler result"
assert "tenant_id" in received_kwargs
assert received_kwargs["tenant_id"] == str(tenant_id)
@pytest.mark.asyncio
async def test_agent_id_injected_into_handler_kwargs(self) -> None:
"""Handler should receive agent_id even though LLM didn't provide it."""
from orchestrator.tools.executor import execute_tool
received_kwargs: dict[str, Any] = {}
async def mock_handler(**kwargs: Any) -> str:
received_kwargs.update(kwargs)
return "ok"
tool = _make_tool(mock_handler)
registry = {"test_tool": tool}
tenant_id = uuid.uuid4()
agent_id = uuid.uuid4()
audit_logger = MagicMock()
audit_logger.log_tool_call = AsyncMock()
tool_call = {
"function": {
"name": "test_tool",
"arguments": '{"query": "test"}',
}
}
await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
assert "agent_id" in received_kwargs
assert received_kwargs["agent_id"] == str(agent_id)
@pytest.mark.asyncio
async def test_injected_ids_are_strings(self) -> None:
"""Injected tenant_id and agent_id should be strings, not UUIDs."""
from orchestrator.tools.executor import execute_tool
received_kwargs: dict[str, Any] = {}
async def mock_handler(**kwargs: Any) -> str:
received_kwargs.update(kwargs)
return "ok"
tool = _make_tool(mock_handler)
registry = {"test_tool": tool}
tenant_id = uuid.uuid4()
agent_id = uuid.uuid4()
audit_logger = MagicMock()
audit_logger.log_tool_call = AsyncMock()
tool_call = {
"function": {
"name": "test_tool",
"arguments": '{"query": "test"}',
}
}
await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
assert isinstance(received_kwargs["tenant_id"], str)
assert isinstance(received_kwargs["agent_id"], str)
@pytest.mark.asyncio
async def test_llm_provided_args_preserved(self) -> None:
"""Original LLM-provided args should still be present after injection."""
from orchestrator.tools.executor import execute_tool
received_kwargs: dict[str, Any] = {}
async def mock_handler(**kwargs: Any) -> str:
received_kwargs.update(kwargs)
return "ok"
tool = _make_tool(mock_handler)
registry = {"test_tool": tool}
tenant_id = uuid.uuid4()
agent_id = uuid.uuid4()
audit_logger = MagicMock()
audit_logger.log_tool_call = AsyncMock()
tool_call = {
"function": {
"name": "test_tool",
"arguments": '{"query": "search term from LLM"}',
}
}
await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
assert received_kwargs["query"] == "search term from LLM"
assert received_kwargs["tenant_id"] == str(tenant_id)
assert received_kwargs["agent_id"] == str(agent_id)
@pytest.mark.asyncio
async def test_injection_after_schema_validation(self) -> None:
"""Injection happens after validation — injected keys don't cause schema failures."""
from orchestrator.tools.executor import execute_tool
# Tool requires exactly 'query', nothing else in schema required
# Schema should pass even though we inject tenant_id/agent_id
async def mock_handler(**kwargs: Any) -> str:
return "passed"
tool = _make_tool(mock_handler)
registry = {"test_tool": tool}
tenant_id = uuid.uuid4()
agent_id = uuid.uuid4()
audit_logger = MagicMock()
audit_logger.log_tool_call = AsyncMock()
tool_call = {
"function": {
"name": "test_tool",
"arguments": '{"query": "test"}',
}
}
result = await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
assert result == "passed"

View File

@@ -0,0 +1,183 @@
"""
Unit tests for the KB ingestion pipeline.
Tests:
- chunk_text: sliding window chunker produces correctly-sized, overlapping chunks
- ingest_document_pipeline: downloads file from MinIO, extracts, chunks, embeds, stores
- ingest_document_pipeline: sets status='error' on failure
"""
from __future__ import annotations
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
class TestChunkText:
def test_basic_chunking(self) -> None:
from orchestrator.tools.ingest import chunk_text
text = "a" * 1000
chunks = chunk_text(text, chunk_size=100, overlap=10)
assert len(chunks) > 0
for chunk in chunks:
assert len(chunk) <= 100
def test_overlap_between_chunks(self) -> None:
from orchestrator.tools.ingest import chunk_text
# Create text with identifiable segments
text = "AAAA" * 50 + "BBBB" * 50 # 400 chars
chunks = chunk_text(text, chunk_size=200, overlap=50)
# With overlap=50, consecutive chunks should share chars
assert len(chunks) >= 2
def test_short_text_returns_one_chunk(self) -> None:
from orchestrator.tools.ingest import chunk_text
text = "Hello world"
chunks = chunk_text(text, chunk_size=500, overlap=50)
assert len(chunks) == 1
assert chunks[0] == "Hello world"
def test_empty_text_returns_empty_list(self) -> None:
from orchestrator.tools.ingest import chunk_text
chunks = chunk_text("", chunk_size=500, overlap=50)
assert chunks == []
def test_whitespace_only_returns_empty_list(self) -> None:
from orchestrator.tools.ingest import chunk_text
chunks = chunk_text(" \n ", chunk_size=500, overlap=50)
assert chunks == []
def test_default_parameters(self) -> None:
from orchestrator.tools.ingest import chunk_text
text = "word " * 500 # 2500 chars
chunks = chunk_text(text)
assert len(chunks) > 1
# Default chunk_size is 500
for chunk in chunks:
assert len(chunk) <= 500
class TestIngestDocumentPipeline:
@pytest.mark.asyncio
async def test_file_upload_sets_status_ready(self) -> None:
"""Pipeline downloads file, extracts, chunks, embeds, stores, sets ready."""
from orchestrator.tools.ingest import ingest_document_pipeline
tenant_id = str(uuid.uuid4())
document_id = str(uuid.uuid4())
mock_doc = MagicMock()
mock_doc.id = uuid.UUID(document_id)
mock_doc.tenant_id = uuid.UUID(tenant_id)
mock_doc.filename = "test.txt"
mock_doc.source_url = None
mock_doc.status = "processing"
with (
patch("orchestrator.tools.ingest.async_session_factory") as mock_sf,
patch("orchestrator.tools.ingest.engine"),
patch("orchestrator.tools.ingest.configure_rls_hook"),
patch("orchestrator.tools.ingest.current_tenant_id"),
patch("orchestrator.tools.ingest._get_minio_client") as mock_minio,
patch("orchestrator.tools.ingest.extract_text", return_value="Test content " * 50) as mock_extract,
patch("orchestrator.tools.ingest.embed_texts", return_value=[[0.1] * 384]) as mock_embed,
):
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_doc
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
mock_sf.return_value = mock_session
# MinIO returns file bytes
minio_client = MagicMock()
response_obj = MagicMock()
response_obj.read.return_value = b"Test content " * 50
minio_client.get_object.return_value = response_obj
mock_minio.return_value = minio_client
await ingest_document_pipeline(document_id, tenant_id)
# Status should be set to 'ready' on the document
assert mock_doc.status == "ready"
assert mock_doc.chunk_count is not None
@pytest.mark.asyncio
async def test_pipeline_sets_error_on_exception(self) -> None:
"""Pipeline marks document as error when extraction fails."""
from orchestrator.tools.ingest import ingest_document_pipeline
tenant_id = str(uuid.uuid4())
document_id = str(uuid.uuid4())
mock_doc = MagicMock()
mock_doc.id = uuid.UUID(document_id)
mock_doc.tenant_id = uuid.UUID(tenant_id)
mock_doc.filename = "test.txt"
mock_doc.source_url = None
mock_doc.status = "processing"
with (
patch("orchestrator.tools.ingest.async_session_factory") as mock_sf,
patch("orchestrator.tools.ingest.engine"),
patch("orchestrator.tools.ingest.configure_rls_hook"),
patch("orchestrator.tools.ingest.current_tenant_id"),
patch("orchestrator.tools.ingest._get_minio_client") as mock_minio,
):
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_doc
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
mock_sf.return_value = mock_session
# MinIO raises an error
minio_client = MagicMock()
minio_client.get_object.side_effect = Exception("MinIO connection failed")
mock_minio.return_value = minio_client
await ingest_document_pipeline(document_id, tenant_id)
assert mock_doc.status == "error"
assert mock_doc.error_message is not None
@pytest.mark.asyncio
async def test_document_not_found_is_no_op(self) -> None:
"""If document doesn't exist, pipeline exits gracefully."""
from orchestrator.tools.ingest import ingest_document_pipeline
tenant_id = str(uuid.uuid4())
document_id = str(uuid.uuid4())
with (
patch("orchestrator.tools.ingest.async_session_factory") as mock_sf,
patch("orchestrator.tools.ingest.engine"),
patch("orchestrator.tools.ingest.configure_rls_hook"),
patch("orchestrator.tools.ingest.current_tenant_id"),
):
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None # Not found
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
mock_sf.return_value = mock_session
# Should not raise
await ingest_document_pipeline(document_id, tenant_id)