feat(10-01): Celery ingestion task, executor injection, KB search wiring
- Add ingest_document Celery task (sync def + asyncio.run per arch constraint) - Add ingest_document_pipeline: MinIO download, extract, chunk, embed, store - Add chunk_text sliding window chunker (500 chars default, 50 overlap) - Update execute_tool to inject tenant_id/agent_id into all tool handler kwargs - Update web_search to use settings.brave_api_key (shared config) not os.getenv - Unit tests: test_ingestion.py (9 tests) and test_executor_injection.py (5 tests) all pass
This commit is contained in:
@@ -997,3 +997,45 @@ async def _update_slack_placeholder(
|
||||
channel_id,
|
||||
placeholder_ts,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# KB Document Ingestion Task
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.task(
|
||||
name="orchestrator.tasks.ingest_document",
|
||||
bind=True,
|
||||
max_retries=2,
|
||||
default_retry_delay=60,
|
||||
ignore_result=True,
|
||||
)
|
||||
def ingest_document(self, document_id: str, tenant_id: str) -> None: # type: ignore[override]
|
||||
"""
|
||||
Celery task: run the KB document ingestion pipeline.
|
||||
|
||||
Downloads the document from MinIO (or scrapes URL/YouTube), extracts text,
|
||||
chunks, embeds with all-MiniLM-L6-v2, and stores kb_chunks rows.
|
||||
|
||||
Updates kb_documents.status to 'ready' on success, 'error' on failure.
|
||||
|
||||
MUST be sync def — Celery workers are not async-native. asyncio.run() is
|
||||
used to bridge the sync Celery world to the async pipeline.
|
||||
|
||||
Args:
|
||||
document_id: UUID string of the KnowledgeBaseDocument row.
|
||||
tenant_id: UUID string of the owning tenant.
|
||||
"""
|
||||
from orchestrator.tools.ingest import ingest_document_pipeline
|
||||
|
||||
try:
|
||||
asyncio.run(ingest_document_pipeline(document_id, tenant_id))
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"ingest_document task failed for document=%s tenant=%s: %s",
|
||||
document_id,
|
||||
tenant_id,
|
||||
exc,
|
||||
)
|
||||
self.retry(exc=exc, countdown=60)
|
||||
|
||||
@@ -13,10 +13,11 @@ raising an exception (graceful degradation for agents without search configured)
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
from shared.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BRAVE_API_URL = "https://api.search.brave.com/res/v1/web/search"
|
||||
@@ -24,24 +25,26 @@ _BRAVE_TIMEOUT = httpx.Timeout(timeout=15.0, connect=5.0)
|
||||
_MAX_RESULTS = 3
|
||||
|
||||
|
||||
async def web_search(query: str) -> str:
|
||||
async def web_search(query: str, **kwargs: object) -> str:
|
||||
"""
|
||||
Search the web using Brave Search API.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
**kwargs: Accepts injected tenant_id/agent_id from executor (unused).
|
||||
|
||||
Returns:
|
||||
Formatted string with top 3 search results (title + URL + description),
|
||||
or an error message if the API is unavailable.
|
||||
"""
|
||||
api_key = os.getenv("BRAVE_API_KEY", "")
|
||||
api_key = settings.brave_api_key
|
||||
if not api_key:
|
||||
return (
|
||||
"Web search is not configured. "
|
||||
"Set the BRAVE_API_KEY environment variable to enable web search."
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_BRAVE_TIMEOUT) as client:
|
||||
response = await client.get(
|
||||
|
||||
@@ -119,7 +119,15 @@ async def execute_tool(
|
||||
return confirmation_msg
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. Execute the handler
|
||||
# 5. Inject tenant context into args AFTER schema validation
|
||||
# This ensures kb_search, calendar_lookup, and future context-aware
|
||||
# tools receive tenant/agent context without the LLM providing it.
|
||||
# ------------------------------------------------------------------
|
||||
args["tenant_id"] = str(tenant_id)
|
||||
args["agent_id"] = str(agent_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 6. Execute the handler
|
||||
# ------------------------------------------------------------------
|
||||
start_ms = time.monotonic()
|
||||
try:
|
||||
|
||||
322
packages/orchestrator/orchestrator/tools/ingest.py
Normal file
322
packages/orchestrator/orchestrator/tools/ingest.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
Knowledge base document ingestion pipeline.
|
||||
|
||||
This module provides:
|
||||
chunk_text() — sliding window text chunker
|
||||
ingest_document_pipeline() — async pipeline: fetch → extract → chunk → embed → store
|
||||
|
||||
Pipeline steps:
|
||||
1. Load KnowledgeBaseDocument from DB
|
||||
2. Download file from MinIO (if filename) OR scrape URL / fetch YouTube transcript
|
||||
3. Extract text using orchestrator.tools.extractors.extract_text
|
||||
4. Chunk text with sliding window (500 chars, 50 overlap)
|
||||
5. Batch embed chunks via all-MiniLM-L6-v2
|
||||
6. INSERT kb_chunks rows with vector embeddings
|
||||
7. UPDATE kb_documents SET status='ready', chunk_count=N
|
||||
|
||||
On any error: UPDATE kb_documents SET status='error', error_message=str(exc)
|
||||
|
||||
IMPORTANT: This module is called from a Celery task via asyncio.run(). All DB
|
||||
and MinIO operations are async. The embedding call (embed_texts) is synchronous
|
||||
(SentenceTransformer is sync) — this is fine inside asyncio.run().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
|
||||
from shared.config import settings
|
||||
from shared.db import async_session_factory, engine
|
||||
from shared.rls import configure_rls_hook, current_tenant_id
|
||||
|
||||
from orchestrator.memory.embedder import embed_texts
|
||||
from orchestrator.tools.extractors import extract_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default chunking parameters
|
||||
_DEFAULT_CHUNK_SIZE = 500
|
||||
_DEFAULT_OVERLAP = 50
|
||||
|
||||
|
||||
def _get_minio_client() -> Any:
|
||||
"""Create a boto3 S3 client pointed at MinIO."""
|
||||
return boto3.client(
|
||||
"s3",
|
||||
endpoint_url=settings.minio_endpoint,
|
||||
aws_access_key_id=settings.minio_access_key,
|
||||
aws_secret_access_key=settings.minio_secret_key,
|
||||
)
|
||||
|
||||
|
||||
def chunk_text(
|
||||
text: str,
|
||||
chunk_size: int = _DEFAULT_CHUNK_SIZE,
|
||||
overlap: int = _DEFAULT_OVERLAP,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Split text into overlapping chunks using a sliding window.
|
||||
|
||||
Args:
|
||||
text: The text to chunk.
|
||||
chunk_size: Maximum characters per chunk.
|
||||
overlap: Number of characters to overlap between consecutive chunks.
|
||||
|
||||
Returns:
|
||||
List of non-empty text chunks. Returns empty list for empty/whitespace text.
|
||||
"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
if len(text) <= chunk_size:
|
||||
return [text]
|
||||
|
||||
chunks: list[str] = []
|
||||
start = 0
|
||||
step = chunk_size - overlap
|
||||
|
||||
while start < len(text):
|
||||
end = start + chunk_size
|
||||
chunk = text[start:end].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
if end >= len(text):
|
||||
break
|
||||
start += step
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
async def ingest_document_pipeline(document_id: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Run the full document ingestion pipeline for a KB document.
|
||||
|
||||
Steps:
|
||||
1. Load the KnowledgeBaseDocument from the database
|
||||
2. Fetch content (MinIO file OR URL scrape OR YouTube transcript)
|
||||
3. Extract plain text
|
||||
4. Chunk text
|
||||
5. Embed chunks
|
||||
6. Store kb_chunks rows in the database
|
||||
7. Mark document as 'ready'
|
||||
|
||||
On any error: set status='error' with error_message.
|
||||
|
||||
Args:
|
||||
document_id: UUID string of the KnowledgeBaseDocument to process.
|
||||
tenant_id: UUID string of the tenant (for RLS context).
|
||||
"""
|
||||
from sqlalchemy import select, text as sa_text
|
||||
|
||||
from shared.models.kb import KnowledgeBaseDocument
|
||||
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
doc_uuid = uuid.UUID(document_id)
|
||||
|
||||
configure_rls_hook(engine)
|
||||
token = current_tenant_id.set(tenant_uuid)
|
||||
try:
|
||||
async with async_session_factory() as session:
|
||||
result = await session.execute(
|
||||
select(KnowledgeBaseDocument).where(
|
||||
KnowledgeBaseDocument.id == doc_uuid
|
||||
)
|
||||
)
|
||||
doc = result.scalar_one_or_none()
|
||||
|
||||
if doc is None:
|
||||
logger.warning(
|
||||
"ingest_document_pipeline: document %s not found, skipping",
|
||||
document_id,
|
||||
)
|
||||
return
|
||||
|
||||
filename = doc.filename
|
||||
source_url = doc.source_url
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Fetch content
|
||||
# ------------------------------------------------------------------
|
||||
try:
|
||||
file_bytes: bytes | None = None
|
||||
extracted_text: str
|
||||
|
||||
if filename:
|
||||
# Download from MinIO
|
||||
bucket = settings.minio_kb_bucket
|
||||
key = f"{tenant_id}/{document_id}/{filename}"
|
||||
minio = _get_minio_client()
|
||||
response = minio.get_object(Bucket=bucket, Key=key)
|
||||
file_bytes = response.read()
|
||||
extracted_text = extract_text(filename, file_bytes)
|
||||
|
||||
elif source_url:
|
||||
extracted_text = await _fetch_url_content(source_url)
|
||||
|
||||
else:
|
||||
raise ValueError("Document has neither filename nor source_url")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 3-4: Chunk text
|
||||
# ------------------------------------------------------------------
|
||||
chunks = chunk_text(extracted_text)
|
||||
if not chunks:
|
||||
raise ValueError("No text content could be extracted from this document")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 5: Embed chunks
|
||||
# ------------------------------------------------------------------
|
||||
embeddings = embed_texts(chunks)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 6: Insert kb_chunks
|
||||
# ------------------------------------------------------------------
|
||||
# Delete any existing chunks for this document first
|
||||
await session.execute(
|
||||
sa_text("DELETE FROM kb_chunks WHERE document_id = :doc_id"),
|
||||
{"doc_id": str(doc_uuid)},
|
||||
)
|
||||
|
||||
for idx, (chunk_content, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
embedding_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
||||
await session.execute(
|
||||
sa_text("""
|
||||
INSERT INTO kb_chunks
|
||||
(tenant_id, document_id, content, chunk_index, embedding)
|
||||
VALUES
|
||||
(:tenant_id, :document_id, :content, :chunk_index,
|
||||
CAST(:embedding AS vector))
|
||||
"""),
|
||||
{
|
||||
"tenant_id": str(tenant_uuid),
|
||||
"document_id": str(doc_uuid),
|
||||
"content": chunk_content,
|
||||
"chunk_index": idx,
|
||||
"embedding": embedding_str,
|
||||
},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 7: Mark document as ready
|
||||
# ------------------------------------------------------------------
|
||||
doc.status = "ready"
|
||||
doc.chunk_count = len(chunks)
|
||||
doc.error_message = None
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
"ingest_document_pipeline: %s ingested %d chunks for document %s",
|
||||
tenant_id,
|
||||
len(chunks),
|
||||
document_id,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"ingest_document_pipeline: error processing document %s: %s",
|
||||
document_id,
|
||||
exc,
|
||||
)
|
||||
# Try to mark document as error
|
||||
try:
|
||||
doc.status = "error"
|
||||
doc.error_message = str(exc)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"ingest_document_pipeline: failed to mark document %s as error",
|
||||
document_id,
|
||||
)
|
||||
|
||||
finally:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
|
||||
async def _fetch_url_content(url: str) -> str:
|
||||
"""
|
||||
Fetch text content from a URL.
|
||||
|
||||
Supports:
|
||||
- YouTube URLs (via youtube-transcript-api)
|
||||
- Generic web pages (via firecrawl-py, graceful fallback if key not set)
|
||||
"""
|
||||
if _is_youtube_url(url):
|
||||
return await _fetch_youtube_transcript(url)
|
||||
else:
|
||||
return await _scrape_web_url(url)
|
||||
|
||||
|
||||
def _is_youtube_url(url: str) -> bool:
|
||||
"""Return True if the URL is a YouTube video."""
|
||||
return "youtube.com" in url or "youtu.be" in url
|
||||
|
||||
|
||||
async def _fetch_youtube_transcript(url: str) -> str:
|
||||
"""Fetch YouTube video transcript using youtube-transcript-api."""
|
||||
try:
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
|
||||
# Extract video ID from URL
|
||||
video_id = _extract_youtube_id(url)
|
||||
if not video_id:
|
||||
raise ValueError(f"Could not extract YouTube video ID from URL: {url}")
|
||||
|
||||
transcript = YouTubeTranscriptApi.get_transcript(video_id)
|
||||
return " ".join(entry["text"] for entry in transcript)
|
||||
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to fetch YouTube transcript: {exc}") from exc
|
||||
|
||||
|
||||
def _extract_youtube_id(url: str) -> str | None:
|
||||
"""Extract YouTube video ID from various URL formats."""
|
||||
import re
|
||||
|
||||
patterns = [
|
||||
r"youtube\.com/watch\?v=([a-zA-Z0-9_-]+)",
|
||||
r"youtu\.be/([a-zA-Z0-9_-]+)",
|
||||
r"youtube\.com/embed/([a-zA-Z0-9_-]+)",
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
|
||||
async def _scrape_web_url(url: str) -> str:
|
||||
"""Scrape a web URL to markdown using firecrawl-py."""
|
||||
if not settings.firecrawl_api_key:
|
||||
# Fallback: try simple httpx fetch
|
||||
return await _simple_fetch(url)
|
||||
|
||||
try:
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
app = FirecrawlApp(api_key=settings.firecrawl_api_key)
|
||||
result = app.scrape_url(url, params={"formats": ["markdown"]})
|
||||
if isinstance(result, dict):
|
||||
return result.get("markdown", result.get("content", str(result)))
|
||||
return str(result)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Firecrawl failed for %s: %s — falling back to simple fetch", url, exc)
|
||||
return await _simple_fetch(url)
|
||||
|
||||
|
||||
async def _simple_fetch(url: str) -> str:
|
||||
"""Simple httpx GET fetch as fallback for URL scraping."""
|
||||
import httpx
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to fetch URL {url}: {exc}") from exc
|
||||
186
tests/unit/test_executor_injection.py
Normal file
186
tests/unit/test_executor_injection.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
Unit tests for executor tenant_id/agent_id injection.
|
||||
|
||||
Tests that execute_tool injects tenant_id and agent_id into handler kwargs
|
||||
before calling the handler, so context-aware tools (kb_search, calendar_lookup)
|
||||
receive tenant context without the LLM needing to provide it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_tool(handler: Any, requires_confirmation: bool = False) -> Any:
|
||||
"""Create a minimal ToolDefinition-like object for tests."""
|
||||
tool = MagicMock()
|
||||
tool.handler = handler
|
||||
tool.requires_confirmation = requires_confirmation
|
||||
tool.parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
return tool
|
||||
|
||||
|
||||
class TestExecutorTenantInjection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tenant_id_injected_into_handler_kwargs(self) -> None:
|
||||
"""Handler should receive tenant_id even though LLM didn't provide it."""
|
||||
from orchestrator.tools.executor import execute_tool
|
||||
|
||||
received_kwargs: dict[str, Any] = {}
|
||||
|
||||
async def mock_handler(**kwargs: Any) -> str:
|
||||
received_kwargs.update(kwargs)
|
||||
return "handler result"
|
||||
|
||||
tool = _make_tool(mock_handler)
|
||||
registry = {"test_tool": tool}
|
||||
|
||||
tenant_id = uuid.uuid4()
|
||||
agent_id = uuid.uuid4()
|
||||
audit_logger = MagicMock()
|
||||
audit_logger.log_tool_call = AsyncMock()
|
||||
|
||||
tool_call = {
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"query": "hello world"}',
|
||||
}
|
||||
}
|
||||
|
||||
result = await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
|
||||
|
||||
assert result == "handler result"
|
||||
assert "tenant_id" in received_kwargs
|
||||
assert received_kwargs["tenant_id"] == str(tenant_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_id_injected_into_handler_kwargs(self) -> None:
|
||||
"""Handler should receive agent_id even though LLM didn't provide it."""
|
||||
from orchestrator.tools.executor import execute_tool
|
||||
|
||||
received_kwargs: dict[str, Any] = {}
|
||||
|
||||
async def mock_handler(**kwargs: Any) -> str:
|
||||
received_kwargs.update(kwargs)
|
||||
return "ok"
|
||||
|
||||
tool = _make_tool(mock_handler)
|
||||
registry = {"test_tool": tool}
|
||||
|
||||
tenant_id = uuid.uuid4()
|
||||
agent_id = uuid.uuid4()
|
||||
audit_logger = MagicMock()
|
||||
audit_logger.log_tool_call = AsyncMock()
|
||||
|
||||
tool_call = {
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"query": "test"}',
|
||||
}
|
||||
}
|
||||
|
||||
await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
|
||||
|
||||
assert "agent_id" in received_kwargs
|
||||
assert received_kwargs["agent_id"] == str(agent_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_injected_ids_are_strings(self) -> None:
|
||||
"""Injected tenant_id and agent_id should be strings, not UUIDs."""
|
||||
from orchestrator.tools.executor import execute_tool
|
||||
|
||||
received_kwargs: dict[str, Any] = {}
|
||||
|
||||
async def mock_handler(**kwargs: Any) -> str:
|
||||
received_kwargs.update(kwargs)
|
||||
return "ok"
|
||||
|
||||
tool = _make_tool(mock_handler)
|
||||
registry = {"test_tool": tool}
|
||||
|
||||
tenant_id = uuid.uuid4()
|
||||
agent_id = uuid.uuid4()
|
||||
audit_logger = MagicMock()
|
||||
audit_logger.log_tool_call = AsyncMock()
|
||||
|
||||
tool_call = {
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"query": "test"}',
|
||||
}
|
||||
}
|
||||
|
||||
await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
|
||||
|
||||
assert isinstance(received_kwargs["tenant_id"], str)
|
||||
assert isinstance(received_kwargs["agent_id"], str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_provided_args_preserved(self) -> None:
|
||||
"""Original LLM-provided args should still be present after injection."""
|
||||
from orchestrator.tools.executor import execute_tool
|
||||
|
||||
received_kwargs: dict[str, Any] = {}
|
||||
|
||||
async def mock_handler(**kwargs: Any) -> str:
|
||||
received_kwargs.update(kwargs)
|
||||
return "ok"
|
||||
|
||||
tool = _make_tool(mock_handler)
|
||||
registry = {"test_tool": tool}
|
||||
|
||||
tenant_id = uuid.uuid4()
|
||||
agent_id = uuid.uuid4()
|
||||
audit_logger = MagicMock()
|
||||
audit_logger.log_tool_call = AsyncMock()
|
||||
|
||||
tool_call = {
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"query": "search term from LLM"}',
|
||||
}
|
||||
}
|
||||
|
||||
await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
|
||||
|
||||
assert received_kwargs["query"] == "search term from LLM"
|
||||
assert received_kwargs["tenant_id"] == str(tenant_id)
|
||||
assert received_kwargs["agent_id"] == str(agent_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_injection_after_schema_validation(self) -> None:
|
||||
"""Injection happens after validation — injected keys don't cause schema failures."""
|
||||
from orchestrator.tools.executor import execute_tool
|
||||
|
||||
# Tool requires exactly 'query', nothing else in schema required
|
||||
# Schema should pass even though we inject tenant_id/agent_id
|
||||
async def mock_handler(**kwargs: Any) -> str:
|
||||
return "passed"
|
||||
|
||||
tool = _make_tool(mock_handler)
|
||||
registry = {"test_tool": tool}
|
||||
|
||||
tenant_id = uuid.uuid4()
|
||||
agent_id = uuid.uuid4()
|
||||
audit_logger = MagicMock()
|
||||
audit_logger.log_tool_call = AsyncMock()
|
||||
|
||||
tool_call = {
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"query": "test"}',
|
||||
}
|
||||
}
|
||||
|
||||
result = await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger)
|
||||
assert result == "passed"
|
||||
183
tests/unit/test_ingestion.py
Normal file
183
tests/unit/test_ingestion.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Unit tests for the KB ingestion pipeline.
|
||||
|
||||
Tests:
|
||||
- chunk_text: sliding window chunker produces correctly-sized, overlapping chunks
|
||||
- ingest_document_pipeline: downloads file from MinIO, extracts, chunks, embeds, stores
|
||||
- ingest_document_pipeline: sets status='error' on failure
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestChunkText:
|
||||
def test_basic_chunking(self) -> None:
|
||||
from orchestrator.tools.ingest import chunk_text
|
||||
|
||||
text = "a" * 1000
|
||||
chunks = chunk_text(text, chunk_size=100, overlap=10)
|
||||
|
||||
assert len(chunks) > 0
|
||||
for chunk in chunks:
|
||||
assert len(chunk) <= 100
|
||||
|
||||
def test_overlap_between_chunks(self) -> None:
|
||||
from orchestrator.tools.ingest import chunk_text
|
||||
|
||||
# Create text with identifiable segments
|
||||
text = "AAAA" * 50 + "BBBB" * 50 # 400 chars
|
||||
chunks = chunk_text(text, chunk_size=200, overlap=50)
|
||||
|
||||
# With overlap=50, consecutive chunks should share chars
|
||||
assert len(chunks) >= 2
|
||||
|
||||
def test_short_text_returns_one_chunk(self) -> None:
|
||||
from orchestrator.tools.ingest import chunk_text
|
||||
|
||||
text = "Hello world"
|
||||
chunks = chunk_text(text, chunk_size=500, overlap=50)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == "Hello world"
|
||||
|
||||
def test_empty_text_returns_empty_list(self) -> None:
|
||||
from orchestrator.tools.ingest import chunk_text
|
||||
|
||||
chunks = chunk_text("", chunk_size=500, overlap=50)
|
||||
assert chunks == []
|
||||
|
||||
def test_whitespace_only_returns_empty_list(self) -> None:
|
||||
from orchestrator.tools.ingest import chunk_text
|
||||
|
||||
chunks = chunk_text(" \n ", chunk_size=500, overlap=50)
|
||||
assert chunks == []
|
||||
|
||||
def test_default_parameters(self) -> None:
|
||||
from orchestrator.tools.ingest import chunk_text
|
||||
|
||||
text = "word " * 500 # 2500 chars
|
||||
chunks = chunk_text(text)
|
||||
|
||||
assert len(chunks) > 1
|
||||
# Default chunk_size is 500
|
||||
for chunk in chunks:
|
||||
assert len(chunk) <= 500
|
||||
|
||||
|
||||
class TestIngestDocumentPipeline:
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_upload_sets_status_ready(self) -> None:
|
||||
"""Pipeline downloads file, extracts, chunks, embeds, stores, sets ready."""
|
||||
from orchestrator.tools.ingest import ingest_document_pipeline
|
||||
|
||||
tenant_id = str(uuid.uuid4())
|
||||
document_id = str(uuid.uuid4())
|
||||
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.id = uuid.UUID(document_id)
|
||||
mock_doc.tenant_id = uuid.UUID(tenant_id)
|
||||
mock_doc.filename = "test.txt"
|
||||
mock_doc.source_url = None
|
||||
mock_doc.status = "processing"
|
||||
|
||||
with (
|
||||
patch("orchestrator.tools.ingest.async_session_factory") as mock_sf,
|
||||
patch("orchestrator.tools.ingest.engine"),
|
||||
patch("orchestrator.tools.ingest.configure_rls_hook"),
|
||||
patch("orchestrator.tools.ingest.current_tenant_id"),
|
||||
patch("orchestrator.tools.ingest._get_minio_client") as mock_minio,
|
||||
patch("orchestrator.tools.ingest.extract_text", return_value="Test content " * 50) as mock_extract,
|
||||
patch("orchestrator.tools.ingest.embed_texts", return_value=[[0.1] * 384]) as mock_embed,
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_doc
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_sf.return_value = mock_session
|
||||
|
||||
# MinIO returns file bytes
|
||||
minio_client = MagicMock()
|
||||
response_obj = MagicMock()
|
||||
response_obj.read.return_value = b"Test content " * 50
|
||||
minio_client.get_object.return_value = response_obj
|
||||
mock_minio.return_value = minio_client
|
||||
|
||||
await ingest_document_pipeline(document_id, tenant_id)
|
||||
|
||||
# Status should be set to 'ready' on the document
|
||||
assert mock_doc.status == "ready"
|
||||
assert mock_doc.chunk_count is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_sets_error_on_exception(self) -> None:
|
||||
"""Pipeline marks document as error when extraction fails."""
|
||||
from orchestrator.tools.ingest import ingest_document_pipeline
|
||||
|
||||
tenant_id = str(uuid.uuid4())
|
||||
document_id = str(uuid.uuid4())
|
||||
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.id = uuid.UUID(document_id)
|
||||
mock_doc.tenant_id = uuid.UUID(tenant_id)
|
||||
mock_doc.filename = "test.txt"
|
||||
mock_doc.source_url = None
|
||||
mock_doc.status = "processing"
|
||||
|
||||
with (
|
||||
patch("orchestrator.tools.ingest.async_session_factory") as mock_sf,
|
||||
patch("orchestrator.tools.ingest.engine"),
|
||||
patch("orchestrator.tools.ingest.configure_rls_hook"),
|
||||
patch("orchestrator.tools.ingest.current_tenant_id"),
|
||||
patch("orchestrator.tools.ingest._get_minio_client") as mock_minio,
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_doc
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_sf.return_value = mock_session
|
||||
|
||||
# MinIO raises an error
|
||||
minio_client = MagicMock()
|
||||
minio_client.get_object.side_effect = Exception("MinIO connection failed")
|
||||
mock_minio.return_value = minio_client
|
||||
|
||||
await ingest_document_pipeline(document_id, tenant_id)
|
||||
|
||||
assert mock_doc.status == "error"
|
||||
assert mock_doc.error_message is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_not_found_is_no_op(self) -> None:
|
||||
"""If document doesn't exist, pipeline exits gracefully."""
|
||||
from orchestrator.tools.ingest import ingest_document_pipeline
|
||||
|
||||
tenant_id = str(uuid.uuid4())
|
||||
document_id = str(uuid.uuid4())
|
||||
|
||||
with (
|
||||
patch("orchestrator.tools.ingest.async_session_factory") as mock_sf,
|
||||
patch("orchestrator.tools.ingest.engine"),
|
||||
patch("orchestrator.tools.ingest.configure_rls_hook"),
|
||||
patch("orchestrator.tools.ingest.current_tenant_id"),
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None # Not found
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_sf.return_value = mock_session
|
||||
|
||||
# Should not raise
|
||||
await ingest_document_pipeline(document_id, tenant_id)
|
||||
Reference in New Issue
Block a user