- Tests for handle_message WhatsApp extra extraction (phone_number_id, bot_token) - Tests for _send_response routing to Slack and WhatsApp - Tests for _process_message using _send_response (not _update_slack_placeholder directly) - Tests for escalation pre-check (skip LLM when already escalated) - Tests for escalation post-check (check_escalation_rules + escalate_to_human) - Tests for _build_conversation_metadata billing keyword extraction - Tests for build_system_prompt WhatsApp tier-2 scoping (Task 2) - Tests for build_messages_with_memory channel parameter passthrough
814 lines
36 KiB
Python
814 lines
36 KiB
Python
"""
|
|
Unit tests for pipeline wiring: escalation + outbound routing re-wiring (Plan 02-06).
|
|
|
|
Tests verify:
|
|
- handle_message pops WhatsApp extras (phone_number_id, bot_token) before model_validate
|
|
- _process_message calls _send_response (not _update_slack_placeholder) for all delivery points
|
|
- Escalation pre-check: already-escalated conversations skip LLM call
|
|
- Escalation post-check: check_escalation_rules called after LLM response
|
|
- escalate_to_human called when rule matches and assignee is configured
|
|
- WhatsApp extras flow through to _send_response correctly
|
|
- build_system_prompt appends WhatsApp business-function scoping (Task 2)
|
|
- build_messages_with_memory/media pass channel through
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def make_agent(
|
|
*,
|
|
name: str = "TestAgent",
|
|
role: str = "Support",
|
|
persona: str = "",
|
|
system_prompt: str = "",
|
|
tool_assignments: list[str] | None = None,
|
|
escalation_rules: list[dict] | None = None,
|
|
escalation_assignee: str | None = None,
|
|
natural_language_escalation: bool = False,
|
|
model_preference: str = "gpt-4o",
|
|
) -> MagicMock:
|
|
"""Create a mock Agent with configurable fields."""
|
|
agent = MagicMock()
|
|
agent.id = uuid.uuid4()
|
|
agent.name = name
|
|
agent.role = role
|
|
agent.persona = persona
|
|
agent.system_prompt = system_prompt
|
|
agent.tool_assignments = tool_assignments or []
|
|
agent.escalation_rules = escalation_rules or []
|
|
agent.escalation_assignee = escalation_assignee
|
|
agent.natural_language_escalation = natural_language_escalation
|
|
agent.model_preference = model_preference
|
|
agent.is_active = True
|
|
return agent
|
|
|
|
|
|
def make_message_data(
|
|
*,
|
|
channel: str = "slack",
|
|
tenant_id: str | None = None,
|
|
text: str = "hello",
|
|
user_id: str = "U123",
|
|
placeholder_ts: str = "1234.5678",
|
|
channel_id: str = "C123",
|
|
phone_number_id: str = "",
|
|
bot_token: str = "",
|
|
thread_id: str = "thread_001",
|
|
) -> dict:
|
|
"""Build a message_data dict as the gateway would inject it."""
|
|
tid = tenant_id or str(uuid.uuid4())
|
|
data: dict[str, Any] = {
|
|
"id": str(uuid.uuid4()),
|
|
"tenant_id": tid,
|
|
"channel": channel,
|
|
"channel_metadata": {},
|
|
"sender": {"user_id": user_id, "display_name": "Test User"},
|
|
"content": {"text": text, "attachments": []},
|
|
"timestamp": "2026-01-01T00:00:00Z",
|
|
"thread_id": thread_id,
|
|
"reply_to": None,
|
|
"context": {},
|
|
}
|
|
if placeholder_ts:
|
|
data["placeholder_ts"] = placeholder_ts
|
|
if channel_id:
|
|
data["channel_id"] = channel_id
|
|
if phone_number_id:
|
|
data["phone_number_id"] = phone_number_id
|
|
if bot_token:
|
|
data["bot_token"] = bot_token
|
|
return data
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Task 1 Tests — handle_message extra-field extraction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestHandleMessageExtraExtraction:
|
|
"""handle_message must pop WhatsApp and Slack extras before model_validate."""
|
|
|
|
def test_pops_phone_number_id_before_validate(self) -> None:
|
|
"""phone_number_id must be removed before KonstructMessage.model_validate."""
|
|
from shared.models.message import KonstructMessage
|
|
|
|
data = make_message_data(channel="whatsapp", phone_number_id="1234567890")
|
|
assert "phone_number_id" in data
|
|
|
|
# Simulate what handle_message does: pop extras
|
|
placeholder_ts: str = data.pop("placeholder_ts", "") or ""
|
|
channel_id: str = data.pop("channel_id", "") or ""
|
|
phone_number_id: str = data.pop("phone_number_id", "") or ""
|
|
bot_token: str = data.pop("bot_token", "") or ""
|
|
|
|
assert phone_number_id == "1234567890"
|
|
assert "phone_number_id" not in data
|
|
# Should validate cleanly without extra fields
|
|
msg = KonstructMessage.model_validate(data)
|
|
assert msg.channel == "whatsapp"
|
|
|
|
def test_pops_bot_token_before_validate(self) -> None:
|
|
"""bot_token (WhatsApp access_token) must be removed before model_validate."""
|
|
from shared.models.message import KonstructMessage
|
|
|
|
data = make_message_data(channel="whatsapp", bot_token="EAAtest123")
|
|
assert "bot_token" in data
|
|
|
|
placeholder_ts: str = data.pop("placeholder_ts", "") or ""
|
|
channel_id: str = data.pop("channel_id", "") or ""
|
|
phone_number_id: str = data.pop("phone_number_id", "") or ""
|
|
bot_token: str = data.pop("bot_token", "") or ""
|
|
|
|
assert bot_token == "EAAtest123"
|
|
assert "bot_token" not in data
|
|
msg = KonstructMessage.model_validate(data)
|
|
assert msg.channel == "whatsapp"
|
|
|
|
def test_extras_dict_built_from_popped_values(self) -> None:
|
|
"""Extras dict must contain all popped values for downstream routing."""
|
|
data = make_message_data(
|
|
channel="whatsapp",
|
|
phone_number_id="9876543210",
|
|
bot_token="EAAaccess",
|
|
placeholder_ts="",
|
|
channel_id="",
|
|
)
|
|
|
|
placeholder_ts: str = data.pop("placeholder_ts", "") or ""
|
|
channel_id: str = data.pop("channel_id", "") or ""
|
|
phone_number_id: str = data.pop("phone_number_id", "") or ""
|
|
bot_token: str = data.pop("bot_token", "") or ""
|
|
|
|
extras = {
|
|
"placeholder_ts": placeholder_ts,
|
|
"channel_id": channel_id,
|
|
"phone_number_id": phone_number_id,
|
|
"bot_token": bot_token,
|
|
}
|
|
|
|
assert extras["phone_number_id"] == "9876543210"
|
|
assert extras["bot_token"] == "EAAaccess"
|
|
|
|
def test_slack_extras_still_popped(self) -> None:
|
|
"""Slack placeholder_ts and channel_id must still be popped (regression)."""
|
|
from shared.models.message import KonstructMessage
|
|
|
|
data = make_message_data(channel="slack", placeholder_ts="9999.0000", channel_id="C999")
|
|
assert "placeholder_ts" in data
|
|
assert "channel_id" in data
|
|
|
|
placeholder_ts: str = data.pop("placeholder_ts", "") or ""
|
|
channel_id: str = data.pop("channel_id", "") or ""
|
|
_ = data.pop("phone_number_id", "") or ""
|
|
_ = data.pop("bot_token", "") or ""
|
|
|
|
assert placeholder_ts == "9999.0000"
|
|
assert channel_id == "C999"
|
|
assert "placeholder_ts" not in data
|
|
msg = KonstructMessage.model_validate(data)
|
|
assert msg.channel == "slack"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Task 1 Tests — _send_response routing (unit-level)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSendResponseRouting:
|
|
"""_send_response must route to Slack or WhatsApp based on channel."""
|
|
|
|
def test_slack_path_calls_update_slack_placeholder(self) -> None:
|
|
"""_send_response for Slack must call _update_slack_placeholder."""
|
|
from orchestrator.tasks import _send_response
|
|
|
|
with patch("orchestrator.tasks._update_slack_placeholder", new_callable=AsyncMock) as mock_update:
|
|
extras = {
|
|
"bot_token": "xoxb-test",
|
|
"channel_id": "C123",
|
|
"placeholder_ts": "1234.5678",
|
|
}
|
|
asyncio.run(_send_response("slack", "hello", extras))
|
|
mock_update.assert_called_once_with(
|
|
bot_token="xoxb-test",
|
|
channel_id="C123",
|
|
placeholder_ts="1234.5678",
|
|
text="hello",
|
|
)
|
|
|
|
def test_whatsapp_path_calls_send_whatsapp_message(self) -> None:
|
|
"""_send_response for WhatsApp must call send_whatsapp_message."""
|
|
from orchestrator.tasks import _send_response
|
|
|
|
with patch("orchestrator.tasks.send_whatsapp_message", new_callable=AsyncMock) as mock_send:
|
|
extras = {
|
|
"phone_number_id": "1234567890",
|
|
"bot_token": "EAAaccess",
|
|
"wa_id": "15551234567",
|
|
}
|
|
asyncio.run(_send_response("whatsapp", "hi", extras))
|
|
mock_send.assert_called_once_with(
|
|
phone_number_id="1234567890",
|
|
access_token="EAAaccess",
|
|
recipient_wa_id="15551234567",
|
|
text="hi",
|
|
)
|
|
|
|
def test_unsupported_channel_logs_warning(self) -> None:
|
|
"""_send_response for an unsupported channel must log a warning without crashing."""
|
|
from orchestrator.tasks import _send_response
|
|
|
|
# Should not raise, just log
|
|
asyncio.run(_send_response("mattermost", "hi", {}))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Task 1 Tests — _process_message uses _send_response (not _update_slack_placeholder)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestProcessMessageUsesSSendResponse:
|
|
"""_process_message must use _send_response for ALL response delivery — never direct _update_slack_placeholder."""
|
|
|
|
def _make_fake_redis(self, escalated: bool = False) -> AsyncMock:
|
|
"""Build a fakeredis-like mock for the async client."""
|
|
redis_mock = AsyncMock()
|
|
redis_mock.get = AsyncMock(return_value=b"escalated" if escalated else None)
|
|
redis_mock.delete = AsyncMock()
|
|
redis_mock.set = AsyncMock()
|
|
redis_mock.setex = AsyncMock()
|
|
redis_mock.aclose = AsyncMock()
|
|
return redis_mock
|
|
|
|
def _make_session_mock(self, agent: Any) -> AsyncMock:
|
|
"""Build a mock async DB session that returns the given agent."""
|
|
session = AsyncMock()
|
|
# Scalars mock
|
|
scalars_mock = MagicMock()
|
|
scalars_mock.first.return_value = agent
|
|
result_mock = AsyncMock()
|
|
result_mock.scalars.return_value = scalars_mock
|
|
session.execute = AsyncMock(return_value=result_mock)
|
|
session.__aenter__ = AsyncMock(return_value=session)
|
|
session.__aexit__ = AsyncMock()
|
|
return session
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_message_calls_send_response_not_update_slack(self) -> None:
|
|
"""_process_message must call _send_response for Slack, NOT _update_slack_placeholder directly."""
|
|
from orchestrator.tasks import _process_message
|
|
from shared.models.message import KonstructMessage
|
|
|
|
agent = make_agent()
|
|
tenant_id = str(uuid.uuid4())
|
|
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
|
|
msg_data.pop("placeholder_ts", None)
|
|
msg_data.pop("channel_id", None)
|
|
msg_data.pop("phone_number_id", None)
|
|
msg_data.pop("bot_token", None)
|
|
msg = KonstructMessage.model_validate(msg_data)
|
|
|
|
extras = {
|
|
"placeholder_ts": "1234.5678",
|
|
"channel_id": "C123",
|
|
"phone_number_id": "",
|
|
"bot_token": "",
|
|
}
|
|
|
|
with (
|
|
patch("orchestrator.tasks._send_response", new_callable=AsyncMock) as mock_send,
|
|
patch("orchestrator.tasks._update_slack_placeholder", new_callable=AsyncMock) as mock_update,
|
|
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
|
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
|
patch("orchestrator.tasks.configure_rls_hook"),
|
|
patch("orchestrator.tasks.current_tenant_id"),
|
|
patch("orchestrator.tasks.engine"),
|
|
patch("orchestrator.tasks.run_agent", new_callable=AsyncMock, return_value="Agent reply"),
|
|
patch("orchestrator.tasks.build_messages_with_memory", return_value=[]),
|
|
patch("orchestrator.tasks.get_recent_messages", new_callable=AsyncMock, return_value=[]),
|
|
patch("orchestrator.tasks.embed_text", return_value=[0.1] * 384),
|
|
patch("orchestrator.tasks.retrieve_relevant", new_callable=AsyncMock, return_value=[]),
|
|
patch("orchestrator.tasks.embed_and_store"),
|
|
patch("orchestrator.tasks.append_message", new_callable=AsyncMock),
|
|
patch("orchestrator.tasks.get_tools_for_agent", return_value=[]),
|
|
patch("orchestrator.tasks.AuditLogger"),
|
|
patch("orchestrator.tasks.check_escalation_rules", return_value=None),
|
|
):
|
|
# Setup redis mock to return None for pre-check (not escalated)
|
|
redis_mock = self._make_fake_redis(escalated=False)
|
|
mock_aioredis.from_url.return_value = redis_mock
|
|
|
|
# Setup session to return our agent
|
|
session_mock = self._make_session_mock(agent)
|
|
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
|
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
|
|
|
await _process_message(msg, extras=extras)
|
|
|
|
# _send_response MUST be called for delivery
|
|
mock_send.assert_called()
|
|
# _update_slack_placeholder must NOT be called directly in _process_message
|
|
mock_update.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_whatsapp_extras_passed_to_send_response(self) -> None:
|
|
"""_process_message must pass phone_number_id, bot_token, wa_id to _send_response for WhatsApp."""
|
|
from orchestrator.tasks import _process_message
|
|
from shared.models.message import KonstructMessage
|
|
|
|
agent = make_agent()
|
|
tenant_id = str(uuid.uuid4())
|
|
msg_data = make_message_data(
|
|
channel="whatsapp",
|
|
tenant_id=tenant_id,
|
|
user_id="15551234567",
|
|
)
|
|
msg_data.pop("placeholder_ts", None)
|
|
msg_data.pop("channel_id", None)
|
|
msg_data.pop("phone_number_id", None)
|
|
msg_data.pop("bot_token", None)
|
|
msg = KonstructMessage.model_validate(msg_data)
|
|
|
|
extras = {
|
|
"placeholder_ts": "",
|
|
"channel_id": "",
|
|
"phone_number_id": "9876543210",
|
|
"bot_token": "EAAaccess",
|
|
}
|
|
|
|
captured_extras: list[dict] = []
|
|
|
|
async def capture_send(channel: str, text: str, ext: dict) -> None:
|
|
captured_extras.append({"channel": channel, "text": text, "extras": ext})
|
|
|
|
with (
|
|
patch("orchestrator.tasks._send_response", side_effect=capture_send),
|
|
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
|
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
|
patch("orchestrator.tasks.configure_rls_hook"),
|
|
patch("orchestrator.tasks.current_tenant_id"),
|
|
patch("orchestrator.tasks.engine"),
|
|
patch("orchestrator.tasks.run_agent", new_callable=AsyncMock, return_value="WA reply"),
|
|
patch("orchestrator.tasks.build_messages_with_memory", return_value=[]),
|
|
patch("orchestrator.tasks.get_recent_messages", new_callable=AsyncMock, return_value=[]),
|
|
patch("orchestrator.tasks.embed_text", return_value=[0.1] * 384),
|
|
patch("orchestrator.tasks.retrieve_relevant", new_callable=AsyncMock, return_value=[]),
|
|
patch("orchestrator.tasks.embed_and_store"),
|
|
patch("orchestrator.tasks.append_message", new_callable=AsyncMock),
|
|
patch("orchestrator.tasks.get_tools_for_agent", return_value=[]),
|
|
patch("orchestrator.tasks.AuditLogger"),
|
|
patch("orchestrator.tasks.check_escalation_rules", return_value=None),
|
|
):
|
|
redis_mock = self._make_fake_redis(escalated=False)
|
|
mock_aioredis.from_url.return_value = redis_mock
|
|
# Configure context manager for redis client used in finally blocks
|
|
redis_mock.__aenter__ = AsyncMock(return_value=redis_mock)
|
|
redis_mock.__aexit__ = AsyncMock()
|
|
|
|
session_mock = self._make_session_mock(agent)
|
|
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
|
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
|
|
|
await _process_message(msg, extras=extras)
|
|
|
|
# At least one _send_response call must have channel=whatsapp
|
|
wa_calls = [c for c in captured_extras if c["channel"] == "whatsapp"]
|
|
assert len(wa_calls) >= 1
|
|
wa_extras = wa_calls[0]["extras"]
|
|
assert wa_extras.get("phone_number_id") == "9876543210"
|
|
assert wa_extras.get("bot_token") == "EAAaccess"
|
|
# wa_id must come from sender.user_id
|
|
assert wa_extras.get("wa_id") == "15551234567"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Task 1 Tests — Escalation pre-check
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEscalationPreCheck:
|
|
"""When escalation status is 'escalated' in Redis, _process_message must return early without LLM call."""
|
|
|
|
def _make_session_mock(self, agent: Any) -> AsyncMock:
|
|
session = AsyncMock()
|
|
scalars_mock = MagicMock()
|
|
scalars_mock.first.return_value = agent
|
|
result_mock = AsyncMock()
|
|
result_mock.scalars.return_value = scalars_mock
|
|
session.execute = AsyncMock(return_value=result_mock)
|
|
session.__aenter__ = AsyncMock(return_value=session)
|
|
session.__aexit__ = AsyncMock()
|
|
return session
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_escalated_conversation_skips_run_agent(self) -> None:
|
|
"""When Redis shows 'escalated', run_agent must NOT be called."""
|
|
from orchestrator.tasks import _process_message
|
|
from shared.models.message import KonstructMessage
|
|
|
|
agent = make_agent()
|
|
tenant_id = str(uuid.uuid4())
|
|
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
|
|
msg_data.pop("placeholder_ts", None)
|
|
msg_data.pop("channel_id", None)
|
|
msg_data.pop("phone_number_id", None)
|
|
msg_data.pop("bot_token", None)
|
|
msg = KonstructMessage.model_validate(msg_data)
|
|
|
|
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
|
|
|
redis_mock = AsyncMock()
|
|
redis_mock.get = AsyncMock(return_value=b"escalated")
|
|
redis_mock.aclose = AsyncMock()
|
|
|
|
with (
|
|
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
|
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
|
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
|
patch("orchestrator.tasks.configure_rls_hook"),
|
|
patch("orchestrator.tasks.current_tenant_id"),
|
|
patch("orchestrator.tasks.engine"),
|
|
patch("orchestrator.tasks.run_agent", new_callable=AsyncMock) as mock_run_agent,
|
|
patch("orchestrator.tasks.AuditLogger"),
|
|
):
|
|
mock_aioredis.from_url.return_value = redis_mock
|
|
redis_mock.__aenter__ = AsyncMock(return_value=redis_mock)
|
|
redis_mock.__aexit__ = AsyncMock()
|
|
|
|
session_mock = self._make_session_mock(agent)
|
|
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
|
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
|
|
|
result = await _process_message(msg, extras=extras)
|
|
|
|
# run_agent must NOT be called — conversation already escalated
|
|
mock_run_agent.assert_not_called()
|
|
# Should return a response indicating escalation state
|
|
assert result["message_id"] == msg.id
|
|
assert "team member" in result["response"].lower() or "already" in result["response"].lower()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Task 1 Tests — Escalation post-check
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEscalationPostCheck:
|
|
"""check_escalation_rules called after run_agent; escalate_to_human called when rule matches."""
|
|
|
|
def _make_session_mock(self, agent: Any) -> AsyncMock:
|
|
session = AsyncMock()
|
|
scalars_mock = MagicMock()
|
|
scalars_mock.first.return_value = agent
|
|
result_mock = AsyncMock()
|
|
result_mock.scalars.return_value = scalars_mock
|
|
session.execute = AsyncMock(return_value=result_mock)
|
|
session.__aenter__ = AsyncMock(return_value=session)
|
|
session.__aexit__ = AsyncMock()
|
|
return session
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_escalation_rules_called_after_run_agent(self) -> None:
|
|
"""check_escalation_rules must be called after run_agent returns."""
|
|
from orchestrator.tasks import _process_message
|
|
from shared.models.message import KonstructMessage
|
|
|
|
agent = make_agent(escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}])
|
|
tenant_id = str(uuid.uuid4())
|
|
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
|
|
msg_data.pop("placeholder_ts", None)
|
|
msg_data.pop("channel_id", None)
|
|
msg_data.pop("phone_number_id", None)
|
|
msg_data.pop("bot_token", None)
|
|
msg = KonstructMessage.model_validate(msg_data)
|
|
|
|
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
|
|
|
redis_mock = AsyncMock()
|
|
redis_mock.get = AsyncMock(return_value=None)
|
|
redis_mock.aclose = AsyncMock()
|
|
|
|
with (
|
|
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
|
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
|
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
|
patch("orchestrator.tasks.configure_rls_hook"),
|
|
patch("orchestrator.tasks.current_tenant_id"),
|
|
patch("orchestrator.tasks.engine"),
|
|
patch("orchestrator.tasks.run_agent", new_callable=AsyncMock, return_value="some reply"),
|
|
patch("orchestrator.tasks.build_messages_with_memory", return_value=[]),
|
|
patch("orchestrator.tasks.get_recent_messages", new_callable=AsyncMock, return_value=[]),
|
|
patch("orchestrator.tasks.embed_text", return_value=[0.1] * 384),
|
|
patch("orchestrator.tasks.retrieve_relevant", new_callable=AsyncMock, return_value=[]),
|
|
patch("orchestrator.tasks.embed_and_store"),
|
|
patch("orchestrator.tasks.append_message", new_callable=AsyncMock),
|
|
patch("orchestrator.tasks.get_tools_for_agent", return_value=[]),
|
|
patch("orchestrator.tasks.AuditLogger"),
|
|
patch("orchestrator.tasks.check_escalation_rules", return_value=None) as mock_check,
|
|
):
|
|
mock_aioredis.from_url.return_value = redis_mock
|
|
|
|
session_mock = self._make_session_mock(agent)
|
|
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
|
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
|
|
|
await _process_message(msg, extras=extras)
|
|
|
|
# check_escalation_rules must have been called
|
|
mock_check.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_escalate_to_human_called_when_rule_matches_and_assignee_set(self) -> None:
|
|
"""When check_escalation_rules returns a rule and assignee is set, escalate_to_human must be called."""
|
|
from orchestrator.tasks import _process_message
|
|
from shared.models.message import KonstructMessage
|
|
|
|
agent = make_agent(
|
|
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}],
|
|
escalation_assignee="U_MANAGER",
|
|
)
|
|
tenant_id = str(uuid.uuid4())
|
|
msg_data = make_message_data(channel="slack", tenant_id=tenant_id, text="refund issue again")
|
|
msg_data.pop("placeholder_ts", None)
|
|
msg_data.pop("channel_id", None)
|
|
msg_data.pop("phone_number_id", None)
|
|
msg_data.pop("bot_token", None)
|
|
msg = KonstructMessage.model_validate(msg_data)
|
|
|
|
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
|
|
|
matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}
|
|
|
|
redis_mock = AsyncMock()
|
|
redis_mock.get = AsyncMock(return_value=None)
|
|
redis_mock.aclose = AsyncMock()
|
|
|
|
with (
|
|
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
|
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
|
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
|
patch("orchestrator.tasks.configure_rls_hook"),
|
|
patch("orchestrator.tasks.current_tenant_id"),
|
|
patch("orchestrator.tasks.engine"),
|
|
patch("orchestrator.tasks.run_agent", new_callable=AsyncMock, return_value="some reply"),
|
|
patch("orchestrator.tasks.build_messages_with_memory", return_value=[]),
|
|
patch("orchestrator.tasks.get_recent_messages", new_callable=AsyncMock, return_value=[]),
|
|
patch("orchestrator.tasks.embed_text", return_value=[0.1] * 384),
|
|
patch("orchestrator.tasks.retrieve_relevant", new_callable=AsyncMock, return_value=[]),
|
|
patch("orchestrator.tasks.embed_and_store"),
|
|
patch("orchestrator.tasks.append_message", new_callable=AsyncMock),
|
|
patch("orchestrator.tasks.get_tools_for_agent", return_value=[]),
|
|
patch("orchestrator.tasks.AuditLogger"),
|
|
patch("orchestrator.tasks.check_escalation_rules", return_value=matched_rule),
|
|
patch("orchestrator.tasks.escalate_to_human", new_callable=AsyncMock, return_value="I've brought in a team member") as mock_escalate,
|
|
):
|
|
mock_aioredis.from_url.return_value = redis_mock
|
|
|
|
session_mock = self._make_session_mock(agent)
|
|
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
|
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
|
|
|
result = await _process_message(msg, extras=extras)
|
|
|
|
# escalate_to_human must be called
|
|
mock_escalate.assert_called_once()
|
|
# Response should be the escalation confirmation
|
|
assert "team member" in result["response"] or mock_escalate.return_value in result["response"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_escalate_to_human_not_called_when_no_assignee(self) -> None:
|
|
"""When rule matches but escalation_assignee is None, escalate_to_human must NOT be called."""
|
|
from orchestrator.tasks import _process_message
|
|
from shared.models.message import KonstructMessage
|
|
|
|
agent = make_agent(
|
|
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}],
|
|
escalation_assignee=None,
|
|
)
|
|
tenant_id = str(uuid.uuid4())
|
|
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
|
|
msg_data.pop("placeholder_ts", None)
|
|
msg_data.pop("channel_id", None)
|
|
msg_data.pop("phone_number_id", None)
|
|
msg_data.pop("bot_token", None)
|
|
msg = KonstructMessage.model_validate(msg_data)
|
|
|
|
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
|
|
|
matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}
|
|
|
|
redis_mock = AsyncMock()
|
|
redis_mock.get = AsyncMock(return_value=None)
|
|
redis_mock.aclose = AsyncMock()
|
|
|
|
with (
|
|
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
|
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
|
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
|
patch("orchestrator.tasks.configure_rls_hook"),
|
|
patch("orchestrator.tasks.current_tenant_id"),
|
|
patch("orchestrator.tasks.engine"),
|
|
patch("orchestrator.tasks.run_agent", new_callable=AsyncMock, return_value="normal reply"),
|
|
patch("orchestrator.tasks.build_messages_with_memory", return_value=[]),
|
|
patch("orchestrator.tasks.get_recent_messages", new_callable=AsyncMock, return_value=[]),
|
|
patch("orchestrator.tasks.embed_text", return_value=[0.1] * 384),
|
|
patch("orchestrator.tasks.retrieve_relevant", new_callable=AsyncMock, return_value=[]),
|
|
patch("orchestrator.tasks.embed_and_store"),
|
|
patch("orchestrator.tasks.append_message", new_callable=AsyncMock),
|
|
patch("orchestrator.tasks.get_tools_for_agent", return_value=[]),
|
|
patch("orchestrator.tasks.AuditLogger"),
|
|
patch("orchestrator.tasks.check_escalation_rules", return_value=matched_rule),
|
|
patch("orchestrator.tasks.escalate_to_human", new_callable=AsyncMock) as mock_escalate,
|
|
):
|
|
mock_aioredis.from_url.return_value = redis_mock
|
|
|
|
session_mock = self._make_session_mock(agent)
|
|
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
|
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
|
|
|
await _process_message(msg, extras=extras)
|
|
|
|
# escalate_to_human must NOT be called — no assignee configured
|
|
mock_escalate.assert_not_called()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Task 1 Tests — _build_conversation_metadata
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBuildConversationMetadata:
|
|
"""_build_conversation_metadata must extract billing keywords from messages."""
|
|
|
|
def test_billing_keyword_in_current_text(self) -> None:
|
|
"""billing keyword in current_text must set billing_dispute=True."""
|
|
from orchestrator.tasks import _build_conversation_metadata
|
|
|
|
result = _build_conversation_metadata([], "I have a billing issue")
|
|
assert result["billing_dispute"] is True
|
|
assert result["attempts"] >= 1
|
|
|
|
def test_billing_keyword_in_recent_messages(self) -> None:
|
|
"""billing keyword in recent_messages must be counted."""
|
|
from orchestrator.tasks import _build_conversation_metadata
|
|
|
|
recent = [{"content": "My invoice is wrong"}, {"content": "Charge seems off"}]
|
|
result = _build_conversation_metadata(recent, "unrelated")
|
|
assert result["billing_dispute"] is True
|
|
assert result["attempts"] >= 2
|
|
|
|
def test_no_keywords_returns_false(self) -> None:
|
|
"""No billing keywords must return billing_dispute=False."""
|
|
from orchestrator.tasks import _build_conversation_metadata
|
|
|
|
result = _build_conversation_metadata(
|
|
[{"content": "Hello there"}, {"content": "How are you"}],
|
|
"what is the weather",
|
|
)
|
|
assert result["billing_dispute"] is False
|
|
assert result["attempts"] == 0
|
|
|
|
def test_multiple_billing_messages_counted(self) -> None:
|
|
"""Each billing keyword occurrence in messages must increment attempts."""
|
|
from orchestrator.tasks import _build_conversation_metadata
|
|
|
|
recent = [
|
|
{"content": "billing issue"},
|
|
{"content": "invoice problem"},
|
|
{"content": "refund request"},
|
|
]
|
|
result = _build_conversation_metadata(recent, "payment question")
|
|
assert result["billing_dispute"] is True
|
|
assert result["attempts"] >= 4
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Task 2 Tests — build_system_prompt tier-2 WhatsApp scoping
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBuildSystemPromptWhatsAppScoping:
|
|
"""build_system_prompt must append business-function scoping for WhatsApp when tool_assignments is non-empty."""
|
|
|
|
def test_whatsapp_with_tool_assignments_appends_scoping(self) -> None:
|
|
"""channel='whatsapp' + non-empty tool_assignments must append scoping clause."""
|
|
from orchestrator.agents.builder import build_system_prompt
|
|
|
|
agent = make_agent(tool_assignments=["customer support", "order tracking"])
|
|
prompt = build_system_prompt(agent, channel="whatsapp")
|
|
assert "You only handle" in prompt
|
|
assert "customer support" in prompt
|
|
assert "order tracking" in prompt
|
|
|
|
def test_slack_channel_does_not_append_scoping(self) -> None:
|
|
"""channel='slack' must NOT append WhatsApp scoping."""
|
|
from orchestrator.agents.builder import build_system_prompt
|
|
|
|
agent = make_agent(tool_assignments=["customer support", "order tracking"])
|
|
prompt = build_system_prompt(agent, channel="slack")
|
|
assert "You only handle" not in prompt
|
|
|
|
def test_whatsapp_empty_tool_assignments_no_scoping(self) -> None:
|
|
"""channel='whatsapp' with empty tool_assignments must NOT append scoping."""
|
|
from orchestrator.agents.builder import build_system_prompt
|
|
|
|
agent = make_agent(tool_assignments=[])
|
|
prompt = build_system_prompt(agent, channel="whatsapp")
|
|
assert "You only handle" not in prompt
|
|
|
|
def test_no_channel_no_scoping(self) -> None:
|
|
"""No channel (default '') must NOT append scoping."""
|
|
from orchestrator.agents.builder import build_system_prompt
|
|
|
|
agent = make_agent(tool_assignments=["billing", "support"])
|
|
prompt = build_system_prompt(agent)
|
|
assert "You only handle" not in prompt
|
|
|
|
def test_scoping_includes_redirect_instruction(self) -> None:
|
|
"""WhatsApp scoping must include instruction to redirect off-topic requests."""
|
|
from orchestrator.agents.builder import build_system_prompt
|
|
|
|
agent = make_agent(tool_assignments=["billing"])
|
|
prompt = build_system_prompt(agent, channel="whatsapp")
|
|
assert "redirect" in prompt.lower() or "outside" in prompt.lower()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Task 2 Tests — build_messages_with_memory passes channel through
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBuildMessagesWithMemoryChannel:
|
|
"""build_messages_with_memory must pass channel through to build_system_prompt."""
|
|
|
|
def test_channel_parameter_accepted(self) -> None:
|
|
"""build_messages_with_memory must accept channel parameter without error."""
|
|
from orchestrator.agents.builder import build_messages_with_memory
|
|
|
|
agent = make_agent(tool_assignments=["support"])
|
|
messages = build_messages_with_memory(
|
|
agent=agent,
|
|
current_message="help",
|
|
recent_messages=[],
|
|
relevant_context=[],
|
|
channel="whatsapp",
|
|
)
|
|
assert len(messages) >= 2 # system + user message
|
|
|
|
def test_whatsapp_channel_scoping_in_system_message(self) -> None:
|
|
"""When channel='whatsapp' and tool_assignments is set, system message must include scoping."""
|
|
from orchestrator.agents.builder import build_messages_with_memory
|
|
|
|
agent = make_agent(tool_assignments=["order tracking", "returns"])
|
|
messages = build_messages_with_memory(
|
|
agent=agent,
|
|
current_message="help",
|
|
recent_messages=[],
|
|
relevant_context=[],
|
|
channel="whatsapp",
|
|
)
|
|
system_msg = messages[0]
|
|
assert system_msg["role"] == "system"
|
|
assert "You only handle" in system_msg["content"]
|
|
|
|
def test_slack_channel_no_scoping_in_system_message(self) -> None:
|
|
"""When channel='slack', system message must NOT include WhatsApp scoping."""
|
|
from orchestrator.agents.builder import build_messages_with_memory
|
|
|
|
agent = make_agent(tool_assignments=["order tracking", "returns"])
|
|
messages = build_messages_with_memory(
|
|
agent=agent,
|
|
current_message="help",
|
|
recent_messages=[],
|
|
relevant_context=[],
|
|
channel="slack",
|
|
)
|
|
system_msg = messages[0]
|
|
assert "You only handle" not in system_msg["content"]
|
|
|
|
def test_default_channel_no_scoping(self) -> None:
|
|
"""Default channel (omitted) must NOT include scoping."""
|
|
from orchestrator.agents.builder import build_messages_with_memory
|
|
|
|
agent = make_agent(tool_assignments=["billing"])
|
|
messages = build_messages_with_memory(
|
|
agent=agent,
|
|
current_message="question",
|
|
recent_messages=[],
|
|
relevant_context=[],
|
|
)
|
|
system_msg = messages[0]
|
|
assert "You only handle" not in system_msg["content"]
|