feat(02-06): re-wire escalation and WhatsApp outbound routing in pipeline

- Move key imports to module level in tasks.py for testability and clarity
- Pop WhatsApp extras (phone_number_id, bot_token) in handle_message before model_validate
- Build unified extras dict and extract wa_id from sender.user_id
- Change _process_message signature to accept extras dict
- Add _build_response_extras() helper for channel-aware extras assembly
- Replace all _update_slack_placeholder calls in _process_message with _send_response()
- Add escalation pre-check: skip LLM when Redis escalation_status_key == 'escalated'
- Add escalation post-check: check_escalation_rules after run_agent; call escalate_to_human
  when rule matches and agent.escalation_assignee is set
- Add _build_conversation_metadata() helper (billing keyword v1 detection)
- Add channel parameter to build_system_prompt(), build_messages_with_memory(),
  build_messages_with_media() for WhatsApp tier-2 business-function scoping
- WhatsApp scoping appends 'You only handle: {topics}' when tool_assignments non-empty
- Pass msg.channel to build_messages_with_memory() in _process_message
- All 26 new tests pass; all existing escalation/WhatsApp tests pass (no regressions)
This commit is contained in:
2026-03-23 19:15:20 -06:00
parent 77c9cfc825
commit bd217a4113
3 changed files with 380 additions and 226 deletions

View File

@@ -133,7 +133,7 @@ def generate_presigned_url(storage_key: str, expiry: int = 3600) -> str:
return presigned_url return presigned_url
def build_system_prompt(agent: Agent) -> str: def build_system_prompt(agent: Agent, channel: str = "") -> str:
""" """
Assemble the full system prompt for an agent. Assemble the full system prompt for an agent.
@@ -142,9 +142,20 @@ def build_system_prompt(agent: Agent) -> str:
- Identity section: name + role - Identity section: name + role
- Persona section (if agent.persona is non-empty) - Persona section (if agent.persona is non-empty)
- AI transparency clause (always appended) - AI transparency clause (always appended)
- WhatsApp tier-2 scoping (only when channel == "whatsapp" and tool_assignments set)
WhatsApp Tier-2 Business-Function Scoping:
Per Meta 2026 policy, agents responding on WhatsApp must constrain themselves
to declared business functions. When channel == "whatsapp" and the agent has
non-empty tool_assignments, a scoping clause is appended that instructs the LLM
to handle only the declared topics and redirect off-topic requests.
This is the tier-2 gate (borderline messages reach the LLM with this constraint;
clearly off-topic messages are caught by the tier-1 keyword gate in the gateway).
Args: Args:
agent: ORM Agent instance. agent: ORM Agent instance.
channel: Channel name (e.g. "slack", "whatsapp"). Used for tier-2 scoping.
Default empty string means no channel-specific scoping.
Returns: Returns:
A complete system prompt string ready to pass to the LLM. A complete system prompt string ready to pass to the LLM.
@@ -167,6 +178,17 @@ def build_system_prompt(agent: Agent) -> str:
"If asked directly whether you are an AI, always respond honestly that you are an AI assistant." "If asked directly whether you are an AI, always respond honestly that you are an AI assistant."
) )
# 5. WhatsApp tier-2 scoping — constrain LLM to declared business functions
if channel == "whatsapp":
functions: list[str] = getattr(agent, "tool_assignments", []) or []
if functions:
topics = ", ".join(functions)
parts.append(
f"You are responding on WhatsApp. You only handle: {topics}. "
f"If the user asks about something outside these topics, "
f"politely redirect them to the allowed topics."
)
return "\n\n".join(parts) return "\n\n".join(parts)
@@ -204,12 +226,14 @@ def build_messages_with_memory(
current_message: str, current_message: str,
recent_messages: list[dict], recent_messages: list[dict],
relevant_context: list[str], relevant_context: list[str],
channel: str = "",
) -> list[dict]: ) -> list[dict]:
""" """
Build an LLM messages array enriched with two-layer memory. Build an LLM messages array enriched with two-layer memory.
Structure (in order): Structure (in order):
1. System message — agent identity, persona, AI transparency clause 1. System message — agent identity, persona, AI transparency clause,
and optional WhatsApp tier-2 scoping (when channel == "whatsapp")
2. System message — long-term context from pgvector (ONLY if non-empty) 2. System message — long-term context from pgvector (ONLY if non-empty)
Injected as a system message before the sliding window so the LLM Injected as a system message before the sliding window so the LLM
has relevant background without it appearing in the conversation. has relevant background without it appearing in the conversation.
@@ -226,11 +250,13 @@ def build_messages_with_memory(
from Redis sliding window (oldest first). from Redis sliding window (oldest first).
relevant_context: Long-term memory — list of content strings from relevant_context: Long-term memory — list of content strings from
pgvector similarity search (most relevant first). pgvector similarity search (most relevant first).
channel: Channel name for tier-2 scoping (e.g. "whatsapp").
Default empty string means no channel-specific scoping.
Returns: Returns:
List of message dicts suitable for an OpenAI-compatible API call. List of message dicts suitable for an OpenAI-compatible API call.
""" """
system_prompt = build_system_prompt(agent) system_prompt = build_system_prompt(agent, channel=channel)
messages: list[dict] = [{"role": "system", "content": system_prompt}] messages: list[dict] = [{"role": "system", "content": system_prompt}]
# Inject long-term pgvector context as a system message BEFORE sliding window # Inject long-term pgvector context as a system message BEFORE sliding window
@@ -256,6 +282,7 @@ def build_messages_with_media(
media_attachments: list[MediaAttachment], media_attachments: list[MediaAttachment],
recent_messages: list[dict], recent_messages: list[dict],
relevant_context: list[str], relevant_context: list[str],
channel: str = "",
) -> list[dict]: ) -> list[dict]:
""" """
Build an LLM messages array with memory enrichment AND multimodal media injection. Build an LLM messages array with memory enrichment AND multimodal media injection.
@@ -277,7 +304,8 @@ def build_messages_with_media(
gracefully rather than raising an error. gracefully rather than raising an error.
Structure (in order): Structure (in order):
1. System message — agent identity, persona, AI transparency clause 1. System message — agent identity, persona, AI transparency clause,
and optional WhatsApp tier-2 scoping (when channel == "whatsapp")
2. System message — long-term context (ONLY if non-empty) 2. System message — long-term context (ONLY if non-empty)
3. Sliding window messages — recent history 3. Sliding window messages — recent history
4. Current user message (plain string or multipart content list) 4. Current user message (plain string or multipart content list)
@@ -289,6 +317,8 @@ def build_messages_with_media(
Empty list produces the same output as build_messages_with_memory(). Empty list produces the same output as build_messages_with_memory().
recent_messages: Short-term memory — list of {"role", "content"} dicts. recent_messages: Short-term memory — list of {"role", "content"} dicts.
relevant_context: Long-term memory — list of content strings from pgvector. relevant_context: Long-term memory — list of content strings from pgvector.
channel: Channel name for tier-2 scoping (e.g. "whatsapp").
Default empty string means no channel-specific scoping.
Returns: Returns:
List of message dicts suitable for an OpenAI/LiteLLM-compatible API call. List of message dicts suitable for an OpenAI/LiteLLM-compatible API call.
@@ -301,6 +331,7 @@ def build_messages_with_media(
current_message=current_message, current_message=current_message,
recent_messages=recent_messages, recent_messages=recent_messages,
relevant_context=relevant_context, relevant_context=relevant_context,
channel=channel,
) )
# If no media attachments, return the base messages unchanged # If no media attachments, return the base messages unchanged

View File

@@ -35,6 +35,23 @@ Pending tool confirmation:
On the next user message, if a pending confirmation exists: On the next user message, if a pending confirmation exists:
- "yes" → execute the pending tool and continue - "yes" → execute the pending tool and continue
- "no" / anything else → cancel and inform the user - "no" / anything else → cancel and inform the user
Escalation pipeline (Phase 2 Plan 06):
Pre-check (before LLM call):
If Redis escalation_status_key == "escalated", return early assistant-mode reply.
This prevents the LLM from being called when a human has already taken over.
Post-check (after LLM response):
check_escalation_rules() evaluates configured rules against conversation metadata.
If a rule matches AND agent.escalation_assignee is set, escalate_to_human() is
called and its return value replaces the LLM response.
Outbound routing (Phase 2 Plan 06):
All response delivery goes through _send_response() which routes to:
- Slack: _update_slack_placeholder() via chat.update
- WhatsApp: send_whatsapp_message() via Meta Cloud API
handle_message now pops WhatsApp extras (phone_number_id, bot_token) and
passes them through to _process_message via the extras dict.
""" """
from __future__ import annotations from __future__ import annotations
@@ -43,10 +60,25 @@ import asyncio
import json import json
import logging import logging
import uuid import uuid
from typing import Any
import redis.asyncio as aioredis
from gateway.channels.whatsapp import send_whatsapp_message from gateway.channels.whatsapp import send_whatsapp_message
from orchestrator.agents.builder import build_messages_with_memory
from orchestrator.agents.runner import run_agent
from orchestrator.audit.logger import AuditLogger
from orchestrator.escalation.handler import check_escalation_rules, escalate_to_human
from orchestrator.main import app from orchestrator.main import app
from orchestrator.memory.embedder import embed_text
from orchestrator.memory.long_term import retrieve_relevant
from orchestrator.memory.short_term import append_message, get_recent_messages
from orchestrator.tools.registry import get_tools_for_agent
from shared.config import settings
from shared.db import async_session_factory, engine
from shared.models.message import KonstructMessage from shared.models.message import KonstructMessage
from shared.redis_keys import escalation_status_key
from shared.rls import configure_rls_hook, current_tenant_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -152,23 +184,28 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped
by the Channel Gateway after tenant resolution completes. by the Channel Gateway after tenant resolution completes.
The ``message_data`` dict MAY contain extra keys beyond KonstructMessage The ``message_data`` dict MAY contain extra keys beyond KonstructMessage
fields. Specifically, the Slack handler injects: fields. Specifically:
- ``placeholder_ts``: Slack message timestamp of the "Thinking..." placeholder - Slack handler injects:
- ``channel_id``: Slack channel ID where the response should be posted ``placeholder_ts``: Slack message timestamp of the "Thinking..." placeholder
``channel_id``: Slack channel ID where the response should be posted
- WhatsApp gateway injects:
``phone_number_id``: WhatsApp phone number ID for outbound messaging
``bot_token``: WhatsApp access_token (injected as bot_token by gateway)
These are extracted before KonstructMessage validation and used to update These are extracted before KonstructMessage validation and used to route
the placeholder with the real LLM response via chat.update. outbound responses via _send_response().
Pipeline: Pipeline:
1. Extract Slack reply metadata (placeholder_ts, channel_id) if present 1. Extract channel reply metadata (Slack + WhatsApp extras) if present
2. Deserialize message_data -> KonstructMessage 2. Deserialize message_data -> KonstructMessage
3. Run async agent pipeline via asyncio.run() 3. Extract wa_id from sender.user_id for WhatsApp messages
4. If Slack metadata present: call chat.update to replace placeholder 4. Build extras dict for channel-aware outbound routing
5. Return response dict 5. Run async agent pipeline via asyncio.run()
6. Return response dict
Args: Args:
message_data: JSON-serializable dict. Must contain KonstructMessage message_data: JSON-serializable dict. Must contain KonstructMessage
fields plus optional ``placeholder_ts`` and ``channel_id``. fields plus optional channel-specific extras.
Returns: Returns:
Dict with keys: Dict with keys:
@@ -181,20 +218,39 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped
placeholder_ts: str = message_data.pop("placeholder_ts", "") or "" placeholder_ts: str = message_data.pop("placeholder_ts", "") or ""
channel_id: str = message_data.pop("channel_id", "") or "" channel_id: str = message_data.pop("channel_id", "") or ""
# Extract WhatsApp-specific extras before model validation
# The WhatsApp gateway injects these alongside the normalized KonstructMessage fields
phone_number_id: str = message_data.pop("phone_number_id", "") or ""
bot_token: str = message_data.pop("bot_token", "") or ""
try: try:
msg = KonstructMessage.model_validate(message_data) msg = KonstructMessage.model_validate(message_data)
except Exception as exc: except Exception as exc:
logger.exception("Failed to deserialize KonstructMessage: %s", message_data) logger.exception("Failed to deserialize KonstructMessage: %s", message_data)
raise self.retry(exc=exc) raise self.retry(exc=exc)
result = asyncio.run(_process_message(msg, placeholder_ts=placeholder_ts, channel_id=channel_id)) # Extract wa_id from sender.user_id — WhatsApp normalizer sets sender.user_id
# to the wa_id (recipient phone number). This must happen AFTER model_validate.
wa_id: str = ""
if msg.channel == "whatsapp" and msg.sender and msg.sender.user_id:
wa_id = msg.sender.user_id
# Build the unified extras dict for channel-aware outbound routing
extras: dict[str, Any] = {
"placeholder_ts": placeholder_ts,
"channel_id": channel_id,
"phone_number_id": phone_number_id,
"bot_token": bot_token,
"wa_id": wa_id,
}
result = asyncio.run(_process_message(msg, extras=extras))
return result return result
async def _process_message( async def _process_message(
msg: KonstructMessage, msg: KonstructMessage,
placeholder_ts: str = "", extras: dict[str, Any] | None = None,
channel_id: str = "",
) -> dict: ) -> dict:
""" """
Async agent pipeline — load agent config, build memory-enriched prompt, call LLM pool. Async agent pipeline — load agent config, build memory-enriched prompt, call LLM pool.
@@ -215,24 +271,26 @@ async def _process_message(
- Tool-call loop runs inside run_agent() — no separate Celery tasks - Tool-call loop runs inside run_agent() — no separate Celery tasks
- If run_agent returns a confirmation message: store pending action in Redis - If run_agent returns a confirmation message: store pending action in Redis
Escalation pipeline (Phase 2 Plan 06 additions):
Pre-check: if Redis escalation_status_key == "escalated", return assistant-mode reply
Post-check: check_escalation_rules after LLM response; if triggered, escalate_to_human
Outbound routing (Phase 2 Plan 06 additions):
All response delivery goes through _send_response() — never direct _update_slack_placeholder.
extras dict is passed through to _send_response for channel-aware routing.
Args: Args:
msg: The deserialized KonstructMessage. msg: The deserialized KonstructMessage.
placeholder_ts: Slack message timestamp of the "Thinking..." placeholder. extras: Channel-specific routing metadata. For Slack: placeholder_ts, channel_id,
channel_id: Slack channel ID for the chat.update call. bot_token. For WhatsApp: phone_number_id, bot_token, wa_id.
Returns: Returns:
Dict with message_id, response, and tenant_id. Dict with message_id, response, and tenant_id.
""" """
from orchestrator.agents.builder import build_messages_with_memory
from orchestrator.agents.runner import run_agent
from orchestrator.audit.logger import AuditLogger
from orchestrator.memory.embedder import embed_text
from orchestrator.memory.long_term import retrieve_relevant
from orchestrator.memory.short_term import append_message, get_recent_messages
from orchestrator.tools.registry import get_tools_for_agent
from shared.db import async_session_factory, engine
from shared.models.tenant import Agent from shared.models.tenant import Agent
from shared.rls import configure_rls_hook, current_tenant_id
if extras is None:
extras = {}
if msg.tenant_id is None: if msg.tenant_id is None:
logger.warning("Message %s has no tenant_id — cannot process", msg.id) logger.warning("Message %s has no tenant_id — cannot process", msg.id)
@@ -265,8 +323,10 @@ async def _process_message(
result = await session.execute(stmt) result = await session.execute(stmt)
agent = result.scalars().first() agent = result.scalars().first()
# Load the bot token for this tenant from channel_connections config # Load the Slack bot token for this tenant from channel_connections config.
if agent is not None and placeholder_ts and channel_id: # This is needed for escalation DM delivery even on WhatsApp messages —
# the escalation handler always sends via Slack DM to the assignee.
if agent is not None and (extras.get("placeholder_ts") and extras.get("channel_id")):
from shared.models.tenant import ChannelConnection, ChannelTypeEnum from shared.models.tenant import ChannelConnection, ChannelTypeEnum
conn_stmt = ( conn_stmt = (
@@ -290,13 +350,9 @@ async def _process_message(
msg.id, msg.id,
) )
no_agent_response = "No active agent is configured for your workspace. Please contact your administrator." no_agent_response = "No active agent is configured for your workspace. Please contact your administrator."
if placeholder_ts and channel_id: # Build response_extras for channel-aware delivery
await _update_slack_placeholder( response_extras = _build_response_extras(msg.channel, extras, slack_bot_token)
bot_token=slack_bot_token, await _send_response(msg.channel, no_agent_response, response_extras)
channel_id=channel_id,
placeholder_ts=placeholder_ts,
text=no_agent_response,
)
return { return {
"message_id": msg.id, "message_id": msg.id,
"response": no_agent_response, "response": no_agent_response,
@@ -318,14 +374,30 @@ async def _process_message(
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
audit_logger = AuditLogger(session_factory=async_session_factory) audit_logger = AuditLogger(session_factory=async_session_factory)
# Build response_extras dict used for all outbound delivery in this pipeline run.
# For Slack: merges DB-loaded slack_bot_token with incoming extras.
# For WhatsApp: extras already contain phone_number_id, bot_token, wa_id.
response_extras = _build_response_extras(msg.channel, extras, slack_bot_token)
# -------------------------------------------------------------------------
# Escalation pre-check — if conversation already escalated, reply in assistant mode
# -------------------------------------------------------------------------
thread_key = msg.thread_id or user_id
esc_key = escalation_status_key(msg.tenant_id, thread_key)
pre_check_redis = aioredis.from_url(settings.redis_url)
try:
esc_status = await pre_check_redis.get(esc_key)
finally:
await pre_check_redis.aclose()
if esc_status == b"escalated":
assistant_reply = "I've already connected you with a team member. They'll continue assisting you."
await _send_response(msg.channel, assistant_reply, response_extras)
return {"message_id": msg.id, "response": assistant_reply, "tenant_id": msg.tenant_id}
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# Pending tool confirmation check # Pending tool confirmation check
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
import redis.asyncio as aioredis
from shared.config import settings
redis_client = aioredis.from_url(settings.redis_url)
pending_confirm_key = _PENDING_TOOL_KEY.format( pending_confirm_key = _PENDING_TOOL_KEY.format(
tenant_id=msg.tenant_id, tenant_id=msg.tenant_id,
user_id=user_id, user_id=user_id,
@@ -334,6 +406,7 @@ async def _process_message(
response_text: str = "" response_text: str = ""
handled_as_confirmation = False handled_as_confirmation = False
redis_client = aioredis.from_url(settings.redis_url)
try: try:
pending_raw = await redis_client.get(pending_confirm_key) pending_raw = await redis_client.get(pending_confirm_key)
@@ -362,13 +435,7 @@ async def _process_message(
await redis_client.aclose() await redis_client.aclose()
if handled_as_confirmation: if handled_as_confirmation:
if placeholder_ts and channel_id: await _send_response(msg.channel, response_text, response_extras)
await _update_slack_placeholder(
bot_token=slack_bot_token,
channel_id=channel_id,
placeholder_ts=placeholder_ts,
text=response_text,
)
return { return {
"message_id": msg.id, "message_id": msg.id,
"response": response_text, "response": response_text,
@@ -412,6 +479,7 @@ async def _process_message(
current_message=user_text, current_message=user_text,
recent_messages=recent_messages, recent_messages=recent_messages,
relevant_context=relevant_context, relevant_context=relevant_context,
channel=str(msg.channel) if msg.channel else "",
) )
# Build tool registry for this agent # Build tool registry for this agent
@@ -428,6 +496,37 @@ async def _process_message(
tool_registry=tool_registry if tool_registry else None, tool_registry=tool_registry if tool_registry else None,
) )
# -------------------------------------------------------------------------
# Escalation post-check — evaluate rules against conversation metadata
# -------------------------------------------------------------------------
conversation_metadata = _build_conversation_metadata(recent_messages, user_text)
triggered_rule = check_escalation_rules(
agent=agent,
message_text=user_text,
conversation_metadata=conversation_metadata,
natural_lang_enabled=getattr(agent, "natural_language_escalation", False),
)
if triggered_rule and getattr(agent, "escalation_assignee", None):
escalation_redis = aioredis.from_url(settings.redis_url)
try:
response_text = await escalate_to_human(
tenant_id=msg.tenant_id,
agent=agent,
thread_id=thread_key,
trigger_reason=triggered_rule.get("condition", "rule triggered"),
recent_messages=recent_messages,
assignee_slack_user_id=agent.escalation_assignee,
bot_token=slack_bot_token,
redis=escalation_redis,
audit_logger=audit_logger,
user_id=user_id,
agent_id=agent_id_str,
)
finally:
await escalation_redis.aclose()
# Check if the response is a tool confirmation request # Check if the response is a tool confirmation request
# The confirmation message template starts with a specific prefix # The confirmation message template starts with a specific prefix
is_confirmation_request = response_text.startswith("This action requires your approval") is_confirmation_request = response_text.startswith("This action requires your approval")
@@ -453,14 +552,8 @@ async def _process_message(
len(relevant_context), len(relevant_context),
) )
# Replace the "Thinking..." placeholder with the real response # Send response via channel-aware routing
if placeholder_ts and channel_id: await _send_response(msg.channel, response_text, response_extras)
await _update_slack_placeholder(
bot_token=slack_bot_token,
channel_id=channel_id,
placeholder_ts=placeholder_ts,
text=response_text,
)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# Memory persistence (after LLM response) # Memory persistence (after LLM response)
@@ -489,6 +582,74 @@ async def _process_message(
} }
def _build_response_extras(
channel: Any,
extras: dict[str, Any],
slack_bot_token: str,
) -> dict[str, Any]:
"""
Build the response_extras dict for channel-aware outbound delivery.
For Slack: injects slack_bot_token into extras["bot_token"] so _send_response
can use it for chat.update calls.
For WhatsApp: extras already contain phone_number_id, bot_token (access_token),
and wa_id — no transformation needed.
Args:
channel: Channel name from KonstructMessage.channel.
extras: Incoming extras from handle_message.
slack_bot_token: Bot token loaded from DB channel_connections.
Returns:
Dict suitable for passing to _send_response.
"""
channel_str = str(channel) if channel else ""
if channel_str == "slack":
return {
"bot_token": slack_bot_token,
"channel_id": extras.get("channel_id", "") or "",
"placeholder_ts": extras.get("placeholder_ts", "") or "",
}
elif channel_str == "whatsapp":
return {
"phone_number_id": extras.get("phone_number_id", "") or "",
"bot_token": extras.get("bot_token", "") or "",
"wa_id": extras.get("wa_id", "") or "",
}
else:
return dict(extras)
def _build_conversation_metadata(
recent_messages: list[dict],
current_text: str,
) -> dict[str, Any]:
"""
Build conversation metadata dict for escalation rule evaluation.
Scans recent messages and current_text for billing keywords and counts occurrences.
Returns a dict with:
- "billing_dispute" (bool): True if any billing keyword found in any message
- "attempts" (int): count of messages containing billing keywords
This is the v1 keyword-based metadata detection (per STATE.md decisions).
Args:
recent_messages: List of {"role", "content"} dicts from Redis sliding window.
current_text: The user's current message text.
Returns:
Dict with billing_dispute and attempts keys.
"""
billing_keywords = {"billing", "invoice", "charge", "refund", "payment", "subscription"}
all_texts = [m.get("content", "") for m in recent_messages] + [current_text]
billing_count = sum(1 for t in all_texts if any(kw in t.lower() for kw in billing_keywords))
return {
"billing_dispute": billing_count > 0,
"attempts": billing_count,
}
async def _execute_pending_tool( async def _execute_pending_tool(
pending_data: dict, pending_data: dict,
tenant_uuid: uuid.UUID, tenant_uuid: uuid.UUID,
@@ -526,7 +687,7 @@ def _extract_tool_name_from_confirmation(confirmation_message: str) -> str:
async def _send_response( async def _send_response(
channel: str, channel: Any,
text: str, text: str,
extras: dict, extras: dict,
) -> None: ) -> None:
@@ -545,7 +706,9 @@ async def _send_response(
For Slack: ``bot_token``, ``channel_id``, ``placeholder_ts`` For Slack: ``bot_token``, ``channel_id``, ``placeholder_ts``
For WhatsApp: ``phone_number_id``, ``bot_token`` (access_token), ``wa_id`` For WhatsApp: ``phone_number_id``, ``bot_token`` (access_token), ``wa_id``
""" """
if channel == "slack": channel_str = str(channel) if channel else ""
if channel_str == "slack":
bot_token: str = extras.get("bot_token", "") or "" bot_token: str = extras.get("bot_token", "") or ""
channel_id: str = extras.get("channel_id", "") or "" channel_id: str = extras.get("channel_id", "") or ""
placeholder_ts: str = extras.get("placeholder_ts", "") or "" placeholder_ts: str = extras.get("placeholder_ts", "") or ""
@@ -563,7 +726,7 @@ async def _send_response(
text=text, text=text,
) )
elif channel == "whatsapp": elif channel_str == "whatsapp":
phone_number_id: str = extras.get("phone_number_id", "") or "" phone_number_id: str = extras.get("phone_number_id", "") or ""
access_token: str = extras.get("bot_token", "") or "" access_token: str = extras.get("bot_token", "") or ""
wa_id: str = extras.get("wa_id", "") or "" wa_id: str = extras.get("wa_id", "") or ""

View File

@@ -15,7 +15,6 @@ Tests verify:
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import json
import uuid import uuid
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@@ -24,7 +23,7 @@ import pytest
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Shared test helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -93,6 +92,89 @@ def make_message_data(
return data return data
def make_fake_redis(escalated: bool = False) -> AsyncMock:
"""Build a fakeredis-like mock for the async redis client."""
redis_mock = AsyncMock()
redis_mock.get = AsyncMock(return_value=b"escalated" if escalated else None)
redis_mock.delete = AsyncMock()
redis_mock.set = AsyncMock()
redis_mock.setex = AsyncMock()
redis_mock.aclose = AsyncMock()
return redis_mock
def make_session_factory_mock(agent: Any) -> MagicMock:
"""
Build a mock for async_session_factory that returns the given agent.
The returned mock is callable and its return value acts as an async
context manager that yields a session whose execute() returns the agent.
"""
# Build result chain: session.execute() -> result.scalars().first() -> agent
scalars_mock = MagicMock()
scalars_mock.first.return_value = agent
result_mock = MagicMock()
result_mock.scalars.return_value = scalars_mock
session = AsyncMock()
session.execute = AsyncMock(return_value=result_mock)
# Make session_factory() return an async context manager
cm = AsyncMock()
cm.__aenter__ = AsyncMock(return_value=session)
cm.__aexit__ = AsyncMock(return_value=False)
factory = MagicMock()
factory.return_value = cm
return factory
def make_process_message_msg(
channel: str = "slack",
tenant_id: str | None = None,
user_id: str = "U123",
text: str = "hello",
thread_id: str = "thread_001",
) -> Any:
"""Build a KonstructMessage for use with _process_message tests."""
from shared.models.message import KonstructMessage
tid = tenant_id or str(uuid.uuid4())
data: dict[str, Any] = {
"id": str(uuid.uuid4()),
"tenant_id": tid,
"channel": channel,
"channel_metadata": {},
"sender": {"user_id": user_id, "display_name": "Test User"},
"content": {"text": text, "attachments": []},
"timestamp": "2026-01-01T00:00:00Z",
"thread_id": thread_id,
"reply_to": None,
"context": {},
}
return KonstructMessage.model_validate(data)
def _standard_process_patches(
agent: Any,
escalated: bool = False,
run_agent_return: str = "Agent reply",
escalation_rule: dict | None = None,
escalate_return: str = "I've brought in a team member",
) -> dict:
"""
Return a dict of patch targets and their corresponding mocks.
Callers use this with contextlib.ExitStack or individual patch() calls.
"""
return {
"redis": make_fake_redis(escalated=escalated),
"session_factory": make_session_factory_mock(agent),
"run_agent_return": run_agent_return,
"escalation_rule": escalation_rule,
"escalate_return": escalate_return,
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Task 1 Tests — handle_message extra-field extraction # Task 1 Tests — handle_message extra-field extraction
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -242,44 +324,13 @@ class TestSendResponseRouting:
class TestProcessMessageUsesSSendResponse: class TestProcessMessageUsesSSendResponse:
"""_process_message must use _send_response for ALL response delivery — never direct _update_slack_placeholder.""" """_process_message must use _send_response for ALL response delivery — never direct _update_slack_placeholder."""
def _make_fake_redis(self, escalated: bool = False) -> AsyncMock:
"""Build a fakeredis-like mock for the async client."""
redis_mock = AsyncMock()
redis_mock.get = AsyncMock(return_value=b"escalated" if escalated else None)
redis_mock.delete = AsyncMock()
redis_mock.set = AsyncMock()
redis_mock.setex = AsyncMock()
redis_mock.aclose = AsyncMock()
return redis_mock
def _make_session_mock(self, agent: Any) -> AsyncMock:
"""Build a mock async DB session that returns the given agent."""
session = AsyncMock()
# Scalars mock
scalars_mock = MagicMock()
scalars_mock.first.return_value = agent
result_mock = AsyncMock()
result_mock.scalars.return_value = scalars_mock
session.execute = AsyncMock(return_value=result_mock)
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock()
return session
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_message_calls_send_response_not_update_slack(self) -> None: async def test_process_message_calls_send_response_not_update_slack(self) -> None:
"""_process_message must call _send_response for Slack, NOT _update_slack_placeholder directly.""" """_process_message must call _send_response for Slack, NOT _update_slack_placeholder directly."""
from orchestrator.tasks import _process_message from orchestrator.tasks import _process_message
from shared.models.message import KonstructMessage
agent = make_agent() agent = make_agent()
tenant_id = str(uuid.uuid4()) msg = make_process_message_msg(channel="slack")
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
msg_data.pop("placeholder_ts", None)
msg_data.pop("channel_id", None)
msg_data.pop("phone_number_id", None)
msg_data.pop("bot_token", None)
msg = KonstructMessage.model_validate(msg_data)
extras = { extras = {
"placeholder_ts": "1234.5678", "placeholder_ts": "1234.5678",
"channel_id": "C123", "channel_id": "C123",
@@ -287,11 +338,14 @@ class TestProcessMessageUsesSSendResponse:
"bot_token": "", "bot_token": "",
} }
redis_mock = make_fake_redis(escalated=False)
sf_mock = make_session_factory_mock(agent)
with ( with (
patch("orchestrator.tasks._send_response", new_callable=AsyncMock) as mock_send, patch("orchestrator.tasks._send_response", new_callable=AsyncMock) as mock_send,
patch("orchestrator.tasks._update_slack_placeholder", new_callable=AsyncMock) as mock_update, patch("orchestrator.tasks._update_slack_placeholder", new_callable=AsyncMock) as mock_update,
patch("orchestrator.tasks.aioredis") as mock_aioredis, patch("orchestrator.tasks.aioredis") as mock_aioredis,
patch("orchestrator.tasks.async_session_factory") as mock_session_factory, patch("orchestrator.tasks.async_session_factory", sf_mock),
patch("orchestrator.tasks.configure_rls_hook"), patch("orchestrator.tasks.configure_rls_hook"),
patch("orchestrator.tasks.current_tenant_id"), patch("orchestrator.tasks.current_tenant_id"),
patch("orchestrator.tasks.engine"), patch("orchestrator.tasks.engine"),
@@ -306,15 +360,8 @@ class TestProcessMessageUsesSSendResponse:
patch("orchestrator.tasks.AuditLogger"), patch("orchestrator.tasks.AuditLogger"),
patch("orchestrator.tasks.check_escalation_rules", return_value=None), patch("orchestrator.tasks.check_escalation_rules", return_value=None),
): ):
# Setup redis mock to return None for pre-check (not escalated)
redis_mock = self._make_fake_redis(escalated=False)
mock_aioredis.from_url.return_value = redis_mock mock_aioredis.from_url.return_value = redis_mock
# Setup session to return our agent
session_mock = self._make_session_mock(agent)
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
mock_session_factory.return_value.__aexit__ = AsyncMock()
await _process_message(msg, extras=extras) await _process_message(msg, extras=extras)
# _send_response MUST be called for delivery # _send_response MUST be called for delivery
@@ -326,37 +373,30 @@ class TestProcessMessageUsesSSendResponse:
async def test_whatsapp_extras_passed_to_send_response(self) -> None: async def test_whatsapp_extras_passed_to_send_response(self) -> None:
"""_process_message must pass phone_number_id, bot_token, wa_id to _send_response for WhatsApp.""" """_process_message must pass phone_number_id, bot_token, wa_id to _send_response for WhatsApp."""
from orchestrator.tasks import _process_message from orchestrator.tasks import _process_message
from shared.models.message import KonstructMessage
agent = make_agent() agent = make_agent()
tenant_id = str(uuid.uuid4()) # user_id = "15551234567" -> becomes wa_id
msg_data = make_message_data( msg = make_process_message_msg(channel="whatsapp", user_id="15551234567")
channel="whatsapp",
tenant_id=tenant_id,
user_id="15551234567",
)
msg_data.pop("placeholder_ts", None)
msg_data.pop("channel_id", None)
msg_data.pop("phone_number_id", None)
msg_data.pop("bot_token", None)
msg = KonstructMessage.model_validate(msg_data)
extras = { extras = {
"placeholder_ts": "", "placeholder_ts": "",
"channel_id": "", "channel_id": "",
"phone_number_id": "9876543210", "phone_number_id": "9876543210",
"bot_token": "EAAaccess", "bot_token": "EAAaccess",
"wa_id": "15551234567",
} }
captured_extras: list[dict] = [] captured_extras: list[dict] = []
async def capture_send(channel: str, text: str, ext: dict) -> None: async def capture_send(channel: Any, text: str, ext: dict) -> None:
captured_extras.append({"channel": channel, "text": text, "extras": ext}) captured_extras.append({"channel": str(channel), "text": text, "extras": ext})
redis_mock = make_fake_redis(escalated=False)
sf_mock = make_session_factory_mock(agent)
with ( with (
patch("orchestrator.tasks._send_response", side_effect=capture_send), patch("orchestrator.tasks._send_response", side_effect=capture_send),
patch("orchestrator.tasks.aioredis") as mock_aioredis, patch("orchestrator.tasks.aioredis") as mock_aioredis,
patch("orchestrator.tasks.async_session_factory") as mock_session_factory, patch("orchestrator.tasks.async_session_factory", sf_mock),
patch("orchestrator.tasks.configure_rls_hook"), patch("orchestrator.tasks.configure_rls_hook"),
patch("orchestrator.tasks.current_tenant_id"), patch("orchestrator.tasks.current_tenant_id"),
patch("orchestrator.tasks.engine"), patch("orchestrator.tasks.engine"),
@@ -371,15 +411,7 @@ class TestProcessMessageUsesSSendResponse:
patch("orchestrator.tasks.AuditLogger"), patch("orchestrator.tasks.AuditLogger"),
patch("orchestrator.tasks.check_escalation_rules", return_value=None), patch("orchestrator.tasks.check_escalation_rules", return_value=None),
): ):
redis_mock = self._make_fake_redis(escalated=False)
mock_aioredis.from_url.return_value = redis_mock mock_aioredis.from_url.return_value = redis_mock
# Configure context manager for redis client used in finally blocks
redis_mock.__aenter__ = AsyncMock(return_value=redis_mock)
redis_mock.__aexit__ = AsyncMock()
session_mock = self._make_session_mock(agent)
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
mock_session_factory.return_value.__aexit__ = AsyncMock()
await _process_message(msg, extras=extras) await _process_message(msg, extras=extras)
@@ -389,7 +421,7 @@ class TestProcessMessageUsesSSendResponse:
wa_extras = wa_calls[0]["extras"] wa_extras = wa_calls[0]["extras"]
assert wa_extras.get("phone_number_id") == "9876543210" assert wa_extras.get("phone_number_id") == "9876543210"
assert wa_extras.get("bot_token") == "EAAaccess" assert wa_extras.get("bot_token") == "EAAaccess"
# wa_id must come from sender.user_id # wa_id must come from extras (which was populated from sender.user_id in handle_message)
assert wa_extras.get("wa_id") == "15551234567" assert wa_extras.get("wa_id") == "15551234567"
@@ -401,42 +433,22 @@ class TestProcessMessageUsesSSendResponse:
class TestEscalationPreCheck: class TestEscalationPreCheck:
"""When escalation status is 'escalated' in Redis, _process_message must return early without LLM call.""" """When escalation status is 'escalated' in Redis, _process_message must return early without LLM call."""
def _make_session_mock(self, agent: Any) -> AsyncMock:
session = AsyncMock()
scalars_mock = MagicMock()
scalars_mock.first.return_value = agent
result_mock = AsyncMock()
result_mock.scalars.return_value = scalars_mock
session.execute = AsyncMock(return_value=result_mock)
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock()
return session
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_escalated_conversation_skips_run_agent(self) -> None: async def test_escalated_conversation_skips_run_agent(self) -> None:
"""When Redis shows 'escalated', run_agent must NOT be called.""" """When Redis shows 'escalated', run_agent must NOT be called."""
from orchestrator.tasks import _process_message from orchestrator.tasks import _process_message
from shared.models.message import KonstructMessage
agent = make_agent() agent = make_agent()
tenant_id = str(uuid.uuid4()) msg = make_process_message_msg(channel="slack")
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
msg_data.pop("placeholder_ts", None)
msg_data.pop("channel_id", None)
msg_data.pop("phone_number_id", None)
msg_data.pop("bot_token", None)
msg = KonstructMessage.model_validate(msg_data)
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""} extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
redis_mock = AsyncMock() redis_mock = make_fake_redis(escalated=True)
redis_mock.get = AsyncMock(return_value=b"escalated") sf_mock = make_session_factory_mock(agent)
redis_mock.aclose = AsyncMock()
with ( with (
patch("orchestrator.tasks._send_response", new_callable=AsyncMock), patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
patch("orchestrator.tasks.aioredis") as mock_aioredis, patch("orchestrator.tasks.aioredis") as mock_aioredis,
patch("orchestrator.tasks.async_session_factory") as mock_session_factory, patch("orchestrator.tasks.async_session_factory", sf_mock),
patch("orchestrator.tasks.configure_rls_hook"), patch("orchestrator.tasks.configure_rls_hook"),
patch("orchestrator.tasks.current_tenant_id"), patch("orchestrator.tasks.current_tenant_id"),
patch("orchestrator.tasks.engine"), patch("orchestrator.tasks.engine"),
@@ -444,12 +456,6 @@ class TestEscalationPreCheck:
patch("orchestrator.tasks.AuditLogger"), patch("orchestrator.tasks.AuditLogger"),
): ):
mock_aioredis.from_url.return_value = redis_mock mock_aioredis.from_url.return_value = redis_mock
redis_mock.__aenter__ = AsyncMock(return_value=redis_mock)
redis_mock.__aexit__ = AsyncMock()
session_mock = self._make_session_mock(agent)
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
mock_session_factory.return_value.__aexit__ = AsyncMock()
result = await _process_message(msg, extras=extras) result = await _process_message(msg, extras=extras)
@@ -468,42 +474,24 @@ class TestEscalationPreCheck:
class TestEscalationPostCheck: class TestEscalationPostCheck:
"""check_escalation_rules called after run_agent; escalate_to_human called when rule matches.""" """check_escalation_rules called after run_agent; escalate_to_human called when rule matches."""
def _make_session_mock(self, agent: Any) -> AsyncMock:
session = AsyncMock()
scalars_mock = MagicMock()
scalars_mock.first.return_value = agent
result_mock = AsyncMock()
result_mock.scalars.return_value = scalars_mock
session.execute = AsyncMock(return_value=result_mock)
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock()
return session
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_escalation_rules_called_after_run_agent(self) -> None: async def test_check_escalation_rules_called_after_run_agent(self) -> None:
"""check_escalation_rules must be called after run_agent returns.""" """check_escalation_rules must be called after run_agent returns."""
from orchestrator.tasks import _process_message from orchestrator.tasks import _process_message
from shared.models.message import KonstructMessage
agent = make_agent(escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}])
tenant_id = str(uuid.uuid4())
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
msg_data.pop("placeholder_ts", None)
msg_data.pop("channel_id", None)
msg_data.pop("phone_number_id", None)
msg_data.pop("bot_token", None)
msg = KonstructMessage.model_validate(msg_data)
agent = make_agent(
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}]
)
msg = make_process_message_msg(channel="slack")
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""} extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
redis_mock = AsyncMock() redis_mock = make_fake_redis(escalated=False)
redis_mock.get = AsyncMock(return_value=None) sf_mock = make_session_factory_mock(agent)
redis_mock.aclose = AsyncMock()
with ( with (
patch("orchestrator.tasks._send_response", new_callable=AsyncMock), patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
patch("orchestrator.tasks.aioredis") as mock_aioredis, patch("orchestrator.tasks.aioredis") as mock_aioredis,
patch("orchestrator.tasks.async_session_factory") as mock_session_factory, patch("orchestrator.tasks.async_session_factory", sf_mock),
patch("orchestrator.tasks.configure_rls_hook"), patch("orchestrator.tasks.configure_rls_hook"),
patch("orchestrator.tasks.current_tenant_id"), patch("orchestrator.tasks.current_tenant_id"),
patch("orchestrator.tasks.engine"), patch("orchestrator.tasks.engine"),
@@ -520,10 +508,6 @@ class TestEscalationPostCheck:
): ):
mock_aioredis.from_url.return_value = redis_mock mock_aioredis.from_url.return_value = redis_mock
session_mock = self._make_session_mock(agent)
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
mock_session_factory.return_value.__aexit__ = AsyncMock()
await _process_message(msg, extras=extras) await _process_message(msg, extras=extras)
# check_escalation_rules must have been called # check_escalation_rules must have been called
@@ -533,32 +517,22 @@ class TestEscalationPostCheck:
async def test_escalate_to_human_called_when_rule_matches_and_assignee_set(self) -> None: async def test_escalate_to_human_called_when_rule_matches_and_assignee_set(self) -> None:
"""When check_escalation_rules returns a rule and assignee is set, escalate_to_human must be called.""" """When check_escalation_rules returns a rule and assignee is set, escalate_to_human must be called."""
from orchestrator.tasks import _process_message from orchestrator.tasks import _process_message
from shared.models.message import KonstructMessage
agent = make_agent( agent = make_agent(
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}], escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}],
escalation_assignee="U_MANAGER", escalation_assignee="U_MANAGER",
) )
tenant_id = str(uuid.uuid4()) msg = make_process_message_msg(channel="slack", text="refund issue again")
msg_data = make_message_data(channel="slack", tenant_id=tenant_id, text="refund issue again")
msg_data.pop("placeholder_ts", None)
msg_data.pop("channel_id", None)
msg_data.pop("phone_number_id", None)
msg_data.pop("bot_token", None)
msg = KonstructMessage.model_validate(msg_data)
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""} extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"} matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}
redis_mock = make_fake_redis(escalated=False)
redis_mock = AsyncMock() sf_mock = make_session_factory_mock(agent)
redis_mock.get = AsyncMock(return_value=None)
redis_mock.aclose = AsyncMock()
with ( with (
patch("orchestrator.tasks._send_response", new_callable=AsyncMock), patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
patch("orchestrator.tasks.aioredis") as mock_aioredis, patch("orchestrator.tasks.aioredis") as mock_aioredis,
patch("orchestrator.tasks.async_session_factory") as mock_session_factory, patch("orchestrator.tasks.async_session_factory", sf_mock),
patch("orchestrator.tasks.configure_rls_hook"), patch("orchestrator.tasks.configure_rls_hook"),
patch("orchestrator.tasks.current_tenant_id"), patch("orchestrator.tasks.current_tenant_id"),
patch("orchestrator.tasks.engine"), patch("orchestrator.tasks.engine"),
@@ -572,51 +546,41 @@ class TestEscalationPostCheck:
patch("orchestrator.tasks.get_tools_for_agent", return_value=[]), patch("orchestrator.tasks.get_tools_for_agent", return_value=[]),
patch("orchestrator.tasks.AuditLogger"), patch("orchestrator.tasks.AuditLogger"),
patch("orchestrator.tasks.check_escalation_rules", return_value=matched_rule), patch("orchestrator.tasks.check_escalation_rules", return_value=matched_rule),
patch("orchestrator.tasks.escalate_to_human", new_callable=AsyncMock, return_value="I've brought in a team member") as mock_escalate, patch(
"orchestrator.tasks.escalate_to_human",
new_callable=AsyncMock,
return_value="I've brought in a team member",
) as mock_escalate,
): ):
mock_aioredis.from_url.return_value = redis_mock mock_aioredis.from_url.return_value = redis_mock
session_mock = self._make_session_mock(agent)
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
mock_session_factory.return_value.__aexit__ = AsyncMock()
result = await _process_message(msg, extras=extras) result = await _process_message(msg, extras=extras)
# escalate_to_human must be called # escalate_to_human must be called
mock_escalate.assert_called_once() mock_escalate.assert_called_once()
# Response should be the escalation confirmation # Response should be the escalation confirmation
assert "team member" in result["response"] or mock_escalate.return_value in result["response"] assert "team member" in result["response"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_escalate_to_human_not_called_when_no_assignee(self) -> None: async def test_escalate_to_human_not_called_when_no_assignee(self) -> None:
"""When rule matches but escalation_assignee is None, escalate_to_human must NOT be called.""" """When rule matches but escalation_assignee is None, escalate_to_human must NOT be called."""
from orchestrator.tasks import _process_message from orchestrator.tasks import _process_message
from shared.models.message import KonstructMessage
agent = make_agent( agent = make_agent(
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}], escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}],
escalation_assignee=None, escalation_assignee=None,
) )
tenant_id = str(uuid.uuid4()) msg = make_process_message_msg(channel="slack")
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
msg_data.pop("placeholder_ts", None)
msg_data.pop("channel_id", None)
msg_data.pop("phone_number_id", None)
msg_data.pop("bot_token", None)
msg = KonstructMessage.model_validate(msg_data)
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""} extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"} matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}
redis_mock = make_fake_redis(escalated=False)
redis_mock = AsyncMock() sf_mock = make_session_factory_mock(agent)
redis_mock.get = AsyncMock(return_value=None)
redis_mock.aclose = AsyncMock()
with ( with (
patch("orchestrator.tasks._send_response", new_callable=AsyncMock), patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
patch("orchestrator.tasks.aioredis") as mock_aioredis, patch("orchestrator.tasks.aioredis") as mock_aioredis,
patch("orchestrator.tasks.async_session_factory") as mock_session_factory, patch("orchestrator.tasks.async_session_factory", sf_mock),
patch("orchestrator.tasks.configure_rls_hook"), patch("orchestrator.tasks.configure_rls_hook"),
patch("orchestrator.tasks.current_tenant_id"), patch("orchestrator.tasks.current_tenant_id"),
patch("orchestrator.tasks.engine"), patch("orchestrator.tasks.engine"),
@@ -634,10 +598,6 @@ class TestEscalationPostCheck:
): ):
mock_aioredis.from_url.return_value = redis_mock mock_aioredis.from_url.return_value = redis_mock
session_mock = self._make_session_mock(agent)
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
mock_session_factory.return_value.__aexit__ = AsyncMock()
await _process_message(msg, extras=extras) await _process_message(msg, extras=extras)
# escalate_to_human must NOT be called — no assignee configured # escalate_to_human must NOT be called — no assignee configured