feat(02-06): re-wire escalation and WhatsApp outbound routing in pipeline

- Move key imports to module level in tasks.py for testability and clarity
- Pop WhatsApp extras (phone_number_id, bot_token) in handle_message before model_validate
- Build unified extras dict and extract wa_id from sender.user_id
- Change _process_message signature to accept extras dict
- Add _build_response_extras() helper for channel-aware extras assembly
- Replace all _update_slack_placeholder calls in _process_message with _send_response()
- Add escalation pre-check: skip LLM when Redis escalation_status_key == 'escalated'
- Add escalation post-check: check_escalation_rules after run_agent; call escalate_to_human
  when rule matches and agent.escalation_assignee is set
- Add _build_conversation_metadata() helper (billing keyword v1 detection)
- Add channel parameter to build_system_prompt(), build_messages_with_memory(),
  build_messages_with_media() for WhatsApp tier-2 business-function scoping
- WhatsApp scoping appends 'You only handle: {topics}' when tool_assignments non-empty
- Pass msg.channel to build_messages_with_memory() in _process_message
- All 26 new tests pass; all existing escalation/WhatsApp tests pass (no regressions)
This commit is contained in:
2026-03-23 19:15:20 -06:00
parent 77c9cfc825
commit bd217a4113
3 changed files with 380 additions and 226 deletions

View File

@@ -15,7 +15,6 @@ Tests verify:
from __future__ import annotations
import asyncio
import json
import uuid
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
@@ -24,7 +23,7 @@ import pytest
# ---------------------------------------------------------------------------
# Helpers
# Shared test helpers
# ---------------------------------------------------------------------------
@@ -93,6 +92,89 @@ def make_message_data(
return data
def make_fake_redis(escalated: bool = False) -> AsyncMock:
"""Build a fakeredis-like mock for the async redis client."""
redis_mock = AsyncMock()
redis_mock.get = AsyncMock(return_value=b"escalated" if escalated else None)
redis_mock.delete = AsyncMock()
redis_mock.set = AsyncMock()
redis_mock.setex = AsyncMock()
redis_mock.aclose = AsyncMock()
return redis_mock
def make_session_factory_mock(agent: Any) -> MagicMock:
"""
Build a mock for async_session_factory that returns the given agent.
The returned mock is callable and its return value acts as an async
context manager that yields a session whose execute() returns the agent.
"""
# Build result chain: session.execute() -> result.scalars().first() -> agent
scalars_mock = MagicMock()
scalars_mock.first.return_value = agent
result_mock = MagicMock()
result_mock.scalars.return_value = scalars_mock
session = AsyncMock()
session.execute = AsyncMock(return_value=result_mock)
# Make session_factory() return an async context manager
cm = AsyncMock()
cm.__aenter__ = AsyncMock(return_value=session)
cm.__aexit__ = AsyncMock(return_value=False)
factory = MagicMock()
factory.return_value = cm
return factory
def make_process_message_msg(
channel: str = "slack",
tenant_id: str | None = None,
user_id: str = "U123",
text: str = "hello",
thread_id: str = "thread_001",
) -> Any:
"""Build a KonstructMessage for use with _process_message tests."""
from shared.models.message import KonstructMessage
tid = tenant_id or str(uuid.uuid4())
data: dict[str, Any] = {
"id": str(uuid.uuid4()),
"tenant_id": tid,
"channel": channel,
"channel_metadata": {},
"sender": {"user_id": user_id, "display_name": "Test User"},
"content": {"text": text, "attachments": []},
"timestamp": "2026-01-01T00:00:00Z",
"thread_id": thread_id,
"reply_to": None,
"context": {},
}
return KonstructMessage.model_validate(data)
def _standard_process_patches(
agent: Any,
escalated: bool = False,
run_agent_return: str = "Agent reply",
escalation_rule: dict | None = None,
escalate_return: str = "I've brought in a team member",
) -> dict:
"""
Return a dict of patch targets and their corresponding mocks.
Callers use this with contextlib.ExitStack or individual patch() calls.
"""
return {
"redis": make_fake_redis(escalated=escalated),
"session_factory": make_session_factory_mock(agent),
"run_agent_return": run_agent_return,
"escalation_rule": escalation_rule,
"escalate_return": escalate_return,
}
# ---------------------------------------------------------------------------
# Task 1 Tests — handle_message extra-field extraction
# ---------------------------------------------------------------------------
@@ -242,44 +324,13 @@ class TestSendResponseRouting:
class TestProcessMessageUsesSSendResponse:
"""_process_message must use _send_response for ALL response delivery — never direct _update_slack_placeholder."""
def _make_fake_redis(self, escalated: bool = False) -> AsyncMock:
"""Build a fakeredis-like mock for the async client."""
redis_mock = AsyncMock()
redis_mock.get = AsyncMock(return_value=b"escalated" if escalated else None)
redis_mock.delete = AsyncMock()
redis_mock.set = AsyncMock()
redis_mock.setex = AsyncMock()
redis_mock.aclose = AsyncMock()
return redis_mock
def _make_session_mock(self, agent: Any) -> AsyncMock:
"""Build a mock async DB session that returns the given agent."""
session = AsyncMock()
# Scalars mock
scalars_mock = MagicMock()
scalars_mock.first.return_value = agent
result_mock = AsyncMock()
result_mock.scalars.return_value = scalars_mock
session.execute = AsyncMock(return_value=result_mock)
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock()
return session
@pytest.mark.asyncio
async def test_process_message_calls_send_response_not_update_slack(self) -> None:
"""_process_message must call _send_response for Slack, NOT _update_slack_placeholder directly."""
from orchestrator.tasks import _process_message
from shared.models.message import KonstructMessage
agent = make_agent()
tenant_id = str(uuid.uuid4())
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
msg_data.pop("placeholder_ts", None)
msg_data.pop("channel_id", None)
msg_data.pop("phone_number_id", None)
msg_data.pop("bot_token", None)
msg = KonstructMessage.model_validate(msg_data)
msg = make_process_message_msg(channel="slack")
extras = {
"placeholder_ts": "1234.5678",
"channel_id": "C123",
@@ -287,11 +338,14 @@ class TestProcessMessageUsesSSendResponse:
"bot_token": "",
}
redis_mock = make_fake_redis(escalated=False)
sf_mock = make_session_factory_mock(agent)
with (
patch("orchestrator.tasks._send_response", new_callable=AsyncMock) as mock_send,
patch("orchestrator.tasks._update_slack_placeholder", new_callable=AsyncMock) as mock_update,
patch("orchestrator.tasks.aioredis") as mock_aioredis,
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
patch("orchestrator.tasks.async_session_factory", sf_mock),
patch("orchestrator.tasks.configure_rls_hook"),
patch("orchestrator.tasks.current_tenant_id"),
patch("orchestrator.tasks.engine"),
@@ -306,15 +360,8 @@ class TestProcessMessageUsesSSendResponse:
patch("orchestrator.tasks.AuditLogger"),
patch("orchestrator.tasks.check_escalation_rules", return_value=None),
):
# Setup redis mock to return None for pre-check (not escalated)
redis_mock = self._make_fake_redis(escalated=False)
mock_aioredis.from_url.return_value = redis_mock
# Setup session to return our agent
session_mock = self._make_session_mock(agent)
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
mock_session_factory.return_value.__aexit__ = AsyncMock()
await _process_message(msg, extras=extras)
# _send_response MUST be called for delivery
@@ -326,37 +373,30 @@ class TestProcessMessageUsesSSendResponse:
async def test_whatsapp_extras_passed_to_send_response(self) -> None:
"""_process_message must pass phone_number_id, bot_token, wa_id to _send_response for WhatsApp."""
from orchestrator.tasks import _process_message
from shared.models.message import KonstructMessage
agent = make_agent()
tenant_id = str(uuid.uuid4())
msg_data = make_message_data(
channel="whatsapp",
tenant_id=tenant_id,
user_id="15551234567",
)
msg_data.pop("placeholder_ts", None)
msg_data.pop("channel_id", None)
msg_data.pop("phone_number_id", None)
msg_data.pop("bot_token", None)
msg = KonstructMessage.model_validate(msg_data)
# user_id = "15551234567" -> becomes wa_id
msg = make_process_message_msg(channel="whatsapp", user_id="15551234567")
extras = {
"placeholder_ts": "",
"channel_id": "",
"phone_number_id": "9876543210",
"bot_token": "EAAaccess",
"wa_id": "15551234567",
}
captured_extras: list[dict] = []
async def capture_send(channel: str, text: str, ext: dict) -> None:
captured_extras.append({"channel": channel, "text": text, "extras": ext})
async def capture_send(channel: Any, text: str, ext: dict) -> None:
captured_extras.append({"channel": str(channel), "text": text, "extras": ext})
redis_mock = make_fake_redis(escalated=False)
sf_mock = make_session_factory_mock(agent)
with (
patch("orchestrator.tasks._send_response", side_effect=capture_send),
patch("orchestrator.tasks.aioredis") as mock_aioredis,
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
patch("orchestrator.tasks.async_session_factory", sf_mock),
patch("orchestrator.tasks.configure_rls_hook"),
patch("orchestrator.tasks.current_tenant_id"),
patch("orchestrator.tasks.engine"),
@@ -371,15 +411,7 @@ class TestProcessMessageUsesSSendResponse:
patch("orchestrator.tasks.AuditLogger"),
patch("orchestrator.tasks.check_escalation_rules", return_value=None),
):
redis_mock = self._make_fake_redis(escalated=False)
mock_aioredis.from_url.return_value = redis_mock
# Configure context manager for redis client used in finally blocks
redis_mock.__aenter__ = AsyncMock(return_value=redis_mock)
redis_mock.__aexit__ = AsyncMock()
session_mock = self._make_session_mock(agent)
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
mock_session_factory.return_value.__aexit__ = AsyncMock()
await _process_message(msg, extras=extras)
@@ -389,7 +421,7 @@ class TestProcessMessageUsesSSendResponse:
wa_extras = wa_calls[0]["extras"]
assert wa_extras.get("phone_number_id") == "9876543210"
assert wa_extras.get("bot_token") == "EAAaccess"
# wa_id must come from sender.user_id
# wa_id must come from extras (which was populated from sender.user_id in handle_message)
assert wa_extras.get("wa_id") == "15551234567"
@@ -401,42 +433,22 @@ class TestProcessMessageUsesSSendResponse:
class TestEscalationPreCheck:
"""When escalation status is 'escalated' in Redis, _process_message must return early without LLM call."""
def _make_session_mock(self, agent: Any) -> AsyncMock:
session = AsyncMock()
scalars_mock = MagicMock()
scalars_mock.first.return_value = agent
result_mock = AsyncMock()
result_mock.scalars.return_value = scalars_mock
session.execute = AsyncMock(return_value=result_mock)
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock()
return session
@pytest.mark.asyncio
async def test_escalated_conversation_skips_run_agent(self) -> None:
"""When Redis shows 'escalated', run_agent must NOT be called."""
from orchestrator.tasks import _process_message
from shared.models.message import KonstructMessage
agent = make_agent()
tenant_id = str(uuid.uuid4())
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
msg_data.pop("placeholder_ts", None)
msg_data.pop("channel_id", None)
msg_data.pop("phone_number_id", None)
msg_data.pop("bot_token", None)
msg = KonstructMessage.model_validate(msg_data)
msg = make_process_message_msg(channel="slack")
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
redis_mock = AsyncMock()
redis_mock.get = AsyncMock(return_value=b"escalated")
redis_mock.aclose = AsyncMock()
redis_mock = make_fake_redis(escalated=True)
sf_mock = make_session_factory_mock(agent)
with (
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
patch("orchestrator.tasks.aioredis") as mock_aioredis,
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
patch("orchestrator.tasks.async_session_factory", sf_mock),
patch("orchestrator.tasks.configure_rls_hook"),
patch("orchestrator.tasks.current_tenant_id"),
patch("orchestrator.tasks.engine"),
@@ -444,12 +456,6 @@ class TestEscalationPreCheck:
patch("orchestrator.tasks.AuditLogger"),
):
mock_aioredis.from_url.return_value = redis_mock
redis_mock.__aenter__ = AsyncMock(return_value=redis_mock)
redis_mock.__aexit__ = AsyncMock()
session_mock = self._make_session_mock(agent)
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
mock_session_factory.return_value.__aexit__ = AsyncMock()
result = await _process_message(msg, extras=extras)
@@ -468,42 +474,24 @@ class TestEscalationPreCheck:
class TestEscalationPostCheck:
"""check_escalation_rules called after run_agent; escalate_to_human called when rule matches."""
def _make_session_mock(self, agent: Any) -> AsyncMock:
session = AsyncMock()
scalars_mock = MagicMock()
scalars_mock.first.return_value = agent
result_mock = AsyncMock()
result_mock.scalars.return_value = scalars_mock
session.execute = AsyncMock(return_value=result_mock)
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock()
return session
@pytest.mark.asyncio
async def test_check_escalation_rules_called_after_run_agent(self) -> None:
"""check_escalation_rules must be called after run_agent returns."""
from orchestrator.tasks import _process_message
from shared.models.message import KonstructMessage
agent = make_agent(escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}])
tenant_id = str(uuid.uuid4())
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
msg_data.pop("placeholder_ts", None)
msg_data.pop("channel_id", None)
msg_data.pop("phone_number_id", None)
msg_data.pop("bot_token", None)
msg = KonstructMessage.model_validate(msg_data)
agent = make_agent(
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}]
)
msg = make_process_message_msg(channel="slack")
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
redis_mock = AsyncMock()
redis_mock.get = AsyncMock(return_value=None)
redis_mock.aclose = AsyncMock()
redis_mock = make_fake_redis(escalated=False)
sf_mock = make_session_factory_mock(agent)
with (
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
patch("orchestrator.tasks.aioredis") as mock_aioredis,
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
patch("orchestrator.tasks.async_session_factory", sf_mock),
patch("orchestrator.tasks.configure_rls_hook"),
patch("orchestrator.tasks.current_tenant_id"),
patch("orchestrator.tasks.engine"),
@@ -520,10 +508,6 @@ class TestEscalationPostCheck:
):
mock_aioredis.from_url.return_value = redis_mock
session_mock = self._make_session_mock(agent)
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
mock_session_factory.return_value.__aexit__ = AsyncMock()
await _process_message(msg, extras=extras)
# check_escalation_rules must have been called
@@ -533,32 +517,22 @@ class TestEscalationPostCheck:
async def test_escalate_to_human_called_when_rule_matches_and_assignee_set(self) -> None:
"""When check_escalation_rules returns a rule and assignee is set, escalate_to_human must be called."""
from orchestrator.tasks import _process_message
from shared.models.message import KonstructMessage
agent = make_agent(
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}],
escalation_assignee="U_MANAGER",
)
tenant_id = str(uuid.uuid4())
msg_data = make_message_data(channel="slack", tenant_id=tenant_id, text="refund issue again")
msg_data.pop("placeholder_ts", None)
msg_data.pop("channel_id", None)
msg_data.pop("phone_number_id", None)
msg_data.pop("bot_token", None)
msg = KonstructMessage.model_validate(msg_data)
msg = make_process_message_msg(channel="slack", text="refund issue again")
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}
redis_mock = AsyncMock()
redis_mock.get = AsyncMock(return_value=None)
redis_mock.aclose = AsyncMock()
redis_mock = make_fake_redis(escalated=False)
sf_mock = make_session_factory_mock(agent)
with (
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
patch("orchestrator.tasks.aioredis") as mock_aioredis,
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
patch("orchestrator.tasks.async_session_factory", sf_mock),
patch("orchestrator.tasks.configure_rls_hook"),
patch("orchestrator.tasks.current_tenant_id"),
patch("orchestrator.tasks.engine"),
@@ -572,51 +546,41 @@ class TestEscalationPostCheck:
patch("orchestrator.tasks.get_tools_for_agent", return_value=[]),
patch("orchestrator.tasks.AuditLogger"),
patch("orchestrator.tasks.check_escalation_rules", return_value=matched_rule),
patch("orchestrator.tasks.escalate_to_human", new_callable=AsyncMock, return_value="I've brought in a team member") as mock_escalate,
patch(
"orchestrator.tasks.escalate_to_human",
new_callable=AsyncMock,
return_value="I've brought in a team member",
) as mock_escalate,
):
mock_aioredis.from_url.return_value = redis_mock
session_mock = self._make_session_mock(agent)
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
mock_session_factory.return_value.__aexit__ = AsyncMock()
result = await _process_message(msg, extras=extras)
# escalate_to_human must be called
mock_escalate.assert_called_once()
# Response should be the escalation confirmation
assert "team member" in result["response"] or mock_escalate.return_value in result["response"]
assert "team member" in result["response"]
@pytest.mark.asyncio
async def test_escalate_to_human_not_called_when_no_assignee(self) -> None:
"""When rule matches but escalation_assignee is None, escalate_to_human must NOT be called."""
from orchestrator.tasks import _process_message
from shared.models.message import KonstructMessage
agent = make_agent(
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}],
escalation_assignee=None,
)
tenant_id = str(uuid.uuid4())
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
msg_data.pop("placeholder_ts", None)
msg_data.pop("channel_id", None)
msg_data.pop("phone_number_id", None)
msg_data.pop("bot_token", None)
msg = KonstructMessage.model_validate(msg_data)
msg = make_process_message_msg(channel="slack")
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}
redis_mock = AsyncMock()
redis_mock.get = AsyncMock(return_value=None)
redis_mock.aclose = AsyncMock()
redis_mock = make_fake_redis(escalated=False)
sf_mock = make_session_factory_mock(agent)
with (
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
patch("orchestrator.tasks.aioredis") as mock_aioredis,
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
patch("orchestrator.tasks.async_session_factory", sf_mock),
patch("orchestrator.tasks.configure_rls_hook"),
patch("orchestrator.tasks.current_tenant_id"),
patch("orchestrator.tasks.engine"),
@@ -634,10 +598,6 @@ class TestEscalationPostCheck:
):
mock_aioredis.from_url.return_value = redis_mock
session_mock = self._make_session_mock(agent)
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
mock_session_factory.return_value.__aexit__ = AsyncMock()
await _process_message(msg, extras=extras)
# escalate_to_human must NOT be called — no assignee configured