From 77c9cfc825d89fa2e0c335e53c2fd31c871b0ed1 Mon Sep 17 00:00:00 2001 From: Adolfo Delorenzo Date: Mon, 23 Mar 2026 19:08:59 -0600 Subject: [PATCH] test(02-06): add failing tests for escalation wiring and WhatsApp outbound routing - Tests for handle_message WhatsApp extra extraction (phone_number_id, bot_token) - Tests for _send_response routing to Slack and WhatsApp - Tests for _process_message using _send_response (not _update_slack_placeholder directly) - Tests for escalation pre-check (skip LLM when already escalated) - Tests for escalation post-check (check_escalation_rules + escalate_to_human) - Tests for _build_conversation_metadata billing keyword extraction - Tests for build_system_prompt WhatsApp tier-2 scoping (Task 2) - Tests for build_messages_with_memory channel parameter passthrough --- tests/unit/test_pipeline_wiring.py | 813 +++++++++++++++++++++++++++++ 1 file changed, 813 insertions(+) create mode 100644 tests/unit/test_pipeline_wiring.py diff --git a/tests/unit/test_pipeline_wiring.py b/tests/unit/test_pipeline_wiring.py new file mode 100644 index 0000000..9eb6ddf --- /dev/null +++ b/tests/unit/test_pipeline_wiring.py @@ -0,0 +1,813 @@ +""" +Unit tests for pipeline wiring: escalation + outbound routing re-wiring (Plan 02-06). + +Tests verify: +- handle_message pops WhatsApp extras (phone_number_id, bot_token) before model_validate +- _process_message calls _send_response (not _update_slack_placeholder) for all delivery points +- Escalation pre-check: already-escalated conversations skip LLM call +- Escalation post-check: check_escalation_rules called after LLM response +- escalate_to_human called when rule matches and assignee is configured +- WhatsApp extras flow through to _send_response correctly +- build_system_prompt appends WhatsApp business-function scoping (Task 2) +- build_messages_with_memory/media pass channel through +""" + +from __future__ import annotations + +import asyncio +import json +import uuid +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_agent( + *, + name: str = "TestAgent", + role: str = "Support", + persona: str = "", + system_prompt: str = "", + tool_assignments: list[str] | None = None, + escalation_rules: list[dict] | None = None, + escalation_assignee: str | None = None, + natural_language_escalation: bool = False, + model_preference: str = "gpt-4o", +) -> MagicMock: + """Create a mock Agent with configurable fields.""" + agent = MagicMock() + agent.id = uuid.uuid4() + agent.name = name + agent.role = role + agent.persona = persona + agent.system_prompt = system_prompt + agent.tool_assignments = tool_assignments or [] + agent.escalation_rules = escalation_rules or [] + agent.escalation_assignee = escalation_assignee + agent.natural_language_escalation = natural_language_escalation + agent.model_preference = model_preference + agent.is_active = True + return agent + + +def make_message_data( + *, + channel: str = "slack", + tenant_id: str | None = None, + text: str = "hello", + user_id: str = "U123", + placeholder_ts: str = "1234.5678", + channel_id: str = "C123", + phone_number_id: str = "", + bot_token: str = "", + thread_id: str = "thread_001", +) -> dict: + """Build a message_data dict as the gateway would inject it.""" + tid = tenant_id or str(uuid.uuid4()) + data: dict[str, Any] = { + "id": str(uuid.uuid4()), + "tenant_id": tid, + "channel": channel, + "channel_metadata": {}, + "sender": {"user_id": user_id, "display_name": "Test User"}, + "content": {"text": text, "attachments": []}, + "timestamp": "2026-01-01T00:00:00Z", + "thread_id": thread_id, + "reply_to": None, + "context": {}, + } + if placeholder_ts: + data["placeholder_ts"] = placeholder_ts + if channel_id: + data["channel_id"] = channel_id + if phone_number_id: + data["phone_number_id"] = phone_number_id + if bot_token: + data["bot_token"] = bot_token + return data + + +# --------------------------------------------------------------------------- +# Task 1 Tests — handle_message extra-field extraction +# --------------------------------------------------------------------------- + + +class TestHandleMessageExtraExtraction: + """handle_message must pop WhatsApp and Slack extras before model_validate.""" + + def test_pops_phone_number_id_before_validate(self) -> None: + """phone_number_id must be removed before KonstructMessage.model_validate.""" + from shared.models.message import KonstructMessage + + data = make_message_data(channel="whatsapp", phone_number_id="1234567890") + assert "phone_number_id" in data + + # Simulate what handle_message does: pop extras + placeholder_ts: str = data.pop("placeholder_ts", "") or "" + channel_id: str = data.pop("channel_id", "") or "" + phone_number_id: str = data.pop("phone_number_id", "") or "" + bot_token: str = data.pop("bot_token", "") or "" + + assert phone_number_id == "1234567890" + assert "phone_number_id" not in data + # Should validate cleanly without extra fields + msg = KonstructMessage.model_validate(data) + assert msg.channel == "whatsapp" + + def test_pops_bot_token_before_validate(self) -> None: + """bot_token (WhatsApp access_token) must be removed before model_validate.""" + from shared.models.message import KonstructMessage + + data = make_message_data(channel="whatsapp", bot_token="EAAtest123") + assert "bot_token" in data + + placeholder_ts: str = data.pop("placeholder_ts", "") or "" + channel_id: str = data.pop("channel_id", "") or "" + phone_number_id: str = data.pop("phone_number_id", "") or "" + bot_token: str = data.pop("bot_token", "") or "" + + assert bot_token == "EAAtest123" + assert "bot_token" not in data + msg = KonstructMessage.model_validate(data) + assert msg.channel == "whatsapp" + + def test_extras_dict_built_from_popped_values(self) -> None: + """Extras dict must contain all popped values for downstream routing.""" + data = make_message_data( + channel="whatsapp", + phone_number_id="9876543210", + bot_token="EAAaccess", + placeholder_ts="", + channel_id="", + ) + + placeholder_ts: str = data.pop("placeholder_ts", "") or "" + channel_id: str = data.pop("channel_id", "") or "" + phone_number_id: str = data.pop("phone_number_id", "") or "" + bot_token: str = data.pop("bot_token", "") or "" + + extras = { + "placeholder_ts": placeholder_ts, + "channel_id": channel_id, + "phone_number_id": phone_number_id, + "bot_token": bot_token, + } + + assert extras["phone_number_id"] == "9876543210" + assert extras["bot_token"] == "EAAaccess" + + def test_slack_extras_still_popped(self) -> None: + """Slack placeholder_ts and channel_id must still be popped (regression).""" + from shared.models.message import KonstructMessage + + data = make_message_data(channel="slack", placeholder_ts="9999.0000", channel_id="C999") + assert "placeholder_ts" in data + assert "channel_id" in data + + placeholder_ts: str = data.pop("placeholder_ts", "") or "" + channel_id: str = data.pop("channel_id", "") or "" + _ = data.pop("phone_number_id", "") or "" + _ = data.pop("bot_token", "") or "" + + assert placeholder_ts == "9999.0000" + assert channel_id == "C999" + assert "placeholder_ts" not in data + msg = KonstructMessage.model_validate(data) + assert msg.channel == "slack" + + +# --------------------------------------------------------------------------- +# Task 1 Tests — _send_response routing (unit-level) +# --------------------------------------------------------------------------- + + +class TestSendResponseRouting: + """_send_response must route to Slack or WhatsApp based on channel.""" + + def test_slack_path_calls_update_slack_placeholder(self) -> None: + """_send_response for Slack must call _update_slack_placeholder.""" + from orchestrator.tasks import _send_response + + with patch("orchestrator.tasks._update_slack_placeholder", new_callable=AsyncMock) as mock_update: + extras = { + "bot_token": "xoxb-test", + "channel_id": "C123", + "placeholder_ts": "1234.5678", + } + asyncio.run(_send_response("slack", "hello", extras)) + mock_update.assert_called_once_with( + bot_token="xoxb-test", + channel_id="C123", + placeholder_ts="1234.5678", + text="hello", + ) + + def test_whatsapp_path_calls_send_whatsapp_message(self) -> None: + """_send_response for WhatsApp must call send_whatsapp_message.""" + from orchestrator.tasks import _send_response + + with patch("orchestrator.tasks.send_whatsapp_message", new_callable=AsyncMock) as mock_send: + extras = { + "phone_number_id": "1234567890", + "bot_token": "EAAaccess", + "wa_id": "15551234567", + } + asyncio.run(_send_response("whatsapp", "hi", extras)) + mock_send.assert_called_once_with( + phone_number_id="1234567890", + access_token="EAAaccess", + recipient_wa_id="15551234567", + text="hi", + ) + + def test_unsupported_channel_logs_warning(self) -> None: + """_send_response for an unsupported channel must log a warning without crashing.""" + from orchestrator.tasks import _send_response + + # Should not raise, just log + asyncio.run(_send_response("mattermost", "hi", {})) + + +# --------------------------------------------------------------------------- +# Task 1 Tests — _process_message uses _send_response (not _update_slack_placeholder) +# --------------------------------------------------------------------------- + + +class TestProcessMessageUsesSSendResponse: + """_process_message must use _send_response for ALL response delivery — never direct _update_slack_placeholder.""" + + def _make_fake_redis(self, escalated: bool = False) -> AsyncMock: + """Build a fakeredis-like mock for the async client.""" + redis_mock = AsyncMock() + redis_mock.get = AsyncMock(return_value=b"escalated" if escalated else None) + redis_mock.delete = AsyncMock() + redis_mock.set = AsyncMock() + redis_mock.setex = AsyncMock() + redis_mock.aclose = AsyncMock() + return redis_mock + + def _make_session_mock(self, agent: Any) -> AsyncMock: + """Build a mock async DB session that returns the given agent.""" + session = AsyncMock() + # Scalars mock + scalars_mock = MagicMock() + scalars_mock.first.return_value = agent + result_mock = AsyncMock() + result_mock.scalars.return_value = scalars_mock + session.execute = AsyncMock(return_value=result_mock) + session.__aenter__ = AsyncMock(return_value=session) + session.__aexit__ = AsyncMock() + return session + + @pytest.mark.asyncio + async def test_process_message_calls_send_response_not_update_slack(self) -> None: + """_process_message must call _send_response for Slack, NOT _update_slack_placeholder directly.""" + from orchestrator.tasks import _process_message + from shared.models.message import KonstructMessage + + agent = make_agent() + tenant_id = str(uuid.uuid4()) + msg_data = make_message_data(channel="slack", tenant_id=tenant_id) + msg_data.pop("placeholder_ts", None) + msg_data.pop("channel_id", None) + msg_data.pop("phone_number_id", None) + msg_data.pop("bot_token", None) + msg = KonstructMessage.model_validate(msg_data) + + extras = { + "placeholder_ts": "1234.5678", + "channel_id": "C123", + "phone_number_id": "", + "bot_token": "", + } + + with ( + patch("orchestrator.tasks._send_response", new_callable=AsyncMock) as mock_send, + patch("orchestrator.tasks._update_slack_placeholder", new_callable=AsyncMock) as mock_update, + patch("orchestrator.tasks.aioredis") as mock_aioredis, + patch("orchestrator.tasks.async_session_factory") as mock_session_factory, + patch("orchestrator.tasks.configure_rls_hook"), + patch("orchestrator.tasks.current_tenant_id"), + patch("orchestrator.tasks.engine"), + patch("orchestrator.tasks.run_agent", new_callable=AsyncMock, return_value="Agent reply"), + patch("orchestrator.tasks.build_messages_with_memory", return_value=[]), + patch("orchestrator.tasks.get_recent_messages", new_callable=AsyncMock, return_value=[]), + patch("orchestrator.tasks.embed_text", return_value=[0.1] * 384), + patch("orchestrator.tasks.retrieve_relevant", new_callable=AsyncMock, return_value=[]), + patch("orchestrator.tasks.embed_and_store"), + patch("orchestrator.tasks.append_message", new_callable=AsyncMock), + patch("orchestrator.tasks.get_tools_for_agent", return_value=[]), + patch("orchestrator.tasks.AuditLogger"), + patch("orchestrator.tasks.check_escalation_rules", return_value=None), + ): + # Setup redis mock to return None for pre-check (not escalated) + redis_mock = self._make_fake_redis(escalated=False) + mock_aioredis.from_url.return_value = redis_mock + + # Setup session to return our agent + session_mock = self._make_session_mock(agent) + mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock) + mock_session_factory.return_value.__aexit__ = AsyncMock() + + await _process_message(msg, extras=extras) + + # _send_response MUST be called for delivery + mock_send.assert_called() + # _update_slack_placeholder must NOT be called directly in _process_message + mock_update.assert_not_called() + + @pytest.mark.asyncio + async def test_whatsapp_extras_passed_to_send_response(self) -> None: + """_process_message must pass phone_number_id, bot_token, wa_id to _send_response for WhatsApp.""" + from orchestrator.tasks import _process_message + from shared.models.message import KonstructMessage + + agent = make_agent() + tenant_id = str(uuid.uuid4()) + msg_data = make_message_data( + channel="whatsapp", + tenant_id=tenant_id, + user_id="15551234567", + ) + msg_data.pop("placeholder_ts", None) + msg_data.pop("channel_id", None) + msg_data.pop("phone_number_id", None) + msg_data.pop("bot_token", None) + msg = KonstructMessage.model_validate(msg_data) + + extras = { + "placeholder_ts": "", + "channel_id": "", + "phone_number_id": "9876543210", + "bot_token": "EAAaccess", + } + + captured_extras: list[dict] = [] + + async def capture_send(channel: str, text: str, ext: dict) -> None: + captured_extras.append({"channel": channel, "text": text, "extras": ext}) + + with ( + patch("orchestrator.tasks._send_response", side_effect=capture_send), + patch("orchestrator.tasks.aioredis") as mock_aioredis, + patch("orchestrator.tasks.async_session_factory") as mock_session_factory, + patch("orchestrator.tasks.configure_rls_hook"), + patch("orchestrator.tasks.current_tenant_id"), + patch("orchestrator.tasks.engine"), + patch("orchestrator.tasks.run_agent", new_callable=AsyncMock, return_value="WA reply"), + patch("orchestrator.tasks.build_messages_with_memory", return_value=[]), + patch("orchestrator.tasks.get_recent_messages", new_callable=AsyncMock, return_value=[]), + patch("orchestrator.tasks.embed_text", return_value=[0.1] * 384), + patch("orchestrator.tasks.retrieve_relevant", new_callable=AsyncMock, return_value=[]), + patch("orchestrator.tasks.embed_and_store"), + patch("orchestrator.tasks.append_message", new_callable=AsyncMock), + patch("orchestrator.tasks.get_tools_for_agent", return_value=[]), + patch("orchestrator.tasks.AuditLogger"), + patch("orchestrator.tasks.check_escalation_rules", return_value=None), + ): + redis_mock = self._make_fake_redis(escalated=False) + mock_aioredis.from_url.return_value = redis_mock + # Configure context manager for redis client used in finally blocks + redis_mock.__aenter__ = AsyncMock(return_value=redis_mock) + redis_mock.__aexit__ = AsyncMock() + + session_mock = self._make_session_mock(agent) + mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock) + mock_session_factory.return_value.__aexit__ = AsyncMock() + + await _process_message(msg, extras=extras) + + # At least one _send_response call must have channel=whatsapp + wa_calls = [c for c in captured_extras if c["channel"] == "whatsapp"] + assert len(wa_calls) >= 1 + wa_extras = wa_calls[0]["extras"] + assert wa_extras.get("phone_number_id") == "9876543210" + assert wa_extras.get("bot_token") == "EAAaccess" + # wa_id must come from sender.user_id + assert wa_extras.get("wa_id") == "15551234567" + + +# --------------------------------------------------------------------------- +# Task 1 Tests — Escalation pre-check +# --------------------------------------------------------------------------- + + +class TestEscalationPreCheck: + """When escalation status is 'escalated' in Redis, _process_message must return early without LLM call.""" + + def _make_session_mock(self, agent: Any) -> AsyncMock: + session = AsyncMock() + scalars_mock = MagicMock() + scalars_mock.first.return_value = agent + result_mock = AsyncMock() + result_mock.scalars.return_value = scalars_mock + session.execute = AsyncMock(return_value=result_mock) + session.__aenter__ = AsyncMock(return_value=session) + session.__aexit__ = AsyncMock() + return session + + @pytest.mark.asyncio + async def test_escalated_conversation_skips_run_agent(self) -> None: + """When Redis shows 'escalated', run_agent must NOT be called.""" + from orchestrator.tasks import _process_message + from shared.models.message import KonstructMessage + + agent = make_agent() + tenant_id = str(uuid.uuid4()) + msg_data = make_message_data(channel="slack", tenant_id=tenant_id) + msg_data.pop("placeholder_ts", None) + msg_data.pop("channel_id", None) + msg_data.pop("phone_number_id", None) + msg_data.pop("bot_token", None) + msg = KonstructMessage.model_validate(msg_data) + + extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""} + + redis_mock = AsyncMock() + redis_mock.get = AsyncMock(return_value=b"escalated") + redis_mock.aclose = AsyncMock() + + with ( + patch("orchestrator.tasks._send_response", new_callable=AsyncMock), + patch("orchestrator.tasks.aioredis") as mock_aioredis, + patch("orchestrator.tasks.async_session_factory") as mock_session_factory, + patch("orchestrator.tasks.configure_rls_hook"), + patch("orchestrator.tasks.current_tenant_id"), + patch("orchestrator.tasks.engine"), + patch("orchestrator.tasks.run_agent", new_callable=AsyncMock) as mock_run_agent, + patch("orchestrator.tasks.AuditLogger"), + ): + mock_aioredis.from_url.return_value = redis_mock + redis_mock.__aenter__ = AsyncMock(return_value=redis_mock) + redis_mock.__aexit__ = AsyncMock() + + session_mock = self._make_session_mock(agent) + mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock) + mock_session_factory.return_value.__aexit__ = AsyncMock() + + result = await _process_message(msg, extras=extras) + + # run_agent must NOT be called — conversation already escalated + mock_run_agent.assert_not_called() + # Should return a response indicating escalation state + assert result["message_id"] == msg.id + assert "team member" in result["response"].lower() or "already" in result["response"].lower() + + +# --------------------------------------------------------------------------- +# Task 1 Tests — Escalation post-check +# --------------------------------------------------------------------------- + + +class TestEscalationPostCheck: + """check_escalation_rules called after run_agent; escalate_to_human called when rule matches.""" + + def _make_session_mock(self, agent: Any) -> AsyncMock: + session = AsyncMock() + scalars_mock = MagicMock() + scalars_mock.first.return_value = agent + result_mock = AsyncMock() + result_mock.scalars.return_value = scalars_mock + session.execute = AsyncMock(return_value=result_mock) + session.__aenter__ = AsyncMock(return_value=session) + session.__aexit__ = AsyncMock() + return session + + @pytest.mark.asyncio + async def test_check_escalation_rules_called_after_run_agent(self) -> None: + """check_escalation_rules must be called after run_agent returns.""" + from orchestrator.tasks import _process_message + from shared.models.message import KonstructMessage + + agent = make_agent(escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}]) + tenant_id = str(uuid.uuid4()) + msg_data = make_message_data(channel="slack", tenant_id=tenant_id) + msg_data.pop("placeholder_ts", None) + msg_data.pop("channel_id", None) + msg_data.pop("phone_number_id", None) + msg_data.pop("bot_token", None) + msg = KonstructMessage.model_validate(msg_data) + + extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""} + + redis_mock = AsyncMock() + redis_mock.get = AsyncMock(return_value=None) + redis_mock.aclose = AsyncMock() + + with ( + patch("orchestrator.tasks._send_response", new_callable=AsyncMock), + patch("orchestrator.tasks.aioredis") as mock_aioredis, + patch("orchestrator.tasks.async_session_factory") as mock_session_factory, + patch("orchestrator.tasks.configure_rls_hook"), + patch("orchestrator.tasks.current_tenant_id"), + patch("orchestrator.tasks.engine"), + patch("orchestrator.tasks.run_agent", new_callable=AsyncMock, return_value="some reply"), + patch("orchestrator.tasks.build_messages_with_memory", return_value=[]), + patch("orchestrator.tasks.get_recent_messages", new_callable=AsyncMock, return_value=[]), + patch("orchestrator.tasks.embed_text", return_value=[0.1] * 384), + patch("orchestrator.tasks.retrieve_relevant", new_callable=AsyncMock, return_value=[]), + patch("orchestrator.tasks.embed_and_store"), + patch("orchestrator.tasks.append_message", new_callable=AsyncMock), + patch("orchestrator.tasks.get_tools_for_agent", return_value=[]), + patch("orchestrator.tasks.AuditLogger"), + patch("orchestrator.tasks.check_escalation_rules", return_value=None) as mock_check, + ): + mock_aioredis.from_url.return_value = redis_mock + + session_mock = self._make_session_mock(agent) + mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock) + mock_session_factory.return_value.__aexit__ = AsyncMock() + + await _process_message(msg, extras=extras) + + # check_escalation_rules must have been called + mock_check.assert_called_once() + + @pytest.mark.asyncio + async def test_escalate_to_human_called_when_rule_matches_and_assignee_set(self) -> None: + """When check_escalation_rules returns a rule and assignee is set, escalate_to_human must be called.""" + from orchestrator.tasks import _process_message + from shared.models.message import KonstructMessage + + agent = make_agent( + escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}], + escalation_assignee="U_MANAGER", + ) + tenant_id = str(uuid.uuid4()) + msg_data = make_message_data(channel="slack", tenant_id=tenant_id, text="refund issue again") + msg_data.pop("placeholder_ts", None) + msg_data.pop("channel_id", None) + msg_data.pop("phone_number_id", None) + msg_data.pop("bot_token", None) + msg = KonstructMessage.model_validate(msg_data) + + extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""} + + matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"} + + redis_mock = AsyncMock() + redis_mock.get = AsyncMock(return_value=None) + redis_mock.aclose = AsyncMock() + + with ( + patch("orchestrator.tasks._send_response", new_callable=AsyncMock), + patch("orchestrator.tasks.aioredis") as mock_aioredis, + patch("orchestrator.tasks.async_session_factory") as mock_session_factory, + patch("orchestrator.tasks.configure_rls_hook"), + patch("orchestrator.tasks.current_tenant_id"), + patch("orchestrator.tasks.engine"), + patch("orchestrator.tasks.run_agent", new_callable=AsyncMock, return_value="some reply"), + patch("orchestrator.tasks.build_messages_with_memory", return_value=[]), + patch("orchestrator.tasks.get_recent_messages", new_callable=AsyncMock, return_value=[]), + patch("orchestrator.tasks.embed_text", return_value=[0.1] * 384), + patch("orchestrator.tasks.retrieve_relevant", new_callable=AsyncMock, return_value=[]), + patch("orchestrator.tasks.embed_and_store"), + patch("orchestrator.tasks.append_message", new_callable=AsyncMock), + patch("orchestrator.tasks.get_tools_for_agent", return_value=[]), + patch("orchestrator.tasks.AuditLogger"), + patch("orchestrator.tasks.check_escalation_rules", return_value=matched_rule), + patch("orchestrator.tasks.escalate_to_human", new_callable=AsyncMock, return_value="I've brought in a team member") as mock_escalate, + ): + mock_aioredis.from_url.return_value = redis_mock + + session_mock = self._make_session_mock(agent) + mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock) + mock_session_factory.return_value.__aexit__ = AsyncMock() + + result = await _process_message(msg, extras=extras) + + # escalate_to_human must be called + mock_escalate.assert_called_once() + # Response should be the escalation confirmation + assert "team member" in result["response"] or mock_escalate.return_value in result["response"] + + @pytest.mark.asyncio + async def test_escalate_to_human_not_called_when_no_assignee(self) -> None: + """When rule matches but escalation_assignee is None, escalate_to_human must NOT be called.""" + from orchestrator.tasks import _process_message + from shared.models.message import KonstructMessage + + agent = make_agent( + escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}], + escalation_assignee=None, + ) + tenant_id = str(uuid.uuid4()) + msg_data = make_message_data(channel="slack", tenant_id=tenant_id) + msg_data.pop("placeholder_ts", None) + msg_data.pop("channel_id", None) + msg_data.pop("phone_number_id", None) + msg_data.pop("bot_token", None) + msg = KonstructMessage.model_validate(msg_data) + + extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""} + + matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"} + + redis_mock = AsyncMock() + redis_mock.get = AsyncMock(return_value=None) + redis_mock.aclose = AsyncMock() + + with ( + patch("orchestrator.tasks._send_response", new_callable=AsyncMock), + patch("orchestrator.tasks.aioredis") as mock_aioredis, + patch("orchestrator.tasks.async_session_factory") as mock_session_factory, + patch("orchestrator.tasks.configure_rls_hook"), + patch("orchestrator.tasks.current_tenant_id"), + patch("orchestrator.tasks.engine"), + patch("orchestrator.tasks.run_agent", new_callable=AsyncMock, return_value="normal reply"), + patch("orchestrator.tasks.build_messages_with_memory", return_value=[]), + patch("orchestrator.tasks.get_recent_messages", new_callable=AsyncMock, return_value=[]), + patch("orchestrator.tasks.embed_text", return_value=[0.1] * 384), + patch("orchestrator.tasks.retrieve_relevant", new_callable=AsyncMock, return_value=[]), + patch("orchestrator.tasks.embed_and_store"), + patch("orchestrator.tasks.append_message", new_callable=AsyncMock), + patch("orchestrator.tasks.get_tools_for_agent", return_value=[]), + patch("orchestrator.tasks.AuditLogger"), + patch("orchestrator.tasks.check_escalation_rules", return_value=matched_rule), + patch("orchestrator.tasks.escalate_to_human", new_callable=AsyncMock) as mock_escalate, + ): + mock_aioredis.from_url.return_value = redis_mock + + session_mock = self._make_session_mock(agent) + mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock) + mock_session_factory.return_value.__aexit__ = AsyncMock() + + await _process_message(msg, extras=extras) + + # escalate_to_human must NOT be called — no assignee configured + mock_escalate.assert_not_called() + + +# --------------------------------------------------------------------------- +# Task 1 Tests — _build_conversation_metadata +# --------------------------------------------------------------------------- + + +class TestBuildConversationMetadata: + """_build_conversation_metadata must extract billing keywords from messages.""" + + def test_billing_keyword_in_current_text(self) -> None: + """billing keyword in current_text must set billing_dispute=True.""" + from orchestrator.tasks import _build_conversation_metadata + + result = _build_conversation_metadata([], "I have a billing issue") + assert result["billing_dispute"] is True + assert result["attempts"] >= 1 + + def test_billing_keyword_in_recent_messages(self) -> None: + """billing keyword in recent_messages must be counted.""" + from orchestrator.tasks import _build_conversation_metadata + + recent = [{"content": "My invoice is wrong"}, {"content": "Charge seems off"}] + result = _build_conversation_metadata(recent, "unrelated") + assert result["billing_dispute"] is True + assert result["attempts"] >= 2 + + def test_no_keywords_returns_false(self) -> None: + """No billing keywords must return billing_dispute=False.""" + from orchestrator.tasks import _build_conversation_metadata + + result = _build_conversation_metadata( + [{"content": "Hello there"}, {"content": "How are you"}], + "what is the weather", + ) + assert result["billing_dispute"] is False + assert result["attempts"] == 0 + + def test_multiple_billing_messages_counted(self) -> None: + """Each billing keyword occurrence in messages must increment attempts.""" + from orchestrator.tasks import _build_conversation_metadata + + recent = [ + {"content": "billing issue"}, + {"content": "invoice problem"}, + {"content": "refund request"}, + ] + result = _build_conversation_metadata(recent, "payment question") + assert result["billing_dispute"] is True + assert result["attempts"] >= 4 + + +# --------------------------------------------------------------------------- +# Task 2 Tests — build_system_prompt tier-2 WhatsApp scoping +# --------------------------------------------------------------------------- + + +class TestBuildSystemPromptWhatsAppScoping: + """build_system_prompt must append business-function scoping for WhatsApp when tool_assignments is non-empty.""" + + def test_whatsapp_with_tool_assignments_appends_scoping(self) -> None: + """channel='whatsapp' + non-empty tool_assignments must append scoping clause.""" + from orchestrator.agents.builder import build_system_prompt + + agent = make_agent(tool_assignments=["customer support", "order tracking"]) + prompt = build_system_prompt(agent, channel="whatsapp") + assert "You only handle" in prompt + assert "customer support" in prompt + assert "order tracking" in prompt + + def test_slack_channel_does_not_append_scoping(self) -> None: + """channel='slack' must NOT append WhatsApp scoping.""" + from orchestrator.agents.builder import build_system_prompt + + agent = make_agent(tool_assignments=["customer support", "order tracking"]) + prompt = build_system_prompt(agent, channel="slack") + assert "You only handle" not in prompt + + def test_whatsapp_empty_tool_assignments_no_scoping(self) -> None: + """channel='whatsapp' with empty tool_assignments must NOT append scoping.""" + from orchestrator.agents.builder import build_system_prompt + + agent = make_agent(tool_assignments=[]) + prompt = build_system_prompt(agent, channel="whatsapp") + assert "You only handle" not in prompt + + def test_no_channel_no_scoping(self) -> None: + """No channel (default '') must NOT append scoping.""" + from orchestrator.agents.builder import build_system_prompt + + agent = make_agent(tool_assignments=["billing", "support"]) + prompt = build_system_prompt(agent) + assert "You only handle" not in prompt + + def test_scoping_includes_redirect_instruction(self) -> None: + """WhatsApp scoping must include instruction to redirect off-topic requests.""" + from orchestrator.agents.builder import build_system_prompt + + agent = make_agent(tool_assignments=["billing"]) + prompt = build_system_prompt(agent, channel="whatsapp") + assert "redirect" in prompt.lower() or "outside" in prompt.lower() + + +# --------------------------------------------------------------------------- +# Task 2 Tests — build_messages_with_memory passes channel through +# --------------------------------------------------------------------------- + + +class TestBuildMessagesWithMemoryChannel: + """build_messages_with_memory must pass channel through to build_system_prompt.""" + + def test_channel_parameter_accepted(self) -> None: + """build_messages_with_memory must accept channel parameter without error.""" + from orchestrator.agents.builder import build_messages_with_memory + + agent = make_agent(tool_assignments=["support"]) + messages = build_messages_with_memory( + agent=agent, + current_message="help", + recent_messages=[], + relevant_context=[], + channel="whatsapp", + ) + assert len(messages) >= 2 # system + user message + + def test_whatsapp_channel_scoping_in_system_message(self) -> None: + """When channel='whatsapp' and tool_assignments is set, system message must include scoping.""" + from orchestrator.agents.builder import build_messages_with_memory + + agent = make_agent(tool_assignments=["order tracking", "returns"]) + messages = build_messages_with_memory( + agent=agent, + current_message="help", + recent_messages=[], + relevant_context=[], + channel="whatsapp", + ) + system_msg = messages[0] + assert system_msg["role"] == "system" + assert "You only handle" in system_msg["content"] + + def test_slack_channel_no_scoping_in_system_message(self) -> None: + """When channel='slack', system message must NOT include WhatsApp scoping.""" + from orchestrator.agents.builder import build_messages_with_memory + + agent = make_agent(tool_assignments=["order tracking", "returns"]) + messages = build_messages_with_memory( + agent=agent, + current_message="help", + recent_messages=[], + relevant_context=[], + channel="slack", + ) + system_msg = messages[0] + assert "You only handle" not in system_msg["content"] + + def test_default_channel_no_scoping(self) -> None: + """Default channel (omitted) must NOT include scoping.""" + from orchestrator.agents.builder import build_messages_with_memory + + agent = make_agent(tool_assignments=["billing"]) + messages = build_messages_with_memory( + agent=agent, + current_message="question", + recent_messages=[], + relevant_context=[], + ) + system_msg = messages[0] + assert "You only handle" not in system_msg["content"]