feat(10-01): Celery ingestion task, executor injection, KB search wiring
- Add ingest_document Celery task (sync def + asyncio.run per arch constraint) - Add ingest_document_pipeline: MinIO download, extract, chunk, embed, store - Add chunk_text sliding window chunker (500 chars default, 50 overlap) - Update execute_tool to inject tenant_id/agent_id into all tool handler kwargs - Update web_search to use settings.brave_api_key (shared config) not os.getenv - Unit tests: test_ingestion.py (9 tests) and test_executor_injection.py (5 tests) all pass
This commit is contained in:
@@ -997,3 +997,45 @@ async def _update_slack_placeholder(
|
||||
channel_id,
|
||||
placeholder_ts,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# KB Document Ingestion Task
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.task(
|
||||
name="orchestrator.tasks.ingest_document",
|
||||
bind=True,
|
||||
max_retries=2,
|
||||
default_retry_delay=60,
|
||||
ignore_result=True,
|
||||
)
|
||||
def ingest_document(self, document_id: str, tenant_id: str) -> None: # type: ignore[override]
|
||||
"""
|
||||
Celery task: run the KB document ingestion pipeline.
|
||||
|
||||
Downloads the document from MinIO (or scrapes URL/YouTube), extracts text,
|
||||
chunks, embeds with all-MiniLM-L6-v2, and stores kb_chunks rows.
|
||||
|
||||
Updates kb_documents.status to 'ready' on success, 'error' on failure.
|
||||
|
||||
MUST be sync def — Celery workers are not async-native. asyncio.run() is
|
||||
used to bridge the sync Celery world to the async pipeline.
|
||||
|
||||
Args:
|
||||
document_id: UUID string of the KnowledgeBaseDocument row.
|
||||
tenant_id: UUID string of the owning tenant.
|
||||
"""
|
||||
from orchestrator.tools.ingest import ingest_document_pipeline
|
||||
|
||||
try:
|
||||
asyncio.run(ingest_document_pipeline(document_id, tenant_id))
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"ingest_document task failed for document=%s tenant=%s: %s",
|
||||
document_id,
|
||||
tenant_id,
|
||||
exc,
|
||||
)
|
||||
self.retry(exc=exc, countdown=60)
|
||||
|
||||
@@ -13,10 +13,11 @@ raising an exception (graceful degradation for agents without search configured)
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
from shared.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BRAVE_API_URL = "https://api.search.brave.com/res/v1/web/search"
|
||||
@@ -24,24 +25,26 @@ _BRAVE_TIMEOUT = httpx.Timeout(timeout=15.0, connect=5.0)
|
||||
_MAX_RESULTS = 3
|
||||
|
||||
|
||||
async def web_search(query: str) -> str:
|
||||
async def web_search(query: str, **kwargs: object) -> str:
|
||||
"""
|
||||
Search the web using Brave Search API.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
**kwargs: Accepts injected tenant_id/agent_id from executor (unused).
|
||||
|
||||
Returns:
|
||||
Formatted string with top 3 search results (title + URL + description),
|
||||
or an error message if the API is unavailable.
|
||||
"""
|
||||
api_key = os.getenv("BRAVE_API_KEY", "")
|
||||
api_key = settings.brave_api_key
|
||||
if not api_key:
|
||||
return (
|
||||
"Web search is not configured. "
|
||||
"Set the BRAVE_API_KEY environment variable to enable web search."
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_BRAVE_TIMEOUT) as client:
|
||||
response = await client.get(
|
||||
|
||||
@@ -119,7 +119,15 @@ async def execute_tool(
|
||||
return confirmation_msg
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. Execute the handler
|
||||
# 5. Inject tenant context into args AFTER schema validation
|
||||
# This ensures kb_search, calendar_lookup, and future context-aware
|
||||
# tools receive tenant/agent context without the LLM providing it.
|
||||
# ------------------------------------------------------------------
|
||||
args["tenant_id"] = str(tenant_id)
|
||||
args["agent_id"] = str(agent_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 6. Execute the handler
|
||||
# ------------------------------------------------------------------
|
||||
start_ms = time.monotonic()
|
||||
try:
|
||||
|
||||
322
packages/orchestrator/orchestrator/tools/ingest.py
Normal file
322
packages/orchestrator/orchestrator/tools/ingest.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
Knowledge base document ingestion pipeline.
|
||||
|
||||
This module provides:
|
||||
chunk_text() — sliding window text chunker
|
||||
ingest_document_pipeline() — async pipeline: fetch → extract → chunk → embed → store
|
||||
|
||||
Pipeline steps:
|
||||
1. Load KnowledgeBaseDocument from DB
|
||||
2. Download file from MinIO (if filename) OR scrape URL / fetch YouTube transcript
|
||||
3. Extract text using orchestrator.tools.extractors.extract_text
|
||||
4. Chunk text with sliding window (500 chars, 50 overlap)
|
||||
5. Batch embed chunks via all-MiniLM-L6-v2
|
||||
6. INSERT kb_chunks rows with vector embeddings
|
||||
7. UPDATE kb_documents SET status='ready', chunk_count=N
|
||||
|
||||
On any error: UPDATE kb_documents SET status='error', error_message=str(exc)
|
||||
|
||||
IMPORTANT: This module is called from a Celery task via asyncio.run(). All DB
|
||||
and MinIO operations are async. The embedding call (embed_texts) is synchronous
|
||||
(SentenceTransformer is sync) — this is fine inside asyncio.run().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
|
||||
from shared.config import settings
|
||||
from shared.db import async_session_factory, engine
|
||||
from shared.rls import configure_rls_hook, current_tenant_id
|
||||
|
||||
from orchestrator.memory.embedder import embed_texts
|
||||
from orchestrator.tools.extractors import extract_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default chunking parameters
|
||||
_DEFAULT_CHUNK_SIZE = 500
|
||||
_DEFAULT_OVERLAP = 50
|
||||
|
||||
|
||||
def _get_minio_client() -> Any:
|
||||
"""Create a boto3 S3 client pointed at MinIO."""
|
||||
return boto3.client(
|
||||
"s3",
|
||||
endpoint_url=settings.minio_endpoint,
|
||||
aws_access_key_id=settings.minio_access_key,
|
||||
aws_secret_access_key=settings.minio_secret_key,
|
||||
)
|
||||
|
||||
|
||||
def chunk_text(
|
||||
text: str,
|
||||
chunk_size: int = _DEFAULT_CHUNK_SIZE,
|
||||
overlap: int = _DEFAULT_OVERLAP,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Split text into overlapping chunks using a sliding window.
|
||||
|
||||
Args:
|
||||
text: The text to chunk.
|
||||
chunk_size: Maximum characters per chunk.
|
||||
overlap: Number of characters to overlap between consecutive chunks.
|
||||
|
||||
Returns:
|
||||
List of non-empty text chunks. Returns empty list for empty/whitespace text.
|
||||
"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
if len(text) <= chunk_size:
|
||||
return [text]
|
||||
|
||||
chunks: list[str] = []
|
||||
start = 0
|
||||
step = chunk_size - overlap
|
||||
|
||||
while start < len(text):
|
||||
end = start + chunk_size
|
||||
chunk = text[start:end].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
if end >= len(text):
|
||||
break
|
||||
start += step
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
async def ingest_document_pipeline(document_id: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Run the full document ingestion pipeline for a KB document.
|
||||
|
||||
Steps:
|
||||
1. Load the KnowledgeBaseDocument from the database
|
||||
2. Fetch content (MinIO file OR URL scrape OR YouTube transcript)
|
||||
3. Extract plain text
|
||||
4. Chunk text
|
||||
5. Embed chunks
|
||||
6. Store kb_chunks rows in the database
|
||||
7. Mark document as 'ready'
|
||||
|
||||
On any error: set status='error' with error_message.
|
||||
|
||||
Args:
|
||||
document_id: UUID string of the KnowledgeBaseDocument to process.
|
||||
tenant_id: UUID string of the tenant (for RLS context).
|
||||
"""
|
||||
from sqlalchemy import select, text as sa_text
|
||||
|
||||
from shared.models.kb import KnowledgeBaseDocument
|
||||
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
doc_uuid = uuid.UUID(document_id)
|
||||
|
||||
configure_rls_hook(engine)
|
||||
token = current_tenant_id.set(tenant_uuid)
|
||||
try:
|
||||
async with async_session_factory() as session:
|
||||
result = await session.execute(
|
||||
select(KnowledgeBaseDocument).where(
|
||||
KnowledgeBaseDocument.id == doc_uuid
|
||||
)
|
||||
)
|
||||
doc = result.scalar_one_or_none()
|
||||
|
||||
if doc is None:
|
||||
logger.warning(
|
||||
"ingest_document_pipeline: document %s not found, skipping",
|
||||
document_id,
|
||||
)
|
||||
return
|
||||
|
||||
filename = doc.filename
|
||||
source_url = doc.source_url
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Fetch content
|
||||
# ------------------------------------------------------------------
|
||||
try:
|
||||
file_bytes: bytes | None = None
|
||||
extracted_text: str
|
||||
|
||||
if filename:
|
||||
# Download from MinIO
|
||||
bucket = settings.minio_kb_bucket
|
||||
key = f"{tenant_id}/{document_id}/{filename}"
|
||||
minio = _get_minio_client()
|
||||
response = minio.get_object(Bucket=bucket, Key=key)
|
||||
file_bytes = response.read()
|
||||
extracted_text = extract_text(filename, file_bytes)
|
||||
|
||||
elif source_url:
|
||||
extracted_text = await _fetch_url_content(source_url)
|
||||
|
||||
else:
|
||||
raise ValueError("Document has neither filename nor source_url")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 3-4: Chunk text
|
||||
# ------------------------------------------------------------------
|
||||
chunks = chunk_text(extracted_text)
|
||||
if not chunks:
|
||||
raise ValueError("No text content could be extracted from this document")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 5: Embed chunks
|
||||
# ------------------------------------------------------------------
|
||||
embeddings = embed_texts(chunks)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 6: Insert kb_chunks
|
||||
# ------------------------------------------------------------------
|
||||
# Delete any existing chunks for this document first
|
||||
await session.execute(
|
||||
sa_text("DELETE FROM kb_chunks WHERE document_id = :doc_id"),
|
||||
{"doc_id": str(doc_uuid)},
|
||||
)
|
||||
|
||||
for idx, (chunk_content, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
embedding_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
||||
await session.execute(
|
||||
sa_text("""
|
||||
INSERT INTO kb_chunks
|
||||
(tenant_id, document_id, content, chunk_index, embedding)
|
||||
VALUES
|
||||
(:tenant_id, :document_id, :content, :chunk_index,
|
||||
CAST(:embedding AS vector))
|
||||
"""),
|
||||
{
|
||||
"tenant_id": str(tenant_uuid),
|
||||
"document_id": str(doc_uuid),
|
||||
"content": chunk_content,
|
||||
"chunk_index": idx,
|
||||
"embedding": embedding_str,
|
||||
},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 7: Mark document as ready
|
||||
# ------------------------------------------------------------------
|
||||
doc.status = "ready"
|
||||
doc.chunk_count = len(chunks)
|
||||
doc.error_message = None
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
"ingest_document_pipeline: %s ingested %d chunks for document %s",
|
||||
tenant_id,
|
||||
len(chunks),
|
||||
document_id,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"ingest_document_pipeline: error processing document %s: %s",
|
||||
document_id,
|
||||
exc,
|
||||
)
|
||||
# Try to mark document as error
|
||||
try:
|
||||
doc.status = "error"
|
||||
doc.error_message = str(exc)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"ingest_document_pipeline: failed to mark document %s as error",
|
||||
document_id,
|
||||
)
|
||||
|
||||
finally:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
|
||||
async def _fetch_url_content(url: str) -> str:
|
||||
"""
|
||||
Fetch text content from a URL.
|
||||
|
||||
Supports:
|
||||
- YouTube URLs (via youtube-transcript-api)
|
||||
- Generic web pages (via firecrawl-py, graceful fallback if key not set)
|
||||
"""
|
||||
if _is_youtube_url(url):
|
||||
return await _fetch_youtube_transcript(url)
|
||||
else:
|
||||
return await _scrape_web_url(url)
|
||||
|
||||
|
||||
def _is_youtube_url(url: str) -> bool:
|
||||
"""Return True if the URL is a YouTube video."""
|
||||
return "youtube.com" in url or "youtu.be" in url
|
||||
|
||||
|
||||
async def _fetch_youtube_transcript(url: str) -> str:
|
||||
"""Fetch YouTube video transcript using youtube-transcript-api."""
|
||||
try:
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
|
||||
# Extract video ID from URL
|
||||
video_id = _extract_youtube_id(url)
|
||||
if not video_id:
|
||||
raise ValueError(f"Could not extract YouTube video ID from URL: {url}")
|
||||
|
||||
transcript = YouTubeTranscriptApi.get_transcript(video_id)
|
||||
return " ".join(entry["text"] for entry in transcript)
|
||||
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to fetch YouTube transcript: {exc}") from exc
|
||||
|
||||
|
||||
def _extract_youtube_id(url: str) -> str | None:
|
||||
"""Extract YouTube video ID from various URL formats."""
|
||||
import re
|
||||
|
||||
patterns = [
|
||||
r"youtube\.com/watch\?v=([a-zA-Z0-9_-]+)",
|
||||
r"youtu\.be/([a-zA-Z0-9_-]+)",
|
||||
r"youtube\.com/embed/([a-zA-Z0-9_-]+)",
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
|
||||
async def _scrape_web_url(url: str) -> str:
|
||||
"""Scrape a web URL to markdown using firecrawl-py."""
|
||||
if not settings.firecrawl_api_key:
|
||||
# Fallback: try simple httpx fetch
|
||||
return await _simple_fetch(url)
|
||||
|
||||
try:
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
app = FirecrawlApp(api_key=settings.firecrawl_api_key)
|
||||
result = app.scrape_url(url, params={"formats": ["markdown"]})
|
||||
if isinstance(result, dict):
|
||||
return result.get("markdown", result.get("content", str(result)))
|
||||
return str(result)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Firecrawl failed for %s: %s — falling back to simple fetch", url, exc)
|
||||
return await _simple_fetch(url)
|
||||
|
||||
|
||||
async def _simple_fetch(url: str) -> str:
|
||||
"""Simple httpx GET fetch as fallback for URL scraping."""
|
||||
import httpx
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to fetch URL {url}: {exc}") from exc
|
||||
Reference in New Issue
Block a user