diff --git a/packages/orchestrator/orchestrator/agents/builder.py b/packages/orchestrator/orchestrator/agents/builder.py index 09fd98e..ec0c953 100644 --- a/packages/orchestrator/orchestrator/agents/builder.py +++ b/packages/orchestrator/orchestrator/agents/builder.py @@ -133,7 +133,7 @@ def generate_presigned_url(storage_key: str, expiry: int = 3600) -> str: return presigned_url -def build_system_prompt(agent: Agent) -> str: +def build_system_prompt(agent: Agent, channel: str = "") -> str: """ Assemble the full system prompt for an agent. @@ -142,9 +142,20 @@ def build_system_prompt(agent: Agent) -> str: - Identity section: name + role - Persona section (if agent.persona is non-empty) - AI transparency clause (always appended) + - WhatsApp tier-2 scoping (only when channel == "whatsapp" and tool_assignments set) + + WhatsApp Tier-2 Business-Function Scoping: + Per Meta 2026 policy, agents responding on WhatsApp must constrain themselves + to declared business functions. When channel == "whatsapp" and the agent has + non-empty tool_assignments, a scoping clause is appended that instructs the LLM + to handle only the declared topics and redirect off-topic requests. + This is the tier-2 gate (borderline messages reach the LLM with this constraint; + clearly off-topic messages are caught by the tier-1 keyword gate in the gateway). Args: - agent: ORM Agent instance. + agent: ORM Agent instance. + channel: Channel name (e.g. "slack", "whatsapp"). Used for tier-2 scoping. + Default empty string means no channel-specific scoping. Returns: A complete system prompt string ready to pass to the LLM. @@ -167,6 +178,17 @@ def build_system_prompt(agent: Agent) -> str: "If asked directly whether you are an AI, always respond honestly that you are an AI assistant." ) + # 5. WhatsApp tier-2 scoping — constrain LLM to declared business functions + if channel == "whatsapp": + functions: list[str] = getattr(agent, "tool_assignments", []) or [] + if functions: + topics = ", ".join(functions) + parts.append( + f"You are responding on WhatsApp. You only handle: {topics}. " + f"If the user asks about something outside these topics, " + f"politely redirect them to the allowed topics." + ) + return "\n\n".join(parts) @@ -204,12 +226,14 @@ def build_messages_with_memory( current_message: str, recent_messages: list[dict], relevant_context: list[str], + channel: str = "", ) -> list[dict]: """ Build an LLM messages array enriched with two-layer memory. Structure (in order): - 1. System message — agent identity, persona, AI transparency clause + 1. System message — agent identity, persona, AI transparency clause, + and optional WhatsApp tier-2 scoping (when channel == "whatsapp") 2. System message — long-term context from pgvector (ONLY if non-empty) Injected as a system message before the sliding window so the LLM has relevant background without it appearing in the conversation. @@ -226,11 +250,13 @@ def build_messages_with_memory( from Redis sliding window (oldest first). relevant_context: Long-term memory — list of content strings from pgvector similarity search (most relevant first). + channel: Channel name for tier-2 scoping (e.g. "whatsapp"). + Default empty string means no channel-specific scoping. Returns: List of message dicts suitable for an OpenAI-compatible API call. """ - system_prompt = build_system_prompt(agent) + system_prompt = build_system_prompt(agent, channel=channel) messages: list[dict] = [{"role": "system", "content": system_prompt}] # Inject long-term pgvector context as a system message BEFORE sliding window @@ -256,6 +282,7 @@ def build_messages_with_media( media_attachments: list[MediaAttachment], recent_messages: list[dict], relevant_context: list[str], + channel: str = "", ) -> list[dict]: """ Build an LLM messages array with memory enrichment AND multimodal media injection. @@ -277,7 +304,8 @@ def build_messages_with_media( gracefully rather than raising an error. Structure (in order): - 1. System message — agent identity, persona, AI transparency clause + 1. System message — agent identity, persona, AI transparency clause, + and optional WhatsApp tier-2 scoping (when channel == "whatsapp") 2. System message — long-term context (ONLY if non-empty) 3. Sliding window messages — recent history 4. Current user message (plain string or multipart content list) @@ -289,6 +317,8 @@ def build_messages_with_media( Empty list produces the same output as build_messages_with_memory(). recent_messages: Short-term memory — list of {"role", "content"} dicts. relevant_context: Long-term memory — list of content strings from pgvector. + channel: Channel name for tier-2 scoping (e.g. "whatsapp"). + Default empty string means no channel-specific scoping. Returns: List of message dicts suitable for an OpenAI/LiteLLM-compatible API call. @@ -301,6 +331,7 @@ def build_messages_with_media( current_message=current_message, recent_messages=recent_messages, relevant_context=relevant_context, + channel=channel, ) # If no media attachments, return the base messages unchanged diff --git a/packages/orchestrator/orchestrator/tasks.py b/packages/orchestrator/orchestrator/tasks.py index 84a6529..94a2616 100644 --- a/packages/orchestrator/orchestrator/tasks.py +++ b/packages/orchestrator/orchestrator/tasks.py @@ -35,6 +35,23 @@ Pending tool confirmation: On the next user message, if a pending confirmation exists: - "yes" → execute the pending tool and continue - "no" / anything else → cancel and inform the user + +Escalation pipeline (Phase 2 Plan 06): + Pre-check (before LLM call): + If Redis escalation_status_key == "escalated", return early assistant-mode reply. + This prevents the LLM from being called when a human has already taken over. + + Post-check (after LLM response): + check_escalation_rules() evaluates configured rules against conversation metadata. + If a rule matches AND agent.escalation_assignee is set, escalate_to_human() is + called and its return value replaces the LLM response. + +Outbound routing (Phase 2 Plan 06): + All response delivery goes through _send_response() which routes to: + - Slack: _update_slack_placeholder() via chat.update + - WhatsApp: send_whatsapp_message() via Meta Cloud API + handle_message now pops WhatsApp extras (phone_number_id, bot_token) and + passes them through to _process_message via the extras dict. """ from __future__ import annotations @@ -43,10 +60,25 @@ import asyncio import json import logging import uuid +from typing import Any + +import redis.asyncio as aioredis from gateway.channels.whatsapp import send_whatsapp_message +from orchestrator.agents.builder import build_messages_with_memory +from orchestrator.agents.runner import run_agent +from orchestrator.audit.logger import AuditLogger +from orchestrator.escalation.handler import check_escalation_rules, escalate_to_human from orchestrator.main import app +from orchestrator.memory.embedder import embed_text +from orchestrator.memory.long_term import retrieve_relevant +from orchestrator.memory.short_term import append_message, get_recent_messages +from orchestrator.tools.registry import get_tools_for_agent +from shared.config import settings +from shared.db import async_session_factory, engine from shared.models.message import KonstructMessage +from shared.redis_keys import escalation_status_key +from shared.rls import configure_rls_hook, current_tenant_id logger = logging.getLogger(__name__) @@ -152,23 +184,28 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped by the Channel Gateway after tenant resolution completes. The ``message_data`` dict MAY contain extra keys beyond KonstructMessage - fields. Specifically, the Slack handler injects: - - ``placeholder_ts``: Slack message timestamp of the "Thinking..." placeholder - - ``channel_id``: Slack channel ID where the response should be posted + fields. Specifically: + - Slack handler injects: + ``placeholder_ts``: Slack message timestamp of the "Thinking..." placeholder + ``channel_id``: Slack channel ID where the response should be posted + - WhatsApp gateway injects: + ``phone_number_id``: WhatsApp phone number ID for outbound messaging + ``bot_token``: WhatsApp access_token (injected as bot_token by gateway) - These are extracted before KonstructMessage validation and used to update - the placeholder with the real LLM response via chat.update. + These are extracted before KonstructMessage validation and used to route + outbound responses via _send_response(). Pipeline: - 1. Extract Slack reply metadata (placeholder_ts, channel_id) if present + 1. Extract channel reply metadata (Slack + WhatsApp extras) if present 2. Deserialize message_data -> KonstructMessage - 3. Run async agent pipeline via asyncio.run() - 4. If Slack metadata present: call chat.update to replace placeholder - 5. Return response dict + 3. Extract wa_id from sender.user_id for WhatsApp messages + 4. Build extras dict for channel-aware outbound routing + 5. Run async agent pipeline via asyncio.run() + 6. Return response dict Args: message_data: JSON-serializable dict. Must contain KonstructMessage - fields plus optional ``placeholder_ts`` and ``channel_id``. + fields plus optional channel-specific extras. Returns: Dict with keys: @@ -181,20 +218,39 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped placeholder_ts: str = message_data.pop("placeholder_ts", "") or "" channel_id: str = message_data.pop("channel_id", "") or "" + # Extract WhatsApp-specific extras before model validation + # The WhatsApp gateway injects these alongside the normalized KonstructMessage fields + phone_number_id: str = message_data.pop("phone_number_id", "") or "" + bot_token: str = message_data.pop("bot_token", "") or "" + try: msg = KonstructMessage.model_validate(message_data) except Exception as exc: logger.exception("Failed to deserialize KonstructMessage: %s", message_data) raise self.retry(exc=exc) - result = asyncio.run(_process_message(msg, placeholder_ts=placeholder_ts, channel_id=channel_id)) + # Extract wa_id from sender.user_id — WhatsApp normalizer sets sender.user_id + # to the wa_id (recipient phone number). This must happen AFTER model_validate. + wa_id: str = "" + if msg.channel == "whatsapp" and msg.sender and msg.sender.user_id: + wa_id = msg.sender.user_id + + # Build the unified extras dict for channel-aware outbound routing + extras: dict[str, Any] = { + "placeholder_ts": placeholder_ts, + "channel_id": channel_id, + "phone_number_id": phone_number_id, + "bot_token": bot_token, + "wa_id": wa_id, + } + + result = asyncio.run(_process_message(msg, extras=extras)) return result async def _process_message( msg: KonstructMessage, - placeholder_ts: str = "", - channel_id: str = "", + extras: dict[str, Any] | None = None, ) -> dict: """ Async agent pipeline — load agent config, build memory-enriched prompt, call LLM pool. @@ -215,24 +271,26 @@ async def _process_message( - Tool-call loop runs inside run_agent() — no separate Celery tasks - If run_agent returns a confirmation message: store pending action in Redis + Escalation pipeline (Phase 2 Plan 06 additions): + Pre-check: if Redis escalation_status_key == "escalated", return assistant-mode reply + Post-check: check_escalation_rules after LLM response; if triggered, escalate_to_human + + Outbound routing (Phase 2 Plan 06 additions): + All response delivery goes through _send_response() — never direct _update_slack_placeholder. + extras dict is passed through to _send_response for channel-aware routing. + Args: - msg: The deserialized KonstructMessage. - placeholder_ts: Slack message timestamp of the "Thinking..." placeholder. - channel_id: Slack channel ID for the chat.update call. + msg: The deserialized KonstructMessage. + extras: Channel-specific routing metadata. For Slack: placeholder_ts, channel_id, + bot_token. For WhatsApp: phone_number_id, bot_token, wa_id. Returns: Dict with message_id, response, and tenant_id. """ - from orchestrator.agents.builder import build_messages_with_memory - from orchestrator.agents.runner import run_agent - from orchestrator.audit.logger import AuditLogger - from orchestrator.memory.embedder import embed_text - from orchestrator.memory.long_term import retrieve_relevant - from orchestrator.memory.short_term import append_message, get_recent_messages - from orchestrator.tools.registry import get_tools_for_agent - from shared.db import async_session_factory, engine from shared.models.tenant import Agent - from shared.rls import configure_rls_hook, current_tenant_id + + if extras is None: + extras = {} if msg.tenant_id is None: logger.warning("Message %s has no tenant_id — cannot process", msg.id) @@ -265,8 +323,10 @@ async def _process_message( result = await session.execute(stmt) agent = result.scalars().first() - # Load the bot token for this tenant from channel_connections config - if agent is not None and placeholder_ts and channel_id: + # Load the Slack bot token for this tenant from channel_connections config. + # This is needed for escalation DM delivery even on WhatsApp messages — + # the escalation handler always sends via Slack DM to the assignee. + if agent is not None and (extras.get("placeholder_ts") and extras.get("channel_id")): from shared.models.tenant import ChannelConnection, ChannelTypeEnum conn_stmt = ( @@ -290,13 +350,9 @@ async def _process_message( msg.id, ) no_agent_response = "No active agent is configured for your workspace. Please contact your administrator." - if placeholder_ts and channel_id: - await _update_slack_placeholder( - bot_token=slack_bot_token, - channel_id=channel_id, - placeholder_ts=placeholder_ts, - text=no_agent_response, - ) + # Build response_extras for channel-aware delivery + response_extras = _build_response_extras(msg.channel, extras, slack_bot_token) + await _send_response(msg.channel, no_agent_response, response_extras) return { "message_id": msg.id, "response": no_agent_response, @@ -318,14 +374,30 @@ async def _process_message( # ------------------------------------------------------------------------- audit_logger = AuditLogger(session_factory=async_session_factory) + # Build response_extras dict used for all outbound delivery in this pipeline run. + # For Slack: merges DB-loaded slack_bot_token with incoming extras. + # For WhatsApp: extras already contain phone_number_id, bot_token, wa_id. + response_extras = _build_response_extras(msg.channel, extras, slack_bot_token) + + # ------------------------------------------------------------------------- + # Escalation pre-check — if conversation already escalated, reply in assistant mode + # ------------------------------------------------------------------------- + thread_key = msg.thread_id or user_id + esc_key = escalation_status_key(msg.tenant_id, thread_key) + pre_check_redis = aioredis.from_url(settings.redis_url) + try: + esc_status = await pre_check_redis.get(esc_key) + finally: + await pre_check_redis.aclose() + + if esc_status == b"escalated": + assistant_reply = "I've already connected you with a team member. They'll continue assisting you." + await _send_response(msg.channel, assistant_reply, response_extras) + return {"message_id": msg.id, "response": assistant_reply, "tenant_id": msg.tenant_id} + # ------------------------------------------------------------------------- # Pending tool confirmation check # ------------------------------------------------------------------------- - import redis.asyncio as aioredis - - from shared.config import settings - - redis_client = aioredis.from_url(settings.redis_url) pending_confirm_key = _PENDING_TOOL_KEY.format( tenant_id=msg.tenant_id, user_id=user_id, @@ -334,6 +406,7 @@ async def _process_message( response_text: str = "" handled_as_confirmation = False + redis_client = aioredis.from_url(settings.redis_url) try: pending_raw = await redis_client.get(pending_confirm_key) @@ -362,13 +435,7 @@ async def _process_message( await redis_client.aclose() if handled_as_confirmation: - if placeholder_ts and channel_id: - await _update_slack_placeholder( - bot_token=slack_bot_token, - channel_id=channel_id, - placeholder_ts=placeholder_ts, - text=response_text, - ) + await _send_response(msg.channel, response_text, response_extras) return { "message_id": msg.id, "response": response_text, @@ -412,6 +479,7 @@ async def _process_message( current_message=user_text, recent_messages=recent_messages, relevant_context=relevant_context, + channel=str(msg.channel) if msg.channel else "", ) # Build tool registry for this agent @@ -428,6 +496,37 @@ async def _process_message( tool_registry=tool_registry if tool_registry else None, ) + # ------------------------------------------------------------------------- + # Escalation post-check — evaluate rules against conversation metadata + # ------------------------------------------------------------------------- + conversation_metadata = _build_conversation_metadata(recent_messages, user_text) + + triggered_rule = check_escalation_rules( + agent=agent, + message_text=user_text, + conversation_metadata=conversation_metadata, + natural_lang_enabled=getattr(agent, "natural_language_escalation", False), + ) + + if triggered_rule and getattr(agent, "escalation_assignee", None): + escalation_redis = aioredis.from_url(settings.redis_url) + try: + response_text = await escalate_to_human( + tenant_id=msg.tenant_id, + agent=agent, + thread_id=thread_key, + trigger_reason=triggered_rule.get("condition", "rule triggered"), + recent_messages=recent_messages, + assignee_slack_user_id=agent.escalation_assignee, + bot_token=slack_bot_token, + redis=escalation_redis, + audit_logger=audit_logger, + user_id=user_id, + agent_id=agent_id_str, + ) + finally: + await escalation_redis.aclose() + # Check if the response is a tool confirmation request # The confirmation message template starts with a specific prefix is_confirmation_request = response_text.startswith("This action requires your approval") @@ -453,14 +552,8 @@ async def _process_message( len(relevant_context), ) - # Replace the "Thinking..." placeholder with the real response - if placeholder_ts and channel_id: - await _update_slack_placeholder( - bot_token=slack_bot_token, - channel_id=channel_id, - placeholder_ts=placeholder_ts, - text=response_text, - ) + # Send response via channel-aware routing + await _send_response(msg.channel, response_text, response_extras) # ------------------------------------------------------------------------- # Memory persistence (after LLM response) @@ -489,6 +582,74 @@ async def _process_message( } +def _build_response_extras( + channel: Any, + extras: dict[str, Any], + slack_bot_token: str, +) -> dict[str, Any]: + """ + Build the response_extras dict for channel-aware outbound delivery. + + For Slack: injects slack_bot_token into extras["bot_token"] so _send_response + can use it for chat.update calls. + For WhatsApp: extras already contain phone_number_id, bot_token (access_token), + and wa_id — no transformation needed. + + Args: + channel: Channel name from KonstructMessage.channel. + extras: Incoming extras from handle_message. + slack_bot_token: Bot token loaded from DB channel_connections. + + Returns: + Dict suitable for passing to _send_response. + """ + channel_str = str(channel) if channel else "" + if channel_str == "slack": + return { + "bot_token": slack_bot_token, + "channel_id": extras.get("channel_id", "") or "", + "placeholder_ts": extras.get("placeholder_ts", "") or "", + } + elif channel_str == "whatsapp": + return { + "phone_number_id": extras.get("phone_number_id", "") or "", + "bot_token": extras.get("bot_token", "") or "", + "wa_id": extras.get("wa_id", "") or "", + } + else: + return dict(extras) + + +def _build_conversation_metadata( + recent_messages: list[dict], + current_text: str, +) -> dict[str, Any]: + """ + Build conversation metadata dict for escalation rule evaluation. + + Scans recent messages and current_text for billing keywords and counts occurrences. + Returns a dict with: + - "billing_dispute" (bool): True if any billing keyword found in any message + - "attempts" (int): count of messages containing billing keywords + + This is the v1 keyword-based metadata detection (per STATE.md decisions). + + Args: + recent_messages: List of {"role", "content"} dicts from Redis sliding window. + current_text: The user's current message text. + + Returns: + Dict with billing_dispute and attempts keys. + """ + billing_keywords = {"billing", "invoice", "charge", "refund", "payment", "subscription"} + all_texts = [m.get("content", "") for m in recent_messages] + [current_text] + billing_count = sum(1 for t in all_texts if any(kw in t.lower() for kw in billing_keywords)) + return { + "billing_dispute": billing_count > 0, + "attempts": billing_count, + } + + async def _execute_pending_tool( pending_data: dict, tenant_uuid: uuid.UUID, @@ -526,7 +687,7 @@ def _extract_tool_name_from_confirmation(confirmation_message: str) -> str: async def _send_response( - channel: str, + channel: Any, text: str, extras: dict, ) -> None: @@ -545,7 +706,9 @@ async def _send_response( For Slack: ``bot_token``, ``channel_id``, ``placeholder_ts`` For WhatsApp: ``phone_number_id``, ``bot_token`` (access_token), ``wa_id`` """ - if channel == "slack": + channel_str = str(channel) if channel else "" + + if channel_str == "slack": bot_token: str = extras.get("bot_token", "") or "" channel_id: str = extras.get("channel_id", "") or "" placeholder_ts: str = extras.get("placeholder_ts", "") or "" @@ -563,7 +726,7 @@ async def _send_response( text=text, ) - elif channel == "whatsapp": + elif channel_str == "whatsapp": phone_number_id: str = extras.get("phone_number_id", "") or "" access_token: str = extras.get("bot_token", "") or "" wa_id: str = extras.get("wa_id", "") or "" diff --git a/tests/unit/test_pipeline_wiring.py b/tests/unit/test_pipeline_wiring.py index 9eb6ddf..7add691 100644 --- a/tests/unit/test_pipeline_wiring.py +++ b/tests/unit/test_pipeline_wiring.py @@ -15,7 +15,6 @@ Tests verify: from __future__ import annotations import asyncio -import json import uuid from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -24,7 +23,7 @@ import pytest # --------------------------------------------------------------------------- -# Helpers +# Shared test helpers # --------------------------------------------------------------------------- @@ -93,6 +92,89 @@ def make_message_data( return data +def make_fake_redis(escalated: bool = False) -> AsyncMock: + """Build a fakeredis-like mock for the async redis client.""" + redis_mock = AsyncMock() + redis_mock.get = AsyncMock(return_value=b"escalated" if escalated else None) + redis_mock.delete = AsyncMock() + redis_mock.set = AsyncMock() + redis_mock.setex = AsyncMock() + redis_mock.aclose = AsyncMock() + return redis_mock + + +def make_session_factory_mock(agent: Any) -> MagicMock: + """ + Build a mock for async_session_factory that returns the given agent. + + The returned mock is callable and its return value acts as an async + context manager that yields a session whose execute() returns the agent. + """ + # Build result chain: session.execute() -> result.scalars().first() -> agent + scalars_mock = MagicMock() + scalars_mock.first.return_value = agent + result_mock = MagicMock() + result_mock.scalars.return_value = scalars_mock + + session = AsyncMock() + session.execute = AsyncMock(return_value=result_mock) + + # Make session_factory() return an async context manager + cm = AsyncMock() + cm.__aenter__ = AsyncMock(return_value=session) + cm.__aexit__ = AsyncMock(return_value=False) + + factory = MagicMock() + factory.return_value = cm + return factory + + +def make_process_message_msg( + channel: str = "slack", + tenant_id: str | None = None, + user_id: str = "U123", + text: str = "hello", + thread_id: str = "thread_001", +) -> Any: + """Build a KonstructMessage for use with _process_message tests.""" + from shared.models.message import KonstructMessage + + tid = tenant_id or str(uuid.uuid4()) + data: dict[str, Any] = { + "id": str(uuid.uuid4()), + "tenant_id": tid, + "channel": channel, + "channel_metadata": {}, + "sender": {"user_id": user_id, "display_name": "Test User"}, + "content": {"text": text, "attachments": []}, + "timestamp": "2026-01-01T00:00:00Z", + "thread_id": thread_id, + "reply_to": None, + "context": {}, + } + return KonstructMessage.model_validate(data) + + +def _standard_process_patches( + agent: Any, + escalated: bool = False, + run_agent_return: str = "Agent reply", + escalation_rule: dict | None = None, + escalate_return: str = "I've brought in a team member", +) -> dict: + """ + Return a dict of patch targets and their corresponding mocks. + Callers use this with contextlib.ExitStack or individual patch() calls. + """ + return { + "redis": make_fake_redis(escalated=escalated), + "session_factory": make_session_factory_mock(agent), + "run_agent_return": run_agent_return, + "escalation_rule": escalation_rule, + "escalate_return": escalate_return, + } + + # --------------------------------------------------------------------------- # Task 1 Tests — handle_message extra-field extraction # --------------------------------------------------------------------------- @@ -242,44 +324,13 @@ class TestSendResponseRouting: class TestProcessMessageUsesSSendResponse: """_process_message must use _send_response for ALL response delivery — never direct _update_slack_placeholder.""" - def _make_fake_redis(self, escalated: bool = False) -> AsyncMock: - """Build a fakeredis-like mock for the async client.""" - redis_mock = AsyncMock() - redis_mock.get = AsyncMock(return_value=b"escalated" if escalated else None) - redis_mock.delete = AsyncMock() - redis_mock.set = AsyncMock() - redis_mock.setex = AsyncMock() - redis_mock.aclose = AsyncMock() - return redis_mock - - def _make_session_mock(self, agent: Any) -> AsyncMock: - """Build a mock async DB session that returns the given agent.""" - session = AsyncMock() - # Scalars mock - scalars_mock = MagicMock() - scalars_mock.first.return_value = agent - result_mock = AsyncMock() - result_mock.scalars.return_value = scalars_mock - session.execute = AsyncMock(return_value=result_mock) - session.__aenter__ = AsyncMock(return_value=session) - session.__aexit__ = AsyncMock() - return session - @pytest.mark.asyncio async def test_process_message_calls_send_response_not_update_slack(self) -> None: """_process_message must call _send_response for Slack, NOT _update_slack_placeholder directly.""" from orchestrator.tasks import _process_message - from shared.models.message import KonstructMessage agent = make_agent() - tenant_id = str(uuid.uuid4()) - msg_data = make_message_data(channel="slack", tenant_id=tenant_id) - msg_data.pop("placeholder_ts", None) - msg_data.pop("channel_id", None) - msg_data.pop("phone_number_id", None) - msg_data.pop("bot_token", None) - msg = KonstructMessage.model_validate(msg_data) - + msg = make_process_message_msg(channel="slack") extras = { "placeholder_ts": "1234.5678", "channel_id": "C123", @@ -287,11 +338,14 @@ class TestProcessMessageUsesSSendResponse: "bot_token": "", } + redis_mock = make_fake_redis(escalated=False) + sf_mock = make_session_factory_mock(agent) + with ( patch("orchestrator.tasks._send_response", new_callable=AsyncMock) as mock_send, patch("orchestrator.tasks._update_slack_placeholder", new_callable=AsyncMock) as mock_update, patch("orchestrator.tasks.aioredis") as mock_aioredis, - patch("orchestrator.tasks.async_session_factory") as mock_session_factory, + patch("orchestrator.tasks.async_session_factory", sf_mock), patch("orchestrator.tasks.configure_rls_hook"), patch("orchestrator.tasks.current_tenant_id"), patch("orchestrator.tasks.engine"), @@ -306,15 +360,8 @@ class TestProcessMessageUsesSSendResponse: patch("orchestrator.tasks.AuditLogger"), patch("orchestrator.tasks.check_escalation_rules", return_value=None), ): - # Setup redis mock to return None for pre-check (not escalated) - redis_mock = self._make_fake_redis(escalated=False) mock_aioredis.from_url.return_value = redis_mock - # Setup session to return our agent - session_mock = self._make_session_mock(agent) - mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock) - mock_session_factory.return_value.__aexit__ = AsyncMock() - await _process_message(msg, extras=extras) # _send_response MUST be called for delivery @@ -326,37 +373,30 @@ class TestProcessMessageUsesSSendResponse: async def test_whatsapp_extras_passed_to_send_response(self) -> None: """_process_message must pass phone_number_id, bot_token, wa_id to _send_response for WhatsApp.""" from orchestrator.tasks import _process_message - from shared.models.message import KonstructMessage agent = make_agent() - tenant_id = str(uuid.uuid4()) - msg_data = make_message_data( - channel="whatsapp", - tenant_id=tenant_id, - user_id="15551234567", - ) - msg_data.pop("placeholder_ts", None) - msg_data.pop("channel_id", None) - msg_data.pop("phone_number_id", None) - msg_data.pop("bot_token", None) - msg = KonstructMessage.model_validate(msg_data) - + # user_id = "15551234567" -> becomes wa_id + msg = make_process_message_msg(channel="whatsapp", user_id="15551234567") extras = { "placeholder_ts": "", "channel_id": "", "phone_number_id": "9876543210", "bot_token": "EAAaccess", + "wa_id": "15551234567", } captured_extras: list[dict] = [] - async def capture_send(channel: str, text: str, ext: dict) -> None: - captured_extras.append({"channel": channel, "text": text, "extras": ext}) + async def capture_send(channel: Any, text: str, ext: dict) -> None: + captured_extras.append({"channel": str(channel), "text": text, "extras": ext}) + + redis_mock = make_fake_redis(escalated=False) + sf_mock = make_session_factory_mock(agent) with ( patch("orchestrator.tasks._send_response", side_effect=capture_send), patch("orchestrator.tasks.aioredis") as mock_aioredis, - patch("orchestrator.tasks.async_session_factory") as mock_session_factory, + patch("orchestrator.tasks.async_session_factory", sf_mock), patch("orchestrator.tasks.configure_rls_hook"), patch("orchestrator.tasks.current_tenant_id"), patch("orchestrator.tasks.engine"), @@ -371,15 +411,7 @@ class TestProcessMessageUsesSSendResponse: patch("orchestrator.tasks.AuditLogger"), patch("orchestrator.tasks.check_escalation_rules", return_value=None), ): - redis_mock = self._make_fake_redis(escalated=False) mock_aioredis.from_url.return_value = redis_mock - # Configure context manager for redis client used in finally blocks - redis_mock.__aenter__ = AsyncMock(return_value=redis_mock) - redis_mock.__aexit__ = AsyncMock() - - session_mock = self._make_session_mock(agent) - mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock) - mock_session_factory.return_value.__aexit__ = AsyncMock() await _process_message(msg, extras=extras) @@ -389,7 +421,7 @@ class TestProcessMessageUsesSSendResponse: wa_extras = wa_calls[0]["extras"] assert wa_extras.get("phone_number_id") == "9876543210" assert wa_extras.get("bot_token") == "EAAaccess" - # wa_id must come from sender.user_id + # wa_id must come from extras (which was populated from sender.user_id in handle_message) assert wa_extras.get("wa_id") == "15551234567" @@ -401,42 +433,22 @@ class TestProcessMessageUsesSSendResponse: class TestEscalationPreCheck: """When escalation status is 'escalated' in Redis, _process_message must return early without LLM call.""" - def _make_session_mock(self, agent: Any) -> AsyncMock: - session = AsyncMock() - scalars_mock = MagicMock() - scalars_mock.first.return_value = agent - result_mock = AsyncMock() - result_mock.scalars.return_value = scalars_mock - session.execute = AsyncMock(return_value=result_mock) - session.__aenter__ = AsyncMock(return_value=session) - session.__aexit__ = AsyncMock() - return session - @pytest.mark.asyncio async def test_escalated_conversation_skips_run_agent(self) -> None: """When Redis shows 'escalated', run_agent must NOT be called.""" from orchestrator.tasks import _process_message - from shared.models.message import KonstructMessage agent = make_agent() - tenant_id = str(uuid.uuid4()) - msg_data = make_message_data(channel="slack", tenant_id=tenant_id) - msg_data.pop("placeholder_ts", None) - msg_data.pop("channel_id", None) - msg_data.pop("phone_number_id", None) - msg_data.pop("bot_token", None) - msg = KonstructMessage.model_validate(msg_data) - + msg = make_process_message_msg(channel="slack") extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""} - redis_mock = AsyncMock() - redis_mock.get = AsyncMock(return_value=b"escalated") - redis_mock.aclose = AsyncMock() + redis_mock = make_fake_redis(escalated=True) + sf_mock = make_session_factory_mock(agent) with ( patch("orchestrator.tasks._send_response", new_callable=AsyncMock), patch("orchestrator.tasks.aioredis") as mock_aioredis, - patch("orchestrator.tasks.async_session_factory") as mock_session_factory, + patch("orchestrator.tasks.async_session_factory", sf_mock), patch("orchestrator.tasks.configure_rls_hook"), patch("orchestrator.tasks.current_tenant_id"), patch("orchestrator.tasks.engine"), @@ -444,12 +456,6 @@ class TestEscalationPreCheck: patch("orchestrator.tasks.AuditLogger"), ): mock_aioredis.from_url.return_value = redis_mock - redis_mock.__aenter__ = AsyncMock(return_value=redis_mock) - redis_mock.__aexit__ = AsyncMock() - - session_mock = self._make_session_mock(agent) - mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock) - mock_session_factory.return_value.__aexit__ = AsyncMock() result = await _process_message(msg, extras=extras) @@ -468,42 +474,24 @@ class TestEscalationPreCheck: class TestEscalationPostCheck: """check_escalation_rules called after run_agent; escalate_to_human called when rule matches.""" - def _make_session_mock(self, agent: Any) -> AsyncMock: - session = AsyncMock() - scalars_mock = MagicMock() - scalars_mock.first.return_value = agent - result_mock = AsyncMock() - result_mock.scalars.return_value = scalars_mock - session.execute = AsyncMock(return_value=result_mock) - session.__aenter__ = AsyncMock(return_value=session) - session.__aexit__ = AsyncMock() - return session - @pytest.mark.asyncio async def test_check_escalation_rules_called_after_run_agent(self) -> None: """check_escalation_rules must be called after run_agent returns.""" from orchestrator.tasks import _process_message - from shared.models.message import KonstructMessage - - agent = make_agent(escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}]) - tenant_id = str(uuid.uuid4()) - msg_data = make_message_data(channel="slack", tenant_id=tenant_id) - msg_data.pop("placeholder_ts", None) - msg_data.pop("channel_id", None) - msg_data.pop("phone_number_id", None) - msg_data.pop("bot_token", None) - msg = KonstructMessage.model_validate(msg_data) + agent = make_agent( + escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}] + ) + msg = make_process_message_msg(channel="slack") extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""} - redis_mock = AsyncMock() - redis_mock.get = AsyncMock(return_value=None) - redis_mock.aclose = AsyncMock() + redis_mock = make_fake_redis(escalated=False) + sf_mock = make_session_factory_mock(agent) with ( patch("orchestrator.tasks._send_response", new_callable=AsyncMock), patch("orchestrator.tasks.aioredis") as mock_aioredis, - patch("orchestrator.tasks.async_session_factory") as mock_session_factory, + patch("orchestrator.tasks.async_session_factory", sf_mock), patch("orchestrator.tasks.configure_rls_hook"), patch("orchestrator.tasks.current_tenant_id"), patch("orchestrator.tasks.engine"), @@ -520,10 +508,6 @@ class TestEscalationPostCheck: ): mock_aioredis.from_url.return_value = redis_mock - session_mock = self._make_session_mock(agent) - mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock) - mock_session_factory.return_value.__aexit__ = AsyncMock() - await _process_message(msg, extras=extras) # check_escalation_rules must have been called @@ -533,32 +517,22 @@ class TestEscalationPostCheck: async def test_escalate_to_human_called_when_rule_matches_and_assignee_set(self) -> None: """When check_escalation_rules returns a rule and assignee is set, escalate_to_human must be called.""" from orchestrator.tasks import _process_message - from shared.models.message import KonstructMessage agent = make_agent( escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}], escalation_assignee="U_MANAGER", ) - tenant_id = str(uuid.uuid4()) - msg_data = make_message_data(channel="slack", tenant_id=tenant_id, text="refund issue again") - msg_data.pop("placeholder_ts", None) - msg_data.pop("channel_id", None) - msg_data.pop("phone_number_id", None) - msg_data.pop("bot_token", None) - msg = KonstructMessage.model_validate(msg_data) - + msg = make_process_message_msg(channel="slack", text="refund issue again") extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""} matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"} - - redis_mock = AsyncMock() - redis_mock.get = AsyncMock(return_value=None) - redis_mock.aclose = AsyncMock() + redis_mock = make_fake_redis(escalated=False) + sf_mock = make_session_factory_mock(agent) with ( patch("orchestrator.tasks._send_response", new_callable=AsyncMock), patch("orchestrator.tasks.aioredis") as mock_aioredis, - patch("orchestrator.tasks.async_session_factory") as mock_session_factory, + patch("orchestrator.tasks.async_session_factory", sf_mock), patch("orchestrator.tasks.configure_rls_hook"), patch("orchestrator.tasks.current_tenant_id"), patch("orchestrator.tasks.engine"), @@ -572,51 +546,41 @@ class TestEscalationPostCheck: patch("orchestrator.tasks.get_tools_for_agent", return_value=[]), patch("orchestrator.tasks.AuditLogger"), patch("orchestrator.tasks.check_escalation_rules", return_value=matched_rule), - patch("orchestrator.tasks.escalate_to_human", new_callable=AsyncMock, return_value="I've brought in a team member") as mock_escalate, + patch( + "orchestrator.tasks.escalate_to_human", + new_callable=AsyncMock, + return_value="I've brought in a team member", + ) as mock_escalate, ): mock_aioredis.from_url.return_value = redis_mock - session_mock = self._make_session_mock(agent) - mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock) - mock_session_factory.return_value.__aexit__ = AsyncMock() - result = await _process_message(msg, extras=extras) # escalate_to_human must be called mock_escalate.assert_called_once() # Response should be the escalation confirmation - assert "team member" in result["response"] or mock_escalate.return_value in result["response"] + assert "team member" in result["response"] @pytest.mark.asyncio async def test_escalate_to_human_not_called_when_no_assignee(self) -> None: """When rule matches but escalation_assignee is None, escalate_to_human must NOT be called.""" from orchestrator.tasks import _process_message - from shared.models.message import KonstructMessage agent = make_agent( escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}], escalation_assignee=None, ) - tenant_id = str(uuid.uuid4()) - msg_data = make_message_data(channel="slack", tenant_id=tenant_id) - msg_data.pop("placeholder_ts", None) - msg_data.pop("channel_id", None) - msg_data.pop("phone_number_id", None) - msg_data.pop("bot_token", None) - msg = KonstructMessage.model_validate(msg_data) - + msg = make_process_message_msg(channel="slack") extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""} matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"} - - redis_mock = AsyncMock() - redis_mock.get = AsyncMock(return_value=None) - redis_mock.aclose = AsyncMock() + redis_mock = make_fake_redis(escalated=False) + sf_mock = make_session_factory_mock(agent) with ( patch("orchestrator.tasks._send_response", new_callable=AsyncMock), patch("orchestrator.tasks.aioredis") as mock_aioredis, - patch("orchestrator.tasks.async_session_factory") as mock_session_factory, + patch("orchestrator.tasks.async_session_factory", sf_mock), patch("orchestrator.tasks.configure_rls_hook"), patch("orchestrator.tasks.current_tenant_id"), patch("orchestrator.tasks.engine"), @@ -634,10 +598,6 @@ class TestEscalationPostCheck: ): mock_aioredis.from_url.return_value = redis_mock - session_mock = self._make_session_mock(agent) - mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock) - mock_session_factory.return_value.__aexit__ = AsyncMock() - await _process_message(msg, extras=extras) # escalate_to_human must NOT be called — no assignee configured