feat(10-01): Celery ingestion task, executor injection, KB search wiring
- Add ingest_document Celery task (sync def + asyncio.run per arch constraint) - Add ingest_document_pipeline: MinIO download, extract, chunk, embed, store - Add chunk_text sliding window chunker (500 chars default, 50 overlap) - Update execute_tool to inject tenant_id/agent_id into all tool handler kwargs - Update web_search to use settings.brave_api_key (shared config) not os.getenv - Unit tests: test_ingestion.py (9 tests) and test_executor_injection.py (5 tests) all pass
This commit is contained in:
@@ -997,3 +997,45 @@ async def _update_slack_placeholder(
|
|||||||
channel_id,
|
channel_id,
|
||||||
placeholder_ts,
|
placeholder_ts,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# KB Document Ingestion Task
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(
|
||||||
|
name="orchestrator.tasks.ingest_document",
|
||||||
|
bind=True,
|
||||||
|
max_retries=2,
|
||||||
|
default_retry_delay=60,
|
||||||
|
ignore_result=True,
|
||||||
|
)
|
||||||
|
def ingest_document(self, document_id: str, tenant_id: str) -> None: # type: ignore[override]
|
||||||
|
"""
|
||||||
|
Celery task: run the KB document ingestion pipeline.
|
||||||
|
|
||||||
|
Downloads the document from MinIO (or scrapes URL/YouTube), extracts text,
|
||||||
|
chunks, embeds with all-MiniLM-L6-v2, and stores kb_chunks rows.
|
||||||
|
|
||||||
|
Updates kb_documents.status to 'ready' on success, 'error' on failure.
|
||||||
|
|
||||||
|
MUST be sync def — Celery workers are not async-native. asyncio.run() is
|
||||||
|
used to bridge the sync Celery world to the async pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: UUID string of the KnowledgeBaseDocument row.
|
||||||
|
tenant_id: UUID string of the owning tenant.
|
||||||
|
"""
|
||||||
|
from orchestrator.tools.ingest import ingest_document_pipeline
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(ingest_document_pipeline(document_id, tenant_id))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception(
|
||||||
|
"ingest_document task failed for document=%s tenant=%s: %s",
|
||||||
|
document_id,
|
||||||
|
tenant_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
self.retry(exc=exc, countdown=60)
|
||||||
|
|||||||
@@ -13,10 +13,11 @@ raising an exception (graceful degradation for agents without search configured)
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_BRAVE_API_URL = "https://api.search.brave.com/res/v1/web/search"
|
_BRAVE_API_URL = "https://api.search.brave.com/res/v1/web/search"
|
||||||
@@ -24,24 +25,26 @@ _BRAVE_TIMEOUT = httpx.Timeout(timeout=15.0, connect=5.0)
|
|||||||
_MAX_RESULTS = 3
|
_MAX_RESULTS = 3
|
||||||
|
|
||||||
|
|
||||||
async def web_search(query: str) -> str:
|
async def web_search(query: str, **kwargs: object) -> str:
|
||||||
"""
|
"""
|
||||||
Search the web using Brave Search API.
|
Search the web using Brave Search API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The search query string.
|
query: The search query string.
|
||||||
|
**kwargs: Accepts injected tenant_id/agent_id from executor (unused).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Formatted string with top 3 search results (title + URL + description),
|
Formatted string with top 3 search results (title + URL + description),
|
||||||
or an error message if the API is unavailable.
|
or an error message if the API is unavailable.
|
||||||
"""
|
"""
|
||||||
api_key = os.getenv("BRAVE_API_KEY", "")
|
api_key = settings.brave_api_key
|
||||||
if not api_key:
|
if not api_key:
|
||||||
return (
|
return (
|
||||||
"Web search is not configured. "
|
"Web search is not configured. "
|
||||||
"Set the BRAVE_API_KEY environment variable to enable web search."
|
"Set the BRAVE_API_KEY environment variable to enable web search."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=_BRAVE_TIMEOUT) as client:
|
async with httpx.AsyncClient(timeout=_BRAVE_TIMEOUT) as client:
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
|
|||||||
@@ -119,7 +119,15 @@ async def execute_tool(
|
|||||||
return confirmation_msg
|
return confirmation_msg
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 5. Execute the handler
|
# 5. Inject tenant context into args AFTER schema validation
|
||||||
|
# This ensures kb_search, calendar_lookup, and future context-aware
|
||||||
|
# tools receive tenant/agent context without the LLM providing it.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
args["tenant_id"] = str(tenant_id)
|
||||||
|
args["agent_id"] = str(agent_id)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 6. Execute the handler
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
start_ms = time.monotonic()
|
start_ms = time.monotonic()
|
||||||
try:
|
try:
|
||||||
|
|||||||
322
packages/orchestrator/orchestrator/tools/ingest.py
Normal file
322
packages/orchestrator/orchestrator/tools/ingest.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
"""
|
||||||
|
Knowledge base document ingestion pipeline.
|
||||||
|
|
||||||
|
This module provides:
|
||||||
|
chunk_text() — sliding window text chunker
|
||||||
|
ingest_document_pipeline() — async pipeline: fetch → extract → chunk → embed → store
|
||||||
|
|
||||||
|
Pipeline steps:
|
||||||
|
1. Load KnowledgeBaseDocument from DB
|
||||||
|
2. Download file from MinIO (if filename) OR scrape URL / fetch YouTube transcript
|
||||||
|
3. Extract text using orchestrator.tools.extractors.extract_text
|
||||||
|
4. Chunk text with sliding window (500 chars, 50 overlap)
|
||||||
|
5. Batch embed chunks via all-MiniLM-L6-v2
|
||||||
|
6. INSERT kb_chunks rows with vector embeddings
|
||||||
|
7. UPDATE kb_documents SET status='ready', chunk_count=N
|
||||||
|
|
||||||
|
On any error: UPDATE kb_documents SET status='error', error_message=str(exc)
|
||||||
|
|
||||||
|
IMPORTANT: This module is called from a Celery task via asyncio.run(). All DB
|
||||||
|
and MinIO operations are async. The embedding call (embed_texts) is synchronous
|
||||||
|
(SentenceTransformer is sync) — this is fine inside asyncio.run().
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.db import async_session_factory, engine
|
||||||
|
from shared.rls import configure_rls_hook, current_tenant_id
|
||||||
|
|
||||||
|
from orchestrator.memory.embedder import embed_texts
|
||||||
|
from orchestrator.tools.extractors import extract_text
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Default chunking parameters
|
||||||
|
_DEFAULT_CHUNK_SIZE = 500
|
||||||
|
_DEFAULT_OVERLAP = 50
|
||||||
|
|
||||||
|
|
||||||
|
def _get_minio_client() -> Any:
|
||||||
|
"""Create a boto3 S3 client pointed at MinIO."""
|
||||||
|
return boto3.client(
|
||||||
|
"s3",
|
||||||
|
endpoint_url=settings.minio_endpoint,
|
||||||
|
aws_access_key_id=settings.minio_access_key,
|
||||||
|
aws_secret_access_key=settings.minio_secret_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_text(
|
||||||
|
text: str,
|
||||||
|
chunk_size: int = _DEFAULT_CHUNK_SIZE,
|
||||||
|
overlap: int = _DEFAULT_OVERLAP,
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Split text into overlapping chunks using a sliding window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to chunk.
|
||||||
|
chunk_size: Maximum characters per chunk.
|
||||||
|
overlap: Number of characters to overlap between consecutive chunks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of non-empty text chunks. Returns empty list for empty/whitespace text.
|
||||||
|
"""
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if len(text) <= chunk_size:
|
||||||
|
return [text]
|
||||||
|
|
||||||
|
chunks: list[str] = []
|
||||||
|
start = 0
|
||||||
|
step = chunk_size - overlap
|
||||||
|
|
||||||
|
while start < len(text):
|
||||||
|
end = start + chunk_size
|
||||||
|
chunk = text[start:end].strip()
|
||||||
|
if chunk:
|
||||||
|
chunks.append(chunk)
|
||||||
|
if end >= len(text):
|
||||||
|
break
|
||||||
|
start += step
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
async def ingest_document_pipeline(document_id: str, tenant_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Run the full document ingestion pipeline for a KB document.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Load the KnowledgeBaseDocument from the database
|
||||||
|
2. Fetch content (MinIO file OR URL scrape OR YouTube transcript)
|
||||||
|
3. Extract plain text
|
||||||
|
4. Chunk text
|
||||||
|
5. Embed chunks
|
||||||
|
6. Store kb_chunks rows in the database
|
||||||
|
7. Mark document as 'ready'
|
||||||
|
|
||||||
|
On any error: set status='error' with error_message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: UUID string of the KnowledgeBaseDocument to process.
|
||||||
|
tenant_id: UUID string of the tenant (for RLS context).
|
||||||
|
"""
|
||||||
|
from sqlalchemy import select, text as sa_text
|
||||||
|
|
||||||
|
from shared.models.kb import KnowledgeBaseDocument
|
||||||
|
|
||||||
|
tenant_uuid = uuid.UUID(tenant_id)
|
||||||
|
doc_uuid = uuid.UUID(document_id)
|
||||||
|
|
||||||
|
configure_rls_hook(engine)
|
||||||
|
token = current_tenant_id.set(tenant_uuid)
|
||||||
|
try:
|
||||||
|
async with async_session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(KnowledgeBaseDocument).where(
|
||||||
|
KnowledgeBaseDocument.id == doc_uuid
|
||||||
|
)
|
||||||
|
)
|
||||||
|
doc = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if doc is None:
|
||||||
|
logger.warning(
|
||||||
|
"ingest_document_pipeline: document %s not found, skipping",
|
||||||
|
document_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
filename = doc.filename
|
||||||
|
source_url = doc.source_url
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step 2: Fetch content
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
try:
|
||||||
|
file_bytes: bytes | None = None
|
||||||
|
extracted_text: str
|
||||||
|
|
||||||
|
if filename:
|
||||||
|
# Download from MinIO
|
||||||
|
bucket = settings.minio_kb_bucket
|
||||||
|
key = f"{tenant_id}/{document_id}/{filename}"
|
||||||
|
minio = _get_minio_client()
|
||||||
|
response = minio.get_object(Bucket=bucket, Key=key)
|
||||||
|
file_bytes = response.read()
|
||||||
|
extracted_text = extract_text(filename, file_bytes)
|
||||||
|
|
||||||
|
elif source_url:
|
||||||
|
extracted_text = await _fetch_url_content(source_url)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Document has neither filename nor source_url")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step 3-4: Chunk text
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
chunks = chunk_text(extracted_text)
|
||||||
|
if not chunks:
|
||||||
|
raise ValueError("No text content could be extracted from this document")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step 5: Embed chunks
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
embeddings = embed_texts(chunks)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step 6: Insert kb_chunks
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Delete any existing chunks for this document first
|
||||||
|
await session.execute(
|
||||||
|
sa_text("DELETE FROM kb_chunks WHERE document_id = :doc_id"),
|
||||||
|
{"doc_id": str(doc_uuid)},
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, (chunk_content, embedding) in enumerate(zip(chunks, embeddings)):
|
||||||
|
embedding_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
||||||
|
await session.execute(
|
||||||
|
sa_text("""
|
||||||
|
INSERT INTO kb_chunks
|
||||||
|
(tenant_id, document_id, content, chunk_index, embedding)
|
||||||
|
VALUES
|
||||||
|
(:tenant_id, :document_id, :content, :chunk_index,
|
||||||
|
CAST(:embedding AS vector))
|
||||||
|
"""),
|
||||||
|
{
|
||||||
|
"tenant_id": str(tenant_uuid),
|
||||||
|
"document_id": str(doc_uuid),
|
||||||
|
"content": chunk_content,
|
||||||
|
"chunk_index": idx,
|
||||||
|
"embedding": embedding_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step 7: Mark document as ready
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
doc.status = "ready"
|
||||||
|
doc.chunk_count = len(chunks)
|
||||||
|
doc.error_message = None
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"ingest_document_pipeline: %s ingested %d chunks for document %s",
|
||||||
|
tenant_id,
|
||||||
|
len(chunks),
|
||||||
|
document_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception(
|
||||||
|
"ingest_document_pipeline: error processing document %s: %s",
|
||||||
|
document_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
# Try to mark document as error
|
||||||
|
try:
|
||||||
|
doc.status = "error"
|
||||||
|
doc.error_message = str(exc)
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"ingest_document_pipeline: failed to mark document %s as error",
|
||||||
|
document_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
current_tenant_id.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_url_content(url: str) -> str:
|
||||||
|
"""
|
||||||
|
Fetch text content from a URL.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- YouTube URLs (via youtube-transcript-api)
|
||||||
|
- Generic web pages (via firecrawl-py, graceful fallback if key not set)
|
||||||
|
"""
|
||||||
|
if _is_youtube_url(url):
|
||||||
|
return await _fetch_youtube_transcript(url)
|
||||||
|
else:
|
||||||
|
return await _scrape_web_url(url)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_youtube_url(url: str) -> bool:
|
||||||
|
"""Return True if the URL is a YouTube video."""
|
||||||
|
return "youtube.com" in url or "youtu.be" in url
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_youtube_transcript(url: str) -> str:
|
||||||
|
"""Fetch YouTube video transcript using youtube-transcript-api."""
|
||||||
|
try:
|
||||||
|
from youtube_transcript_api import YouTubeTranscriptApi
|
||||||
|
|
||||||
|
# Extract video ID from URL
|
||||||
|
video_id = _extract_youtube_id(url)
|
||||||
|
if not video_id:
|
||||||
|
raise ValueError(f"Could not extract YouTube video ID from URL: {url}")
|
||||||
|
|
||||||
|
transcript = YouTubeTranscriptApi.get_transcript(video_id)
|
||||||
|
return " ".join(entry["text"] for entry in transcript)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValueError(f"Failed to fetch YouTube transcript: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_youtube_id(url: str) -> str | None:
|
||||||
|
"""Extract YouTube video ID from various URL formats."""
|
||||||
|
import re
|
||||||
|
|
||||||
|
patterns = [
|
||||||
|
r"youtube\.com/watch\?v=([a-zA-Z0-9_-]+)",
|
||||||
|
r"youtu\.be/([a-zA-Z0-9_-]+)",
|
||||||
|
r"youtube\.com/embed/([a-zA-Z0-9_-]+)",
|
||||||
|
]
|
||||||
|
for pattern in patterns:
|
||||||
|
match = re.search(pattern, url)
|
||||||
|
if match:
|
||||||
|
return match.group(1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _scrape_web_url(url: str) -> str:
|
||||||
|
"""Scrape a web URL to markdown using firecrawl-py."""
|
||||||
|
if not settings.firecrawl_api_key:
|
||||||
|
# Fallback: try simple httpx fetch
|
||||||
|
return await _simple_fetch(url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from firecrawl import FirecrawlApp
|
||||||
|
|
||||||
|
app = FirecrawlApp(api_key=settings.firecrawl_api_key)
|
||||||
|
result = app.scrape_url(url, params={"formats": ["markdown"]})
|
||||||
|
if isinstance(result, dict):
|
||||||
|
return result.get("markdown", result.get("content", str(result)))
|
||||||
|
return str(result)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Firecrawl failed for %s: %s — falling back to simple fetch", url, exc)
|
||||||
|
return await _simple_fetch(url)
|
||||||
|
|
||||||
|
|
||||||
|
async def _simple_fetch(url: str) -> str:
|
||||||
|
"""Simple httpx GET fetch as fallback for URL scraping."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.get(url, follow_redirects=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.text
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValueError(f"Failed to fetch URL {url}: {exc}") from exc
|
||||||
186
tests/unit/test_executor_injection.py
Normal file
186
tests/unit/test_executor_injection.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for executor tenant_id/agent_id injection.
|
||||||
|
|
||||||
|
Tests that execute_tool injects tenant_id and agent_id into handler kwargs
|
||||||
|
before calling the handler, so context-aware tools (kb_search, calendar_lookup)
|
||||||
|
receive tenant context without the LLM needing to provide it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tool(handler: Any, requires_confirmation: bool = False) -> Any:
|
||||||
|
"""Create a minimal ToolDefinition-like object for tests."""
|
||||||
|
tool = MagicMock()
|
||||||
|
tool.handler = handler
|
||||||
|
tool.requires_confirmation = requires_confirmation
|
||||||
|
tool.parameters = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
}
|
||||||
|
return tool
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecutorTenantInjection:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tenant_id_injected_into_handler_kwargs(self) -> None:
|
||||||
|
"""Handler should receive tenant_id even though LLM didn't provide it."""
|
||||||
|
from orchestrator.tools.executor import execute_tool
|
||||||
|
|
||||||
|
received_kwargs: dict[str, Any] = {}
|
||||||
|
|
||||||
|
async def mock_handler(**kwargs: Any) -> str:
|
||||||
|
received_kwargs.update(kwargs)
|
||||||
|
return "handler result"
|
||||||
|
|
||||||
|
tool = _make_tool(mock_handler)
|
||||||
|
registry = {"test_tool": tool}
|
||||||
|
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
agent_id = uuid.uuid4()
|
||||||
|
audit_logger = MagicMock()
|
||||||
|
audit_logger.log_tool_call = AsyncMock()
|
||||||
|
|
||||||
|
tool_call = {
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": '{"query": "hello world"}',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
|
||||||
|
|
||||||
|
assert result == "handler result"
|
||||||
|
assert "tenant_id" in received_kwargs
|
||||||
|
assert received_kwargs["tenant_id"] == str(tenant_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_id_injected_into_handler_kwargs(self) -> None:
|
||||||
|
"""Handler should receive agent_id even though LLM didn't provide it."""
|
||||||
|
from orchestrator.tools.executor import execute_tool
|
||||||
|
|
||||||
|
received_kwargs: dict[str, Any] = {}
|
||||||
|
|
||||||
|
async def mock_handler(**kwargs: Any) -> str:
|
||||||
|
received_kwargs.update(kwargs)
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
tool = _make_tool(mock_handler)
|
||||||
|
registry = {"test_tool": tool}
|
||||||
|
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
agent_id = uuid.uuid4()
|
||||||
|
audit_logger = MagicMock()
|
||||||
|
audit_logger.log_tool_call = AsyncMock()
|
||||||
|
|
||||||
|
tool_call = {
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": '{"query": "test"}',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
|
||||||
|
|
||||||
|
assert "agent_id" in received_kwargs
|
||||||
|
assert received_kwargs["agent_id"] == str(agent_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_injected_ids_are_strings(self) -> None:
|
||||||
|
"""Injected tenant_id and agent_id should be strings, not UUIDs."""
|
||||||
|
from orchestrator.tools.executor import execute_tool
|
||||||
|
|
||||||
|
received_kwargs: dict[str, Any] = {}
|
||||||
|
|
||||||
|
async def mock_handler(**kwargs: Any) -> str:
|
||||||
|
received_kwargs.update(kwargs)
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
tool = _make_tool(mock_handler)
|
||||||
|
registry = {"test_tool": tool}
|
||||||
|
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
agent_id = uuid.uuid4()
|
||||||
|
audit_logger = MagicMock()
|
||||||
|
audit_logger.log_tool_call = AsyncMock()
|
||||||
|
|
||||||
|
tool_call = {
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": '{"query": "test"}',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
|
||||||
|
|
||||||
|
assert isinstance(received_kwargs["tenant_id"], str)
|
||||||
|
assert isinstance(received_kwargs["agent_id"], str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_provided_args_preserved(self) -> None:
|
||||||
|
"""Original LLM-provided args should still be present after injection."""
|
||||||
|
from orchestrator.tools.executor import execute_tool
|
||||||
|
|
||||||
|
received_kwargs: dict[str, Any] = {}
|
||||||
|
|
||||||
|
async def mock_handler(**kwargs: Any) -> str:
|
||||||
|
received_kwargs.update(kwargs)
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
tool = _make_tool(mock_handler)
|
||||||
|
registry = {"test_tool": tool}
|
||||||
|
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
agent_id = uuid.uuid4()
|
||||||
|
audit_logger = MagicMock()
|
||||||
|
audit_logger.log_tool_call = AsyncMock()
|
||||||
|
|
||||||
|
tool_call = {
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": '{"query": "search term from LLM"}',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
|
||||||
|
|
||||||
|
assert received_kwargs["query"] == "search term from LLM"
|
||||||
|
assert received_kwargs["tenant_id"] == str(tenant_id)
|
||||||
|
assert received_kwargs["agent_id"] == str(agent_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_injection_after_schema_validation(self) -> None:
|
||||||
|
"""Injection happens after validation — injected keys don't cause schema failures."""
|
||||||
|
from orchestrator.tools.executor import execute_tool
|
||||||
|
|
||||||
|
# Tool requires exactly 'query', nothing else in schema required
|
||||||
|
# Schema should pass even though we inject tenant_id/agent_id
|
||||||
|
async def mock_handler(**kwargs: Any) -> str:
|
||||||
|
return "passed"
|
||||||
|
|
||||||
|
tool = _make_tool(mock_handler)
|
||||||
|
registry = {"test_tool": tool}
|
||||||
|
|
||||||
|
tenant_id = uuid.uuid4()
|
||||||
|
agent_id = uuid.uuid4()
|
||||||
|
audit_logger = MagicMock()
|
||||||
|
audit_logger.log_tool_call = AsyncMock()
|
||||||
|
|
||||||
|
tool_call = {
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": '{"query": "test"}',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
|
||||||
|
assert result == "passed"
|
||||||
183
tests/unit/test_ingestion.py
Normal file
183
tests/unit/test_ingestion.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the KB ingestion pipeline.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- chunk_text: sliding window chunker produces correctly-sized, overlapping chunks
|
||||||
|
- ingest_document_pipeline: downloads file from MinIO, extracts, chunks, embeds, stores
|
||||||
|
- ingest_document_pipeline: sets status='error' on failure
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunkText:
|
||||||
|
def test_basic_chunking(self) -> None:
|
||||||
|
from orchestrator.tools.ingest import chunk_text
|
||||||
|
|
||||||
|
text = "a" * 1000
|
||||||
|
chunks = chunk_text(text, chunk_size=100, overlap=10)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
for chunk in chunks:
|
||||||
|
assert len(chunk) <= 100
|
||||||
|
|
||||||
|
def test_overlap_between_chunks(self) -> None:
|
||||||
|
from orchestrator.tools.ingest import chunk_text
|
||||||
|
|
||||||
|
# Create text with identifiable segments
|
||||||
|
text = "AAAA" * 50 + "BBBB" * 50 # 400 chars
|
||||||
|
chunks = chunk_text(text, chunk_size=200, overlap=50)
|
||||||
|
|
||||||
|
# With overlap=50, consecutive chunks should share chars
|
||||||
|
assert len(chunks) >= 2
|
||||||
|
|
||||||
|
def test_short_text_returns_one_chunk(self) -> None:
|
||||||
|
from orchestrator.tools.ingest import chunk_text
|
||||||
|
|
||||||
|
text = "Hello world"
|
||||||
|
chunks = chunk_text(text, chunk_size=500, overlap=50)
|
||||||
|
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert chunks[0] == "Hello world"
|
||||||
|
|
||||||
|
def test_empty_text_returns_empty_list(self) -> None:
|
||||||
|
from orchestrator.tools.ingest import chunk_text
|
||||||
|
|
||||||
|
chunks = chunk_text("", chunk_size=500, overlap=50)
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
def test_whitespace_only_returns_empty_list(self) -> None:
|
||||||
|
from orchestrator.tools.ingest import chunk_text
|
||||||
|
|
||||||
|
chunks = chunk_text(" \n ", chunk_size=500, overlap=50)
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
def test_default_parameters(self) -> None:
|
||||||
|
from orchestrator.tools.ingest import chunk_text
|
||||||
|
|
||||||
|
text = "word " * 500 # 2500 chars
|
||||||
|
chunks = chunk_text(text)
|
||||||
|
|
||||||
|
assert len(chunks) > 1
|
||||||
|
# Default chunk_size is 500
|
||||||
|
for chunk in chunks:
|
||||||
|
assert len(chunk) <= 500
|
||||||
|
|
||||||
|
|
||||||
|
class TestIngestDocumentPipeline:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_upload_sets_status_ready(self) -> None:
|
||||||
|
"""Pipeline downloads file, extracts, chunks, embeds, stores, sets ready."""
|
||||||
|
from orchestrator.tools.ingest import ingest_document_pipeline
|
||||||
|
|
||||||
|
tenant_id = str(uuid.uuid4())
|
||||||
|
document_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
mock_doc = MagicMock()
|
||||||
|
mock_doc.id = uuid.UUID(document_id)
|
||||||
|
mock_doc.tenant_id = uuid.UUID(tenant_id)
|
||||||
|
mock_doc.filename = "test.txt"
|
||||||
|
mock_doc.source_url = None
|
||||||
|
mock_doc.status = "processing"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("orchestrator.tools.ingest.async_session_factory") as mock_sf,
|
||||||
|
patch("orchestrator.tools.ingest.engine"),
|
||||||
|
patch("orchestrator.tools.ingest.configure_rls_hook"),
|
||||||
|
patch("orchestrator.tools.ingest.current_tenant_id"),
|
||||||
|
patch("orchestrator.tools.ingest._get_minio_client") as mock_minio,
|
||||||
|
patch("orchestrator.tools.ingest.extract_text", return_value="Test content " * 50) as mock_extract,
|
||||||
|
patch("orchestrator.tools.ingest.embed_texts", return_value=[[0.1] * 384]) as mock_embed,
|
||||||
|
):
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = mock_doc
|
||||||
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
mock_session.commit = AsyncMock()
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_sf.return_value = mock_session
|
||||||
|
|
||||||
|
# MinIO returns file bytes
|
||||||
|
minio_client = MagicMock()
|
||||||
|
response_obj = MagicMock()
|
||||||
|
response_obj.read.return_value = b"Test content " * 50
|
||||||
|
minio_client.get_object.return_value = response_obj
|
||||||
|
mock_minio.return_value = minio_client
|
||||||
|
|
||||||
|
await ingest_document_pipeline(document_id, tenant_id)
|
||||||
|
|
||||||
|
# Status should be set to 'ready' on the document
|
||||||
|
assert mock_doc.status == "ready"
|
||||||
|
assert mock_doc.chunk_count is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_sets_error_on_exception(self) -> None:
|
||||||
|
"""Pipeline marks document as error when extraction fails."""
|
||||||
|
from orchestrator.tools.ingest import ingest_document_pipeline
|
||||||
|
|
||||||
|
tenant_id = str(uuid.uuid4())
|
||||||
|
document_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
mock_doc = MagicMock()
|
||||||
|
mock_doc.id = uuid.UUID(document_id)
|
||||||
|
mock_doc.tenant_id = uuid.UUID(tenant_id)
|
||||||
|
mock_doc.filename = "test.txt"
|
||||||
|
mock_doc.source_url = None
|
||||||
|
mock_doc.status = "processing"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("orchestrator.tools.ingest.async_session_factory") as mock_sf,
|
||||||
|
patch("orchestrator.tools.ingest.engine"),
|
||||||
|
patch("orchestrator.tools.ingest.configure_rls_hook"),
|
||||||
|
patch("orchestrator.tools.ingest.current_tenant_id"),
|
||||||
|
patch("orchestrator.tools.ingest._get_minio_client") as mock_minio,
|
||||||
|
):
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = mock_doc
|
||||||
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
mock_session.commit = AsyncMock()
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_sf.return_value = mock_session
|
||||||
|
|
||||||
|
# MinIO raises an error
|
||||||
|
minio_client = MagicMock()
|
||||||
|
minio_client.get_object.side_effect = Exception("MinIO connection failed")
|
||||||
|
mock_minio.return_value = minio_client
|
||||||
|
|
||||||
|
await ingest_document_pipeline(document_id, tenant_id)
|
||||||
|
|
||||||
|
assert mock_doc.status == "error"
|
||||||
|
assert mock_doc.error_message is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_document_not_found_is_no_op(self) -> None:
|
||||||
|
"""If document doesn't exist, pipeline exits gracefully."""
|
||||||
|
from orchestrator.tools.ingest import ingest_document_pipeline
|
||||||
|
|
||||||
|
tenant_id = str(uuid.uuid4())
|
||||||
|
document_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("orchestrator.tools.ingest.async_session_factory") as mock_sf,
|
||||||
|
patch("orchestrator.tools.ingest.engine"),
|
||||||
|
patch("orchestrator.tools.ingest.configure_rls_hook"),
|
||||||
|
patch("orchestrator.tools.ingest.current_tenant_id"),
|
||||||
|
):
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = None # Not found
|
||||||
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_sf.return_value = mock_session
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
await ingest_document_pipeline(document_id, tenant_id)
|
||||||
Reference in New Issue
Block a user