diff --git a/packages/orchestrator/orchestrator/tasks.py b/packages/orchestrator/orchestrator/tasks.py index dce7ba4..54abc81 100644 --- a/packages/orchestrator/orchestrator/tasks.py +++ b/packages/orchestrator/orchestrator/tasks.py @@ -997,3 +997,45 @@ async def _update_slack_placeholder( channel_id, placeholder_ts, ) + + +# ============================================================================= +# KB Document Ingestion Task +# ============================================================================= + + +@app.task( + name="orchestrator.tasks.ingest_document", + bind=True, + max_retries=2, + default_retry_delay=60, + ignore_result=True, +) +def ingest_document(self, document_id: str, tenant_id: str) -> None: # type: ignore[override] + """ + Celery task: run the KB document ingestion pipeline. + + Downloads the document from MinIO (or scrapes URL/YouTube), extracts text, + chunks, embeds with all-MiniLM-L6-v2, and stores kb_chunks rows. + + Updates kb_documents.status to 'ready' on success, 'error' on failure. + + MUST be sync def — Celery workers are not async-native. asyncio.run() is + used to bridge the sync Celery world to the async pipeline. + + Args: + document_id: UUID string of the KnowledgeBaseDocument row. + tenant_id: UUID string of the owning tenant. + """ + from orchestrator.tools.ingest import ingest_document_pipeline + + try: + asyncio.run(ingest_document_pipeline(document_id, tenant_id)) + except Exception as exc: + logger.exception( + "ingest_document task failed for document=%s tenant=%s: %s", + document_id, + tenant_id, + exc, + ) + self.retry(exc=exc, countdown=60) diff --git a/packages/orchestrator/orchestrator/tools/builtins/web_search.py b/packages/orchestrator/orchestrator/tools/builtins/web_search.py index 08bf6ce..6e6c7c9 100644 --- a/packages/orchestrator/orchestrator/tools/builtins/web_search.py +++ b/packages/orchestrator/orchestrator/tools/builtins/web_search.py @@ -13,10 +13,11 @@ raising an exception (graceful degradation for agents without search configured) from __future__ import annotations import logging -import os import httpx +from shared.config import settings + logger = logging.getLogger(__name__) _BRAVE_API_URL = "https://api.search.brave.com/res/v1/web/search" @@ -24,24 +25,26 @@ _BRAVE_TIMEOUT = httpx.Timeout(timeout=15.0, connect=5.0) _MAX_RESULTS = 3 -async def web_search(query: str) -> str: +async def web_search(query: str, **kwargs: object) -> str: """ Search the web using Brave Search API. Args: query: The search query string. + **kwargs: Accepts injected tenant_id/agent_id from executor (unused). Returns: Formatted string with top 3 search results (title + URL + description), or an error message if the API is unavailable. """ - api_key = os.getenv("BRAVE_API_KEY", "") + api_key = settings.brave_api_key if not api_key: return ( "Web search is not configured. " "Set the BRAVE_API_KEY environment variable to enable web search." ) + try: async with httpx.AsyncClient(timeout=_BRAVE_TIMEOUT) as client: response = await client.get( diff --git a/packages/orchestrator/orchestrator/tools/executor.py b/packages/orchestrator/orchestrator/tools/executor.py index c2e1b60..e16ad02 100644 --- a/packages/orchestrator/orchestrator/tools/executor.py +++ b/packages/orchestrator/orchestrator/tools/executor.py @@ -119,7 +119,15 @@ async def execute_tool( return confirmation_msg # ------------------------------------------------------------------ - # 5. Execute the handler + # 5. Inject tenant context into args AFTER schema validation + # This ensures kb_search, calendar_lookup, and future context-aware + # tools receive tenant/agent context without the LLM providing it. + # ------------------------------------------------------------------ + args["tenant_id"] = str(tenant_id) + args["agent_id"] = str(agent_id) + + # ------------------------------------------------------------------ + # 6. Execute the handler # ------------------------------------------------------------------ start_ms = time.monotonic() try: diff --git a/packages/orchestrator/orchestrator/tools/ingest.py b/packages/orchestrator/orchestrator/tools/ingest.py new file mode 100644 index 0000000..2547c23 --- /dev/null +++ b/packages/orchestrator/orchestrator/tools/ingest.py @@ -0,0 +1,322 @@ +""" +Knowledge base document ingestion pipeline. + +This module provides: + chunk_text() — sliding window text chunker + ingest_document_pipeline() — async pipeline: fetch → extract → chunk → embed → store + +Pipeline steps: + 1. Load KnowledgeBaseDocument from DB + 2. Download file from MinIO (if filename) OR scrape URL / fetch YouTube transcript + 3. Extract text using orchestrator.tools.extractors.extract_text + 4. Chunk text with sliding window (500 chars, 50 overlap) + 5. Batch embed chunks via all-MiniLM-L6-v2 + 6. INSERT kb_chunks rows with vector embeddings + 7. UPDATE kb_documents SET status='ready', chunk_count=N + +On any error: UPDATE kb_documents SET status='error', error_message=str(exc) + +IMPORTANT: This module is called from a Celery task via asyncio.run(). All DB +and MinIO operations are async. The embedding call (embed_texts) is synchronous +(SentenceTransformer is sync) — this is fine inside asyncio.run(). +""" + +from __future__ import annotations + +import logging +import uuid +from typing import Any + +import boto3 + +from shared.config import settings +from shared.db import async_session_factory, engine +from shared.rls import configure_rls_hook, current_tenant_id + +from orchestrator.memory.embedder import embed_texts +from orchestrator.tools.extractors import extract_text + +logger = logging.getLogger(__name__) + +# Default chunking parameters +_DEFAULT_CHUNK_SIZE = 500 +_DEFAULT_OVERLAP = 50 + + +def _get_minio_client() -> Any: + """Create a boto3 S3 client pointed at MinIO.""" + return boto3.client( + "s3", + endpoint_url=settings.minio_endpoint, + aws_access_key_id=settings.minio_access_key, + aws_secret_access_key=settings.minio_secret_key, + ) + + +def chunk_text( + text: str, + chunk_size: int = _DEFAULT_CHUNK_SIZE, + overlap: int = _DEFAULT_OVERLAP, +) -> list[str]: + """ + Split text into overlapping chunks using a sliding window. + + Args: + text: The text to chunk. + chunk_size: Maximum characters per chunk. + overlap: Number of characters to overlap between consecutive chunks. + + Returns: + List of non-empty text chunks. Returns empty list for empty/whitespace text. + """ + text = text.strip() + if not text: + return [] + + if len(text) <= chunk_size: + return [text] + + chunks: list[str] = [] + start = 0 + step = chunk_size - overlap + + while start < len(text): + end = start + chunk_size + chunk = text[start:end].strip() + if chunk: + chunks.append(chunk) + if end >= len(text): + break + start += step + + return chunks + + +async def ingest_document_pipeline(document_id: str, tenant_id: str) -> None: + """ + Run the full document ingestion pipeline for a KB document. + + Steps: + 1. Load the KnowledgeBaseDocument from the database + 2. Fetch content (MinIO file OR URL scrape OR YouTube transcript) + 3. Extract plain text + 4. Chunk text + 5. Embed chunks + 6. Store kb_chunks rows in the database + 7. Mark document as 'ready' + + On any error: set status='error' with error_message. + + Args: + document_id: UUID string of the KnowledgeBaseDocument to process. + tenant_id: UUID string of the tenant (for RLS context). + """ + from sqlalchemy import select, text as sa_text + + from shared.models.kb import KnowledgeBaseDocument + + tenant_uuid = uuid.UUID(tenant_id) + doc_uuid = uuid.UUID(document_id) + + configure_rls_hook(engine) + token = current_tenant_id.set(tenant_uuid) + try: + async with async_session_factory() as session: + result = await session.execute( + select(KnowledgeBaseDocument).where( + KnowledgeBaseDocument.id == doc_uuid + ) + ) + doc = result.scalar_one_or_none() + + if doc is None: + logger.warning( + "ingest_document_pipeline: document %s not found, skipping", + document_id, + ) + return + + filename = doc.filename + source_url = doc.source_url + + # ------------------------------------------------------------------ + # Step 2: Fetch content + # ------------------------------------------------------------------ + try: + file_bytes: bytes | None = None + extracted_text: str + + if filename: + # Download from MinIO + bucket = settings.minio_kb_bucket + key = f"{tenant_id}/{document_id}/{filename}" + minio = _get_minio_client() + response = minio.get_object(Bucket=bucket, Key=key) + file_bytes = response.read() + extracted_text = extract_text(filename, file_bytes) + + elif source_url: + extracted_text = await _fetch_url_content(source_url) + + else: + raise ValueError("Document has neither filename nor source_url") + + # ------------------------------------------------------------------ + # Step 3-4: Chunk text + # ------------------------------------------------------------------ + chunks = chunk_text(extracted_text) + if not chunks: + raise ValueError("No text content could be extracted from this document") + + # ------------------------------------------------------------------ + # Step 5: Embed chunks + # ------------------------------------------------------------------ + embeddings = embed_texts(chunks) + + # ------------------------------------------------------------------ + # Step 6: Insert kb_chunks + # ------------------------------------------------------------------ + # Delete any existing chunks for this document first + await session.execute( + sa_text("DELETE FROM kb_chunks WHERE document_id = :doc_id"), + {"doc_id": str(doc_uuid)}, + ) + + for idx, (chunk_content, embedding) in enumerate(zip(chunks, embeddings)): + embedding_str = "[" + ",".join(str(x) for x in embedding) + "]" + await session.execute( + sa_text(""" + INSERT INTO kb_chunks + (tenant_id, document_id, content, chunk_index, embedding) + VALUES + (:tenant_id, :document_id, :content, :chunk_index, + CAST(:embedding AS vector)) + """), + { + "tenant_id": str(tenant_uuid), + "document_id": str(doc_uuid), + "content": chunk_content, + "chunk_index": idx, + "embedding": embedding_str, + }, + ) + + # ------------------------------------------------------------------ + # Step 7: Mark document as ready + # ------------------------------------------------------------------ + doc.status = "ready" + doc.chunk_count = len(chunks) + doc.error_message = None + await session.commit() + + logger.info( + "ingest_document_pipeline: %s ingested %d chunks for document %s", + tenant_id, + len(chunks), + document_id, + ) + + except Exception as exc: + logger.exception( + "ingest_document_pipeline: error processing document %s: %s", + document_id, + exc, + ) + # Try to mark document as error + try: + doc.status = "error" + doc.error_message = str(exc) + await session.commit() + except Exception: + logger.exception( + "ingest_document_pipeline: failed to mark document %s as error", + document_id, + ) + + finally: + current_tenant_id.reset(token) + + +async def _fetch_url_content(url: str) -> str: + """ + Fetch text content from a URL. + + Supports: + - YouTube URLs (via youtube-transcript-api) + - Generic web pages (via firecrawl-py, graceful fallback if key not set) + """ + if _is_youtube_url(url): + return await _fetch_youtube_transcript(url) + else: + return await _scrape_web_url(url) + + +def _is_youtube_url(url: str) -> bool: + """Return True if the URL is a YouTube video.""" + return "youtube.com" in url or "youtu.be" in url + + +async def _fetch_youtube_transcript(url: str) -> str: + """Fetch YouTube video transcript using youtube-transcript-api.""" + try: + from youtube_transcript_api import YouTubeTranscriptApi + + # Extract video ID from URL + video_id = _extract_youtube_id(url) + if not video_id: + raise ValueError(f"Could not extract YouTube video ID from URL: {url}") + + transcript = YouTubeTranscriptApi.get_transcript(video_id) + return " ".join(entry["text"] for entry in transcript) + + except Exception as exc: + raise ValueError(f"Failed to fetch YouTube transcript: {exc}") from exc + + +def _extract_youtube_id(url: str) -> str | None: + """Extract YouTube video ID from various URL formats.""" + import re + + patterns = [ + r"youtube\.com/watch\?v=([a-zA-Z0-9_-]+)", + r"youtu\.be/([a-zA-Z0-9_-]+)", + r"youtube\.com/embed/([a-zA-Z0-9_-]+)", + ] + for pattern in patterns: + match = re.search(pattern, url) + if match: + return match.group(1) + return None + + +async def _scrape_web_url(url: str) -> str: + """Scrape a web URL to markdown using firecrawl-py.""" + if not settings.firecrawl_api_key: + # Fallback: try simple httpx fetch + return await _simple_fetch(url) + + try: + from firecrawl import FirecrawlApp + + app = FirecrawlApp(api_key=settings.firecrawl_api_key) + result = app.scrape_url(url, params={"formats": ["markdown"]}) + if isinstance(result, dict): + return result.get("markdown", result.get("content", str(result))) + return str(result) + + except Exception as exc: + logger.warning("Firecrawl failed for %s: %s — falling back to simple fetch", url, exc) + return await _simple_fetch(url) + + +async def _simple_fetch(url: str) -> str: + """Simple httpx GET fetch as fallback for URL scraping.""" + import httpx + + try: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(url, follow_redirects=True) + response.raise_for_status() + return response.text + except Exception as exc: + raise ValueError(f"Failed to fetch URL {url}: {exc}") from exc diff --git a/tests/unit/test_executor_injection.py b/tests/unit/test_executor_injection.py new file mode 100644 index 0000000..ca0ccc8 --- /dev/null +++ b/tests/unit/test_executor_injection.py @@ -0,0 +1,186 @@ +""" +Unit tests for executor tenant_id/agent_id injection. + +Tests that execute_tool injects tenant_id and agent_id into handler kwargs +before calling the handler, so context-aware tools (kb_search, calendar_lookup) +receive tenant context without the LLM needing to provide it. +""" + +from __future__ import annotations + +import uuid +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + + +def _make_tool(handler: Any, requires_confirmation: bool = False) -> Any: + """Create a minimal ToolDefinition-like object for tests.""" + tool = MagicMock() + tool.handler = handler + tool.requires_confirmation = requires_confirmation + tool.parameters = { + "type": "object", + "properties": { + "query": {"type": "string"}, + }, + "required": ["query"], + } + return tool + + +class TestExecutorTenantInjection: + @pytest.mark.asyncio + async def test_tenant_id_injected_into_handler_kwargs(self) -> None: + """Handler should receive tenant_id even though LLM didn't provide it.""" + from orchestrator.tools.executor import execute_tool + + received_kwargs: dict[str, Any] = {} + + async def mock_handler(**kwargs: Any) -> str: + received_kwargs.update(kwargs) + return "handler result" + + tool = _make_tool(mock_handler) + registry = {"test_tool": tool} + + tenant_id = uuid.uuid4() + agent_id = uuid.uuid4() + audit_logger = MagicMock() + audit_logger.log_tool_call = AsyncMock() + + tool_call = { + "function": { + "name": "test_tool", + "arguments": '{"query": "hello world"}', + } + } + + result = await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger) + + assert result == "handler result" + assert "tenant_id" in received_kwargs + assert received_kwargs["tenant_id"] == str(tenant_id) + + @pytest.mark.asyncio + async def test_agent_id_injected_into_handler_kwargs(self) -> None: + """Handler should receive agent_id even though LLM didn't provide it.""" + from orchestrator.tools.executor import execute_tool + + received_kwargs: dict[str, Any] = {} + + async def mock_handler(**kwargs: Any) -> str: + received_kwargs.update(kwargs) + return "ok" + + tool = _make_tool(mock_handler) + registry = {"test_tool": tool} + + tenant_id = uuid.uuid4() + agent_id = uuid.uuid4() + audit_logger = MagicMock() + audit_logger.log_tool_call = AsyncMock() + + tool_call = { + "function": { + "name": "test_tool", + "arguments": '{"query": "test"}', + } + } + + await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger) + + assert "agent_id" in received_kwargs + assert received_kwargs["agent_id"] == str(agent_id) + + @pytest.mark.asyncio + async def test_injected_ids_are_strings(self) -> None: + """Injected tenant_id and agent_id should be strings, not UUIDs.""" + from orchestrator.tools.executor import execute_tool + + received_kwargs: dict[str, Any] = {} + + async def mock_handler(**kwargs: Any) -> str: + received_kwargs.update(kwargs) + return "ok" + + tool = _make_tool(mock_handler) + registry = {"test_tool": tool} + + tenant_id = uuid.uuid4() + agent_id = uuid.uuid4() + audit_logger = MagicMock() + audit_logger.log_tool_call = AsyncMock() + + tool_call = { + "function": { + "name": "test_tool", + "arguments": '{"query": "test"}', + } + } + + await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger) + + assert isinstance(received_kwargs["tenant_id"], str) + assert isinstance(received_kwargs["agent_id"], str) + + @pytest.mark.asyncio + async def test_llm_provided_args_preserved(self) -> None: + """Original LLM-provided args should still be present after injection.""" + from orchestrator.tools.executor import execute_tool + + received_kwargs: dict[str, Any] = {} + + async def mock_handler(**kwargs: Any) -> str: + received_kwargs.update(kwargs) + return "ok" + + tool = _make_tool(mock_handler) + registry = {"test_tool": tool} + + tenant_id = uuid.uuid4() + agent_id = uuid.uuid4() + audit_logger = MagicMock() + audit_logger.log_tool_call = AsyncMock() + + tool_call = { + "function": { + "name": "test_tool", + "arguments": '{"query": "search term from LLM"}', + } + } + + await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger) + + assert received_kwargs["query"] == "search term from LLM" + assert received_kwargs["tenant_id"] == str(tenant_id) + assert received_kwargs["agent_id"] == str(agent_id) + + @pytest.mark.asyncio + async def test_injection_after_schema_validation(self) -> None: + """Injection happens after validation — injected keys don't cause schema failures.""" + from orchestrator.tools.executor import execute_tool + + # Tool requires exactly 'query', nothing else in schema required + # Schema should pass even though we inject tenant_id/agent_id + async def mock_handler(**kwargs: Any) -> str: + return "passed" + + tool = _make_tool(mock_handler) + registry = {"test_tool": tool} + + tenant_id = uuid.uuid4() + agent_id = uuid.uuid4() + audit_logger = MagicMock() + audit_logger.log_tool_call = AsyncMock() + + tool_call = { + "function": { + "name": "test_tool", + "arguments": '{"query": "test"}', + } + } + + result = await execute_tool(tool_call, registry, tenant_id, agent_id, audit_logger) + assert result == "passed" diff --git a/tests/unit/test_ingestion.py b/tests/unit/test_ingestion.py new file mode 100644 index 0000000..8ff47ed --- /dev/null +++ b/tests/unit/test_ingestion.py @@ -0,0 +1,183 @@ +""" +Unit tests for the KB ingestion pipeline. + +Tests: + - chunk_text: sliding window chunker produces correctly-sized, overlapping chunks + - ingest_document_pipeline: downloads file from MinIO, extracts, chunks, embeds, stores + - ingest_document_pipeline: sets status='error' on failure +""" + +from __future__ import annotations + +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +class TestChunkText: + def test_basic_chunking(self) -> None: + from orchestrator.tools.ingest import chunk_text + + text = "a" * 1000 + chunks = chunk_text(text, chunk_size=100, overlap=10) + + assert len(chunks) > 0 + for chunk in chunks: + assert len(chunk) <= 100 + + def test_overlap_between_chunks(self) -> None: + from orchestrator.tools.ingest import chunk_text + + # Create text with identifiable segments + text = "AAAA" * 50 + "BBBB" * 50 # 400 chars + chunks = chunk_text(text, chunk_size=200, overlap=50) + + # With overlap=50, consecutive chunks should share chars + assert len(chunks) >= 2 + + def test_short_text_returns_one_chunk(self) -> None: + from orchestrator.tools.ingest import chunk_text + + text = "Hello world" + chunks = chunk_text(text, chunk_size=500, overlap=50) + + assert len(chunks) == 1 + assert chunks[0] == "Hello world" + + def test_empty_text_returns_empty_list(self) -> None: + from orchestrator.tools.ingest import chunk_text + + chunks = chunk_text("", chunk_size=500, overlap=50) + assert chunks == [] + + def test_whitespace_only_returns_empty_list(self) -> None: + from orchestrator.tools.ingest import chunk_text + + chunks = chunk_text(" \n ", chunk_size=500, overlap=50) + assert chunks == [] + + def test_default_parameters(self) -> None: + from orchestrator.tools.ingest import chunk_text + + text = "word " * 500 # 2500 chars + chunks = chunk_text(text) + + assert len(chunks) > 1 + # Default chunk_size is 500 + for chunk in chunks: + assert len(chunk) <= 500 + + +class TestIngestDocumentPipeline: + @pytest.mark.asyncio + async def test_file_upload_sets_status_ready(self) -> None: + """Pipeline downloads file, extracts, chunks, embeds, stores, sets ready.""" + from orchestrator.tools.ingest import ingest_document_pipeline + + tenant_id = str(uuid.uuid4()) + document_id = str(uuid.uuid4()) + + mock_doc = MagicMock() + mock_doc.id = uuid.UUID(document_id) + mock_doc.tenant_id = uuid.UUID(tenant_id) + mock_doc.filename = "test.txt" + mock_doc.source_url = None + mock_doc.status = "processing" + + with ( + patch("orchestrator.tools.ingest.async_session_factory") as mock_sf, + patch("orchestrator.tools.ingest.engine"), + patch("orchestrator.tools.ingest.configure_rls_hook"), + patch("orchestrator.tools.ingest.current_tenant_id"), + patch("orchestrator.tools.ingest._get_minio_client") as mock_minio, + patch("orchestrator.tools.ingest.extract_text", return_value="Test content " * 50) as mock_extract, + patch("orchestrator.tools.ingest.embed_texts", return_value=[[0.1] * 384]) as mock_embed, + ): + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = mock_doc + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.commit = AsyncMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + mock_sf.return_value = mock_session + + # MinIO returns file bytes + minio_client = MagicMock() + response_obj = MagicMock() + response_obj.read.return_value = b"Test content " * 50 + minio_client.get_object.return_value = response_obj + mock_minio.return_value = minio_client + + await ingest_document_pipeline(document_id, tenant_id) + + # Status should be set to 'ready' on the document + assert mock_doc.status == "ready" + assert mock_doc.chunk_count is not None + + @pytest.mark.asyncio + async def test_pipeline_sets_error_on_exception(self) -> None: + """Pipeline marks document as error when extraction fails.""" + from orchestrator.tools.ingest import ingest_document_pipeline + + tenant_id = str(uuid.uuid4()) + document_id = str(uuid.uuid4()) + + mock_doc = MagicMock() + mock_doc.id = uuid.UUID(document_id) + mock_doc.tenant_id = uuid.UUID(tenant_id) + mock_doc.filename = "test.txt" + mock_doc.source_url = None + mock_doc.status = "processing" + + with ( + patch("orchestrator.tools.ingest.async_session_factory") as mock_sf, + patch("orchestrator.tools.ingest.engine"), + patch("orchestrator.tools.ingest.configure_rls_hook"), + patch("orchestrator.tools.ingest.current_tenant_id"), + patch("orchestrator.tools.ingest._get_minio_client") as mock_minio, + ): + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = mock_doc + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.commit = AsyncMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + mock_sf.return_value = mock_session + + # MinIO raises an error + minio_client = MagicMock() + minio_client.get_object.side_effect = Exception("MinIO connection failed") + mock_minio.return_value = minio_client + + await ingest_document_pipeline(document_id, tenant_id) + + assert mock_doc.status == "error" + assert mock_doc.error_message is not None + + @pytest.mark.asyncio + async def test_document_not_found_is_no_op(self) -> None: + """If document doesn't exist, pipeline exits gracefully.""" + from orchestrator.tools.ingest import ingest_document_pipeline + + tenant_id = str(uuid.uuid4()) + document_id = str(uuid.uuid4()) + + with ( + patch("orchestrator.tools.ingest.async_session_factory") as mock_sf, + patch("orchestrator.tools.ingest.engine"), + patch("orchestrator.tools.ingest.configure_rls_hook"), + patch("orchestrator.tools.ingest.current_tenant_id"), + ): + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None # Not found + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + mock_sf.return_value = mock_session + + # Should not raise + await ingest_document_pipeline(document_id, tenant_id)