feat(02-06): re-wire escalation and WhatsApp outbound routing in pipeline
- Move key imports to module level in tasks.py for testability and clarity
- Pop WhatsApp extras (phone_number_id, bot_token) in handle_message before model_validate
- Build unified extras dict and extract wa_id from sender.user_id
- Change _process_message signature to accept extras dict
- Add _build_response_extras() helper for channel-aware extras assembly
- Replace all _update_slack_placeholder calls in _process_message with _send_response()
- Add escalation pre-check: skip LLM when Redis escalation_status_key == 'escalated'
- Add escalation post-check: check_escalation_rules after run_agent; call escalate_to_human
when rule matches and agent.escalation_assignee is set
- Add _build_conversation_metadata() helper (billing keyword v1 detection)
- Add channel parameter to build_system_prompt(), build_messages_with_memory(),
build_messages_with_media() for WhatsApp tier-2 business-function scoping
- WhatsApp scoping appends 'You only handle: {topics}' when tool_assignments non-empty
- Pass msg.channel to build_messages_with_memory() in _process_message
- All 26 new tests pass; all existing escalation/WhatsApp tests pass (no regressions)
This commit is contained in:
@@ -133,7 +133,7 @@ def generate_presigned_url(storage_key: str, expiry: int = 3600) -> str:
|
||||
return presigned_url
|
||||
|
||||
|
||||
def build_system_prompt(agent: Agent) -> str:
|
||||
def build_system_prompt(agent: Agent, channel: str = "") -> str:
|
||||
"""
|
||||
Assemble the full system prompt for an agent.
|
||||
|
||||
@@ -142,9 +142,20 @@ def build_system_prompt(agent: Agent) -> str:
|
||||
- Identity section: name + role
|
||||
- Persona section (if agent.persona is non-empty)
|
||||
- AI transparency clause (always appended)
|
||||
- WhatsApp tier-2 scoping (only when channel == "whatsapp" and tool_assignments set)
|
||||
|
||||
WhatsApp Tier-2 Business-Function Scoping:
|
||||
Per Meta 2026 policy, agents responding on WhatsApp must constrain themselves
|
||||
to declared business functions. When channel == "whatsapp" and the agent has
|
||||
non-empty tool_assignments, a scoping clause is appended that instructs the LLM
|
||||
to handle only the declared topics and redirect off-topic requests.
|
||||
This is the tier-2 gate (borderline messages reach the LLM with this constraint;
|
||||
clearly off-topic messages are caught by the tier-1 keyword gate in the gateway).
|
||||
|
||||
Args:
|
||||
agent: ORM Agent instance.
|
||||
agent: ORM Agent instance.
|
||||
channel: Channel name (e.g. "slack", "whatsapp"). Used for tier-2 scoping.
|
||||
Default empty string means no channel-specific scoping.
|
||||
|
||||
Returns:
|
||||
A complete system prompt string ready to pass to the LLM.
|
||||
@@ -167,6 +178,17 @@ def build_system_prompt(agent: Agent) -> str:
|
||||
"If asked directly whether you are an AI, always respond honestly that you are an AI assistant."
|
||||
)
|
||||
|
||||
# 5. WhatsApp tier-2 scoping — constrain LLM to declared business functions
|
||||
if channel == "whatsapp":
|
||||
functions: list[str] = getattr(agent, "tool_assignments", []) or []
|
||||
if functions:
|
||||
topics = ", ".join(functions)
|
||||
parts.append(
|
||||
f"You are responding on WhatsApp. You only handle: {topics}. "
|
||||
f"If the user asks about something outside these topics, "
|
||||
f"politely redirect them to the allowed topics."
|
||||
)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
@@ -204,12 +226,14 @@ def build_messages_with_memory(
|
||||
current_message: str,
|
||||
recent_messages: list[dict],
|
||||
relevant_context: list[str],
|
||||
channel: str = "",
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Build an LLM messages array enriched with two-layer memory.
|
||||
|
||||
Structure (in order):
|
||||
1. System message — agent identity, persona, AI transparency clause
|
||||
1. System message — agent identity, persona, AI transparency clause,
|
||||
and optional WhatsApp tier-2 scoping (when channel == "whatsapp")
|
||||
2. System message — long-term context from pgvector (ONLY if non-empty)
|
||||
Injected as a system message before the sliding window so the LLM
|
||||
has relevant background without it appearing in the conversation.
|
||||
@@ -226,11 +250,13 @@ def build_messages_with_memory(
|
||||
from Redis sliding window (oldest first).
|
||||
relevant_context: Long-term memory — list of content strings from
|
||||
pgvector similarity search (most relevant first).
|
||||
channel: Channel name for tier-2 scoping (e.g. "whatsapp").
|
||||
Default empty string means no channel-specific scoping.
|
||||
|
||||
Returns:
|
||||
List of message dicts suitable for an OpenAI-compatible API call.
|
||||
"""
|
||||
system_prompt = build_system_prompt(agent)
|
||||
system_prompt = build_system_prompt(agent, channel=channel)
|
||||
messages: list[dict] = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# Inject long-term pgvector context as a system message BEFORE sliding window
|
||||
@@ -256,6 +282,7 @@ def build_messages_with_media(
|
||||
media_attachments: list[MediaAttachment],
|
||||
recent_messages: list[dict],
|
||||
relevant_context: list[str],
|
||||
channel: str = "",
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Build an LLM messages array with memory enrichment AND multimodal media injection.
|
||||
@@ -277,7 +304,8 @@ def build_messages_with_media(
|
||||
gracefully rather than raising an error.
|
||||
|
||||
Structure (in order):
|
||||
1. System message — agent identity, persona, AI transparency clause
|
||||
1. System message — agent identity, persona, AI transparency clause,
|
||||
and optional WhatsApp tier-2 scoping (when channel == "whatsapp")
|
||||
2. System message — long-term context (ONLY if non-empty)
|
||||
3. Sliding window messages — recent history
|
||||
4. Current user message (plain string or multipart content list)
|
||||
@@ -289,6 +317,8 @@ def build_messages_with_media(
|
||||
Empty list produces the same output as build_messages_with_memory().
|
||||
recent_messages: Short-term memory — list of {"role", "content"} dicts.
|
||||
relevant_context: Long-term memory — list of content strings from pgvector.
|
||||
channel: Channel name for tier-2 scoping (e.g. "whatsapp").
|
||||
Default empty string means no channel-specific scoping.
|
||||
|
||||
Returns:
|
||||
List of message dicts suitable for an OpenAI/LiteLLM-compatible API call.
|
||||
@@ -301,6 +331,7 @@ def build_messages_with_media(
|
||||
current_message=current_message,
|
||||
recent_messages=recent_messages,
|
||||
relevant_context=relevant_context,
|
||||
channel=channel,
|
||||
)
|
||||
|
||||
# If no media attachments, return the base messages unchanged
|
||||
|
||||
@@ -35,6 +35,23 @@ Pending tool confirmation:
|
||||
On the next user message, if a pending confirmation exists:
|
||||
- "yes" → execute the pending tool and continue
|
||||
- "no" / anything else → cancel and inform the user
|
||||
|
||||
Escalation pipeline (Phase 2 Plan 06):
|
||||
Pre-check (before LLM call):
|
||||
If Redis escalation_status_key == "escalated", return early assistant-mode reply.
|
||||
This prevents the LLM from being called when a human has already taken over.
|
||||
|
||||
Post-check (after LLM response):
|
||||
check_escalation_rules() evaluates configured rules against conversation metadata.
|
||||
If a rule matches AND agent.escalation_assignee is set, escalate_to_human() is
|
||||
called and its return value replaces the LLM response.
|
||||
|
||||
Outbound routing (Phase 2 Plan 06):
|
||||
All response delivery goes through _send_response() which routes to:
|
||||
- Slack: _update_slack_placeholder() via chat.update
|
||||
- WhatsApp: send_whatsapp_message() via Meta Cloud API
|
||||
handle_message now pops WhatsApp extras (phone_number_id, bot_token) and
|
||||
passes them through to _process_message via the extras dict.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -43,10 +60,25 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from gateway.channels.whatsapp import send_whatsapp_message
|
||||
from orchestrator.agents.builder import build_messages_with_memory
|
||||
from orchestrator.agents.runner import run_agent
|
||||
from orchestrator.audit.logger import AuditLogger
|
||||
from orchestrator.escalation.handler import check_escalation_rules, escalate_to_human
|
||||
from orchestrator.main import app
|
||||
from orchestrator.memory.embedder import embed_text
|
||||
from orchestrator.memory.long_term import retrieve_relevant
|
||||
from orchestrator.memory.short_term import append_message, get_recent_messages
|
||||
from orchestrator.tools.registry import get_tools_for_agent
|
||||
from shared.config import settings
|
||||
from shared.db import async_session_factory, engine
|
||||
from shared.models.message import KonstructMessage
|
||||
from shared.redis_keys import escalation_status_key
|
||||
from shared.rls import configure_rls_hook, current_tenant_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -152,23 +184,28 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped
|
||||
by the Channel Gateway after tenant resolution completes.
|
||||
|
||||
The ``message_data`` dict MAY contain extra keys beyond KonstructMessage
|
||||
fields. Specifically, the Slack handler injects:
|
||||
- ``placeholder_ts``: Slack message timestamp of the "Thinking..." placeholder
|
||||
- ``channel_id``: Slack channel ID where the response should be posted
|
||||
fields. Specifically:
|
||||
- Slack handler injects:
|
||||
``placeholder_ts``: Slack message timestamp of the "Thinking..." placeholder
|
||||
``channel_id``: Slack channel ID where the response should be posted
|
||||
- WhatsApp gateway injects:
|
||||
``phone_number_id``: WhatsApp phone number ID for outbound messaging
|
||||
``bot_token``: WhatsApp access_token (injected as bot_token by gateway)
|
||||
|
||||
These are extracted before KonstructMessage validation and used to update
|
||||
the placeholder with the real LLM response via chat.update.
|
||||
These are extracted before KonstructMessage validation and used to route
|
||||
outbound responses via _send_response().
|
||||
|
||||
Pipeline:
|
||||
1. Extract Slack reply metadata (placeholder_ts, channel_id) if present
|
||||
1. Extract channel reply metadata (Slack + WhatsApp extras) if present
|
||||
2. Deserialize message_data -> KonstructMessage
|
||||
3. Run async agent pipeline via asyncio.run()
|
||||
4. If Slack metadata present: call chat.update to replace placeholder
|
||||
5. Return response dict
|
||||
3. Extract wa_id from sender.user_id for WhatsApp messages
|
||||
4. Build extras dict for channel-aware outbound routing
|
||||
5. Run async agent pipeline via asyncio.run()
|
||||
6. Return response dict
|
||||
|
||||
Args:
|
||||
message_data: JSON-serializable dict. Must contain KonstructMessage
|
||||
fields plus optional ``placeholder_ts`` and ``channel_id``.
|
||||
fields plus optional channel-specific extras.
|
||||
|
||||
Returns:
|
||||
Dict with keys:
|
||||
@@ -181,20 +218,39 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped
|
||||
placeholder_ts: str = message_data.pop("placeholder_ts", "") or ""
|
||||
channel_id: str = message_data.pop("channel_id", "") or ""
|
||||
|
||||
# Extract WhatsApp-specific extras before model validation
|
||||
# The WhatsApp gateway injects these alongside the normalized KonstructMessage fields
|
||||
phone_number_id: str = message_data.pop("phone_number_id", "") or ""
|
||||
bot_token: str = message_data.pop("bot_token", "") or ""
|
||||
|
||||
try:
|
||||
msg = KonstructMessage.model_validate(message_data)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to deserialize KonstructMessage: %s", message_data)
|
||||
raise self.retry(exc=exc)
|
||||
|
||||
result = asyncio.run(_process_message(msg, placeholder_ts=placeholder_ts, channel_id=channel_id))
|
||||
# Extract wa_id from sender.user_id — WhatsApp normalizer sets sender.user_id
|
||||
# to the wa_id (recipient phone number). This must happen AFTER model_validate.
|
||||
wa_id: str = ""
|
||||
if msg.channel == "whatsapp" and msg.sender and msg.sender.user_id:
|
||||
wa_id = msg.sender.user_id
|
||||
|
||||
# Build the unified extras dict for channel-aware outbound routing
|
||||
extras: dict[str, Any] = {
|
||||
"placeholder_ts": placeholder_ts,
|
||||
"channel_id": channel_id,
|
||||
"phone_number_id": phone_number_id,
|
||||
"bot_token": bot_token,
|
||||
"wa_id": wa_id,
|
||||
}
|
||||
|
||||
result = asyncio.run(_process_message(msg, extras=extras))
|
||||
return result
|
||||
|
||||
|
||||
async def _process_message(
|
||||
msg: KonstructMessage,
|
||||
placeholder_ts: str = "",
|
||||
channel_id: str = "",
|
||||
extras: dict[str, Any] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Async agent pipeline — load agent config, build memory-enriched prompt, call LLM pool.
|
||||
@@ -215,24 +271,26 @@ async def _process_message(
|
||||
- Tool-call loop runs inside run_agent() — no separate Celery tasks
|
||||
- If run_agent returns a confirmation message: store pending action in Redis
|
||||
|
||||
Escalation pipeline (Phase 2 Plan 06 additions):
|
||||
Pre-check: if Redis escalation_status_key == "escalated", return assistant-mode reply
|
||||
Post-check: check_escalation_rules after LLM response; if triggered, escalate_to_human
|
||||
|
||||
Outbound routing (Phase 2 Plan 06 additions):
|
||||
All response delivery goes through _send_response() — never direct _update_slack_placeholder.
|
||||
extras dict is passed through to _send_response for channel-aware routing.
|
||||
|
||||
Args:
|
||||
msg: The deserialized KonstructMessage.
|
||||
placeholder_ts: Slack message timestamp of the "Thinking..." placeholder.
|
||||
channel_id: Slack channel ID for the chat.update call.
|
||||
msg: The deserialized KonstructMessage.
|
||||
extras: Channel-specific routing metadata. For Slack: placeholder_ts, channel_id,
|
||||
bot_token. For WhatsApp: phone_number_id, bot_token, wa_id.
|
||||
|
||||
Returns:
|
||||
Dict with message_id, response, and tenant_id.
|
||||
"""
|
||||
from orchestrator.agents.builder import build_messages_with_memory
|
||||
from orchestrator.agents.runner import run_agent
|
||||
from orchestrator.audit.logger import AuditLogger
|
||||
from orchestrator.memory.embedder import embed_text
|
||||
from orchestrator.memory.long_term import retrieve_relevant
|
||||
from orchestrator.memory.short_term import append_message, get_recent_messages
|
||||
from orchestrator.tools.registry import get_tools_for_agent
|
||||
from shared.db import async_session_factory, engine
|
||||
from shared.models.tenant import Agent
|
||||
from shared.rls import configure_rls_hook, current_tenant_id
|
||||
|
||||
if extras is None:
|
||||
extras = {}
|
||||
|
||||
if msg.tenant_id is None:
|
||||
logger.warning("Message %s has no tenant_id — cannot process", msg.id)
|
||||
@@ -265,8 +323,10 @@ async def _process_message(
|
||||
result = await session.execute(stmt)
|
||||
agent = result.scalars().first()
|
||||
|
||||
# Load the bot token for this tenant from channel_connections config
|
||||
if agent is not None and placeholder_ts and channel_id:
|
||||
# Load the Slack bot token for this tenant from channel_connections config.
|
||||
# This is needed for escalation DM delivery even on WhatsApp messages —
|
||||
# the escalation handler always sends via Slack DM to the assignee.
|
||||
if agent is not None and (extras.get("placeholder_ts") and extras.get("channel_id")):
|
||||
from shared.models.tenant import ChannelConnection, ChannelTypeEnum
|
||||
|
||||
conn_stmt = (
|
||||
@@ -290,13 +350,9 @@ async def _process_message(
|
||||
msg.id,
|
||||
)
|
||||
no_agent_response = "No active agent is configured for your workspace. Please contact your administrator."
|
||||
if placeholder_ts and channel_id:
|
||||
await _update_slack_placeholder(
|
||||
bot_token=slack_bot_token,
|
||||
channel_id=channel_id,
|
||||
placeholder_ts=placeholder_ts,
|
||||
text=no_agent_response,
|
||||
)
|
||||
# Build response_extras for channel-aware delivery
|
||||
response_extras = _build_response_extras(msg.channel, extras, slack_bot_token)
|
||||
await _send_response(msg.channel, no_agent_response, response_extras)
|
||||
return {
|
||||
"message_id": msg.id,
|
||||
"response": no_agent_response,
|
||||
@@ -318,14 +374,30 @@ async def _process_message(
|
||||
# -------------------------------------------------------------------------
|
||||
audit_logger = AuditLogger(session_factory=async_session_factory)
|
||||
|
||||
# Build response_extras dict used for all outbound delivery in this pipeline run.
|
||||
# For Slack: merges DB-loaded slack_bot_token with incoming extras.
|
||||
# For WhatsApp: extras already contain phone_number_id, bot_token, wa_id.
|
||||
response_extras = _build_response_extras(msg.channel, extras, slack_bot_token)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Escalation pre-check — if conversation already escalated, reply in assistant mode
|
||||
# -------------------------------------------------------------------------
|
||||
thread_key = msg.thread_id or user_id
|
||||
esc_key = escalation_status_key(msg.tenant_id, thread_key)
|
||||
pre_check_redis = aioredis.from_url(settings.redis_url)
|
||||
try:
|
||||
esc_status = await pre_check_redis.get(esc_key)
|
||||
finally:
|
||||
await pre_check_redis.aclose()
|
||||
|
||||
if esc_status == b"escalated":
|
||||
assistant_reply = "I've already connected you with a team member. They'll continue assisting you."
|
||||
await _send_response(msg.channel, assistant_reply, response_extras)
|
||||
return {"message_id": msg.id, "response": assistant_reply, "tenant_id": msg.tenant_id}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Pending tool confirmation check
|
||||
# -------------------------------------------------------------------------
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from shared.config import settings
|
||||
|
||||
redis_client = aioredis.from_url(settings.redis_url)
|
||||
pending_confirm_key = _PENDING_TOOL_KEY.format(
|
||||
tenant_id=msg.tenant_id,
|
||||
user_id=user_id,
|
||||
@@ -334,6 +406,7 @@ async def _process_message(
|
||||
response_text: str = ""
|
||||
handled_as_confirmation = False
|
||||
|
||||
redis_client = aioredis.from_url(settings.redis_url)
|
||||
try:
|
||||
pending_raw = await redis_client.get(pending_confirm_key)
|
||||
|
||||
@@ -362,13 +435,7 @@ async def _process_message(
|
||||
await redis_client.aclose()
|
||||
|
||||
if handled_as_confirmation:
|
||||
if placeholder_ts and channel_id:
|
||||
await _update_slack_placeholder(
|
||||
bot_token=slack_bot_token,
|
||||
channel_id=channel_id,
|
||||
placeholder_ts=placeholder_ts,
|
||||
text=response_text,
|
||||
)
|
||||
await _send_response(msg.channel, response_text, response_extras)
|
||||
return {
|
||||
"message_id": msg.id,
|
||||
"response": response_text,
|
||||
@@ -412,6 +479,7 @@ async def _process_message(
|
||||
current_message=user_text,
|
||||
recent_messages=recent_messages,
|
||||
relevant_context=relevant_context,
|
||||
channel=str(msg.channel) if msg.channel else "",
|
||||
)
|
||||
|
||||
# Build tool registry for this agent
|
||||
@@ -428,6 +496,37 @@ async def _process_message(
|
||||
tool_registry=tool_registry if tool_registry else None,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Escalation post-check — evaluate rules against conversation metadata
|
||||
# -------------------------------------------------------------------------
|
||||
conversation_metadata = _build_conversation_metadata(recent_messages, user_text)
|
||||
|
||||
triggered_rule = check_escalation_rules(
|
||||
agent=agent,
|
||||
message_text=user_text,
|
||||
conversation_metadata=conversation_metadata,
|
||||
natural_lang_enabled=getattr(agent, "natural_language_escalation", False),
|
||||
)
|
||||
|
||||
if triggered_rule and getattr(agent, "escalation_assignee", None):
|
||||
escalation_redis = aioredis.from_url(settings.redis_url)
|
||||
try:
|
||||
response_text = await escalate_to_human(
|
||||
tenant_id=msg.tenant_id,
|
||||
agent=agent,
|
||||
thread_id=thread_key,
|
||||
trigger_reason=triggered_rule.get("condition", "rule triggered"),
|
||||
recent_messages=recent_messages,
|
||||
assignee_slack_user_id=agent.escalation_assignee,
|
||||
bot_token=slack_bot_token,
|
||||
redis=escalation_redis,
|
||||
audit_logger=audit_logger,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id_str,
|
||||
)
|
||||
finally:
|
||||
await escalation_redis.aclose()
|
||||
|
||||
# Check if the response is a tool confirmation request
|
||||
# The confirmation message template starts with a specific prefix
|
||||
is_confirmation_request = response_text.startswith("This action requires your approval")
|
||||
@@ -453,14 +552,8 @@ async def _process_message(
|
||||
len(relevant_context),
|
||||
)
|
||||
|
||||
# Replace the "Thinking..." placeholder with the real response
|
||||
if placeholder_ts and channel_id:
|
||||
await _update_slack_placeholder(
|
||||
bot_token=slack_bot_token,
|
||||
channel_id=channel_id,
|
||||
placeholder_ts=placeholder_ts,
|
||||
text=response_text,
|
||||
)
|
||||
# Send response via channel-aware routing
|
||||
await _send_response(msg.channel, response_text, response_extras)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Memory persistence (after LLM response)
|
||||
@@ -489,6 +582,74 @@ async def _process_message(
|
||||
}
|
||||
|
||||
|
||||
def _build_response_extras(
|
||||
channel: Any,
|
||||
extras: dict[str, Any],
|
||||
slack_bot_token: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build the response_extras dict for channel-aware outbound delivery.
|
||||
|
||||
For Slack: injects slack_bot_token into extras["bot_token"] so _send_response
|
||||
can use it for chat.update calls.
|
||||
For WhatsApp: extras already contain phone_number_id, bot_token (access_token),
|
||||
and wa_id — no transformation needed.
|
||||
|
||||
Args:
|
||||
channel: Channel name from KonstructMessage.channel.
|
||||
extras: Incoming extras from handle_message.
|
||||
slack_bot_token: Bot token loaded from DB channel_connections.
|
||||
|
||||
Returns:
|
||||
Dict suitable for passing to _send_response.
|
||||
"""
|
||||
channel_str = str(channel) if channel else ""
|
||||
if channel_str == "slack":
|
||||
return {
|
||||
"bot_token": slack_bot_token,
|
||||
"channel_id": extras.get("channel_id", "") or "",
|
||||
"placeholder_ts": extras.get("placeholder_ts", "") or "",
|
||||
}
|
||||
elif channel_str == "whatsapp":
|
||||
return {
|
||||
"phone_number_id": extras.get("phone_number_id", "") or "",
|
||||
"bot_token": extras.get("bot_token", "") or "",
|
||||
"wa_id": extras.get("wa_id", "") or "",
|
||||
}
|
||||
else:
|
||||
return dict(extras)
|
||||
|
||||
|
||||
def _build_conversation_metadata(
|
||||
recent_messages: list[dict],
|
||||
current_text: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build conversation metadata dict for escalation rule evaluation.
|
||||
|
||||
Scans recent messages and current_text for billing keywords and counts occurrences.
|
||||
Returns a dict with:
|
||||
- "billing_dispute" (bool): True if any billing keyword found in any message
|
||||
- "attempts" (int): count of messages containing billing keywords
|
||||
|
||||
This is the v1 keyword-based metadata detection (per STATE.md decisions).
|
||||
|
||||
Args:
|
||||
recent_messages: List of {"role", "content"} dicts from Redis sliding window.
|
||||
current_text: The user's current message text.
|
||||
|
||||
Returns:
|
||||
Dict with billing_dispute and attempts keys.
|
||||
"""
|
||||
billing_keywords = {"billing", "invoice", "charge", "refund", "payment", "subscription"}
|
||||
all_texts = [m.get("content", "") for m in recent_messages] + [current_text]
|
||||
billing_count = sum(1 for t in all_texts if any(kw in t.lower() for kw in billing_keywords))
|
||||
return {
|
||||
"billing_dispute": billing_count > 0,
|
||||
"attempts": billing_count,
|
||||
}
|
||||
|
||||
|
||||
async def _execute_pending_tool(
|
||||
pending_data: dict,
|
||||
tenant_uuid: uuid.UUID,
|
||||
@@ -526,7 +687,7 @@ def _extract_tool_name_from_confirmation(confirmation_message: str) -> str:
|
||||
|
||||
|
||||
async def _send_response(
|
||||
channel: str,
|
||||
channel: Any,
|
||||
text: str,
|
||||
extras: dict,
|
||||
) -> None:
|
||||
@@ -545,7 +706,9 @@ async def _send_response(
|
||||
For Slack: ``bot_token``, ``channel_id``, ``placeholder_ts``
|
||||
For WhatsApp: ``phone_number_id``, ``bot_token`` (access_token), ``wa_id``
|
||||
"""
|
||||
if channel == "slack":
|
||||
channel_str = str(channel) if channel else ""
|
||||
|
||||
if channel_str == "slack":
|
||||
bot_token: str = extras.get("bot_token", "") or ""
|
||||
channel_id: str = extras.get("channel_id", "") or ""
|
||||
placeholder_ts: str = extras.get("placeholder_ts", "") or ""
|
||||
@@ -563,7 +726,7 @@ async def _send_response(
|
||||
text=text,
|
||||
)
|
||||
|
||||
elif channel == "whatsapp":
|
||||
elif channel_str == "whatsapp":
|
||||
phone_number_id: str = extras.get("phone_number_id", "") or ""
|
||||
access_token: str = extras.get("bot_token", "") or ""
|
||||
wa_id: str = extras.get("wa_id", "") or ""
|
||||
|
||||
@@ -15,7 +15,6 @@ Tests verify:
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@@ -24,7 +23,7 @@ import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# Shared test helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -93,6 +92,89 @@ def make_message_data(
|
||||
return data
|
||||
|
||||
|
||||
def make_fake_redis(escalated: bool = False) -> AsyncMock:
|
||||
"""Build a fakeredis-like mock for the async redis client."""
|
||||
redis_mock = AsyncMock()
|
||||
redis_mock.get = AsyncMock(return_value=b"escalated" if escalated else None)
|
||||
redis_mock.delete = AsyncMock()
|
||||
redis_mock.set = AsyncMock()
|
||||
redis_mock.setex = AsyncMock()
|
||||
redis_mock.aclose = AsyncMock()
|
||||
return redis_mock
|
||||
|
||||
|
||||
def make_session_factory_mock(agent: Any) -> MagicMock:
|
||||
"""
|
||||
Build a mock for async_session_factory that returns the given agent.
|
||||
|
||||
The returned mock is callable and its return value acts as an async
|
||||
context manager that yields a session whose execute() returns the agent.
|
||||
"""
|
||||
# Build result chain: session.execute() -> result.scalars().first() -> agent
|
||||
scalars_mock = MagicMock()
|
||||
scalars_mock.first.return_value = agent
|
||||
result_mock = MagicMock()
|
||||
result_mock.scalars.return_value = scalars_mock
|
||||
|
||||
session = AsyncMock()
|
||||
session.execute = AsyncMock(return_value=result_mock)
|
||||
|
||||
# Make session_factory() return an async context manager
|
||||
cm = AsyncMock()
|
||||
cm.__aenter__ = AsyncMock(return_value=session)
|
||||
cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
factory = MagicMock()
|
||||
factory.return_value = cm
|
||||
return factory
|
||||
|
||||
|
||||
def make_process_message_msg(
|
||||
channel: str = "slack",
|
||||
tenant_id: str | None = None,
|
||||
user_id: str = "U123",
|
||||
text: str = "hello",
|
||||
thread_id: str = "thread_001",
|
||||
) -> Any:
|
||||
"""Build a KonstructMessage for use with _process_message tests."""
|
||||
from shared.models.message import KonstructMessage
|
||||
|
||||
tid = tenant_id or str(uuid.uuid4())
|
||||
data: dict[str, Any] = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"tenant_id": tid,
|
||||
"channel": channel,
|
||||
"channel_metadata": {},
|
||||
"sender": {"user_id": user_id, "display_name": "Test User"},
|
||||
"content": {"text": text, "attachments": []},
|
||||
"timestamp": "2026-01-01T00:00:00Z",
|
||||
"thread_id": thread_id,
|
||||
"reply_to": None,
|
||||
"context": {},
|
||||
}
|
||||
return KonstructMessage.model_validate(data)
|
||||
|
||||
|
||||
def _standard_process_patches(
|
||||
agent: Any,
|
||||
escalated: bool = False,
|
||||
run_agent_return: str = "Agent reply",
|
||||
escalation_rule: dict | None = None,
|
||||
escalate_return: str = "I've brought in a team member",
|
||||
) -> dict:
|
||||
"""
|
||||
Return a dict of patch targets and their corresponding mocks.
|
||||
Callers use this with contextlib.ExitStack or individual patch() calls.
|
||||
"""
|
||||
return {
|
||||
"redis": make_fake_redis(escalated=escalated),
|
||||
"session_factory": make_session_factory_mock(agent),
|
||||
"run_agent_return": run_agent_return,
|
||||
"escalation_rule": escalation_rule,
|
||||
"escalate_return": escalate_return,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Task 1 Tests — handle_message extra-field extraction
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -242,44 +324,13 @@ class TestSendResponseRouting:
|
||||
class TestProcessMessageUsesSSendResponse:
|
||||
"""_process_message must use _send_response for ALL response delivery — never direct _update_slack_placeholder."""
|
||||
|
||||
def _make_fake_redis(self, escalated: bool = False) -> AsyncMock:
|
||||
"""Build a fakeredis-like mock for the async client."""
|
||||
redis_mock = AsyncMock()
|
||||
redis_mock.get = AsyncMock(return_value=b"escalated" if escalated else None)
|
||||
redis_mock.delete = AsyncMock()
|
||||
redis_mock.set = AsyncMock()
|
||||
redis_mock.setex = AsyncMock()
|
||||
redis_mock.aclose = AsyncMock()
|
||||
return redis_mock
|
||||
|
||||
def _make_session_mock(self, agent: Any) -> AsyncMock:
|
||||
"""Build a mock async DB session that returns the given agent."""
|
||||
session = AsyncMock()
|
||||
# Scalars mock
|
||||
scalars_mock = MagicMock()
|
||||
scalars_mock.first.return_value = agent
|
||||
result_mock = AsyncMock()
|
||||
result_mock.scalars.return_value = scalars_mock
|
||||
session.execute = AsyncMock(return_value=result_mock)
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_calls_send_response_not_update_slack(self) -> None:
|
||||
"""_process_message must call _send_response for Slack, NOT _update_slack_placeholder directly."""
|
||||
from orchestrator.tasks import _process_message
|
||||
from shared.models.message import KonstructMessage
|
||||
|
||||
agent = make_agent()
|
||||
tenant_id = str(uuid.uuid4())
|
||||
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
|
||||
msg_data.pop("placeholder_ts", None)
|
||||
msg_data.pop("channel_id", None)
|
||||
msg_data.pop("phone_number_id", None)
|
||||
msg_data.pop("bot_token", None)
|
||||
msg = KonstructMessage.model_validate(msg_data)
|
||||
|
||||
msg = make_process_message_msg(channel="slack")
|
||||
extras = {
|
||||
"placeholder_ts": "1234.5678",
|
||||
"channel_id": "C123",
|
||||
@@ -287,11 +338,14 @@ class TestProcessMessageUsesSSendResponse:
|
||||
"bot_token": "",
|
||||
}
|
||||
|
||||
redis_mock = make_fake_redis(escalated=False)
|
||||
sf_mock = make_session_factory_mock(agent)
|
||||
|
||||
with (
|
||||
patch("orchestrator.tasks._send_response", new_callable=AsyncMock) as mock_send,
|
||||
patch("orchestrator.tasks._update_slack_placeholder", new_callable=AsyncMock) as mock_update,
|
||||
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
||||
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
||||
patch("orchestrator.tasks.async_session_factory", sf_mock),
|
||||
patch("orchestrator.tasks.configure_rls_hook"),
|
||||
patch("orchestrator.tasks.current_tenant_id"),
|
||||
patch("orchestrator.tasks.engine"),
|
||||
@@ -306,15 +360,8 @@ class TestProcessMessageUsesSSendResponse:
|
||||
patch("orchestrator.tasks.AuditLogger"),
|
||||
patch("orchestrator.tasks.check_escalation_rules", return_value=None),
|
||||
):
|
||||
# Setup redis mock to return None for pre-check (not escalated)
|
||||
redis_mock = self._make_fake_redis(escalated=False)
|
||||
mock_aioredis.from_url.return_value = redis_mock
|
||||
|
||||
# Setup session to return our agent
|
||||
session_mock = self._make_session_mock(agent)
|
||||
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
||||
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
await _process_message(msg, extras=extras)
|
||||
|
||||
# _send_response MUST be called for delivery
|
||||
@@ -326,37 +373,30 @@ class TestProcessMessageUsesSSendResponse:
|
||||
async def test_whatsapp_extras_passed_to_send_response(self) -> None:
|
||||
"""_process_message must pass phone_number_id, bot_token, wa_id to _send_response for WhatsApp."""
|
||||
from orchestrator.tasks import _process_message
|
||||
from shared.models.message import KonstructMessage
|
||||
|
||||
agent = make_agent()
|
||||
tenant_id = str(uuid.uuid4())
|
||||
msg_data = make_message_data(
|
||||
channel="whatsapp",
|
||||
tenant_id=tenant_id,
|
||||
user_id="15551234567",
|
||||
)
|
||||
msg_data.pop("placeholder_ts", None)
|
||||
msg_data.pop("channel_id", None)
|
||||
msg_data.pop("phone_number_id", None)
|
||||
msg_data.pop("bot_token", None)
|
||||
msg = KonstructMessage.model_validate(msg_data)
|
||||
|
||||
# user_id = "15551234567" -> becomes wa_id
|
||||
msg = make_process_message_msg(channel="whatsapp", user_id="15551234567")
|
||||
extras = {
|
||||
"placeholder_ts": "",
|
||||
"channel_id": "",
|
||||
"phone_number_id": "9876543210",
|
||||
"bot_token": "EAAaccess",
|
||||
"wa_id": "15551234567",
|
||||
}
|
||||
|
||||
captured_extras: list[dict] = []
|
||||
|
||||
async def capture_send(channel: str, text: str, ext: dict) -> None:
|
||||
captured_extras.append({"channel": channel, "text": text, "extras": ext})
|
||||
async def capture_send(channel: Any, text: str, ext: dict) -> None:
|
||||
captured_extras.append({"channel": str(channel), "text": text, "extras": ext})
|
||||
|
||||
redis_mock = make_fake_redis(escalated=False)
|
||||
sf_mock = make_session_factory_mock(agent)
|
||||
|
||||
with (
|
||||
patch("orchestrator.tasks._send_response", side_effect=capture_send),
|
||||
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
||||
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
||||
patch("orchestrator.tasks.async_session_factory", sf_mock),
|
||||
patch("orchestrator.tasks.configure_rls_hook"),
|
||||
patch("orchestrator.tasks.current_tenant_id"),
|
||||
patch("orchestrator.tasks.engine"),
|
||||
@@ -371,15 +411,7 @@ class TestProcessMessageUsesSSendResponse:
|
||||
patch("orchestrator.tasks.AuditLogger"),
|
||||
patch("orchestrator.tasks.check_escalation_rules", return_value=None),
|
||||
):
|
||||
redis_mock = self._make_fake_redis(escalated=False)
|
||||
mock_aioredis.from_url.return_value = redis_mock
|
||||
# Configure context manager for redis client used in finally blocks
|
||||
redis_mock.__aenter__ = AsyncMock(return_value=redis_mock)
|
||||
redis_mock.__aexit__ = AsyncMock()
|
||||
|
||||
session_mock = self._make_session_mock(agent)
|
||||
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
||||
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
await _process_message(msg, extras=extras)
|
||||
|
||||
@@ -389,7 +421,7 @@ class TestProcessMessageUsesSSendResponse:
|
||||
wa_extras = wa_calls[0]["extras"]
|
||||
assert wa_extras.get("phone_number_id") == "9876543210"
|
||||
assert wa_extras.get("bot_token") == "EAAaccess"
|
||||
# wa_id must come from sender.user_id
|
||||
# wa_id must come from extras (which was populated from sender.user_id in handle_message)
|
||||
assert wa_extras.get("wa_id") == "15551234567"
|
||||
|
||||
|
||||
@@ -401,42 +433,22 @@ class TestProcessMessageUsesSSendResponse:
|
||||
class TestEscalationPreCheck:
|
||||
"""When escalation status is 'escalated' in Redis, _process_message must return early without LLM call."""
|
||||
|
||||
def _make_session_mock(self, agent: Any) -> AsyncMock:
|
||||
session = AsyncMock()
|
||||
scalars_mock = MagicMock()
|
||||
scalars_mock.first.return_value = agent
|
||||
result_mock = AsyncMock()
|
||||
result_mock.scalars.return_value = scalars_mock
|
||||
session.execute = AsyncMock(return_value=result_mock)
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalated_conversation_skips_run_agent(self) -> None:
|
||||
"""When Redis shows 'escalated', run_agent must NOT be called."""
|
||||
from orchestrator.tasks import _process_message
|
||||
from shared.models.message import KonstructMessage
|
||||
|
||||
agent = make_agent()
|
||||
tenant_id = str(uuid.uuid4())
|
||||
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
|
||||
msg_data.pop("placeholder_ts", None)
|
||||
msg_data.pop("channel_id", None)
|
||||
msg_data.pop("phone_number_id", None)
|
||||
msg_data.pop("bot_token", None)
|
||||
msg = KonstructMessage.model_validate(msg_data)
|
||||
|
||||
msg = make_process_message_msg(channel="slack")
|
||||
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
||||
|
||||
redis_mock = AsyncMock()
|
||||
redis_mock.get = AsyncMock(return_value=b"escalated")
|
||||
redis_mock.aclose = AsyncMock()
|
||||
redis_mock = make_fake_redis(escalated=True)
|
||||
sf_mock = make_session_factory_mock(agent)
|
||||
|
||||
with (
|
||||
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
||||
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
||||
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
||||
patch("orchestrator.tasks.async_session_factory", sf_mock),
|
||||
patch("orchestrator.tasks.configure_rls_hook"),
|
||||
patch("orchestrator.tasks.current_tenant_id"),
|
||||
patch("orchestrator.tasks.engine"),
|
||||
@@ -444,12 +456,6 @@ class TestEscalationPreCheck:
|
||||
patch("orchestrator.tasks.AuditLogger"),
|
||||
):
|
||||
mock_aioredis.from_url.return_value = redis_mock
|
||||
redis_mock.__aenter__ = AsyncMock(return_value=redis_mock)
|
||||
redis_mock.__aexit__ = AsyncMock()
|
||||
|
||||
session_mock = self._make_session_mock(agent)
|
||||
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
||||
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
result = await _process_message(msg, extras=extras)
|
||||
|
||||
@@ -468,42 +474,24 @@ class TestEscalationPreCheck:
|
||||
class TestEscalationPostCheck:
|
||||
"""check_escalation_rules called after run_agent; escalate_to_human called when rule matches."""
|
||||
|
||||
def _make_session_mock(self, agent: Any) -> AsyncMock:
|
||||
session = AsyncMock()
|
||||
scalars_mock = MagicMock()
|
||||
scalars_mock.first.return_value = agent
|
||||
result_mock = AsyncMock()
|
||||
result_mock.scalars.return_value = scalars_mock
|
||||
session.execute = AsyncMock(return_value=result_mock)
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_escalation_rules_called_after_run_agent(self) -> None:
|
||||
"""check_escalation_rules must be called after run_agent returns."""
|
||||
from orchestrator.tasks import _process_message
|
||||
from shared.models.message import KonstructMessage
|
||||
|
||||
agent = make_agent(escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}])
|
||||
tenant_id = str(uuid.uuid4())
|
||||
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
|
||||
msg_data.pop("placeholder_ts", None)
|
||||
msg_data.pop("channel_id", None)
|
||||
msg_data.pop("phone_number_id", None)
|
||||
msg_data.pop("bot_token", None)
|
||||
msg = KonstructMessage.model_validate(msg_data)
|
||||
|
||||
agent = make_agent(
|
||||
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}]
|
||||
)
|
||||
msg = make_process_message_msg(channel="slack")
|
||||
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
||||
|
||||
redis_mock = AsyncMock()
|
||||
redis_mock.get = AsyncMock(return_value=None)
|
||||
redis_mock.aclose = AsyncMock()
|
||||
redis_mock = make_fake_redis(escalated=False)
|
||||
sf_mock = make_session_factory_mock(agent)
|
||||
|
||||
with (
|
||||
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
||||
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
||||
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
||||
patch("orchestrator.tasks.async_session_factory", sf_mock),
|
||||
patch("orchestrator.tasks.configure_rls_hook"),
|
||||
patch("orchestrator.tasks.current_tenant_id"),
|
||||
patch("orchestrator.tasks.engine"),
|
||||
@@ -520,10 +508,6 @@ class TestEscalationPostCheck:
|
||||
):
|
||||
mock_aioredis.from_url.return_value = redis_mock
|
||||
|
||||
session_mock = self._make_session_mock(agent)
|
||||
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
||||
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
await _process_message(msg, extras=extras)
|
||||
|
||||
# check_escalation_rules must have been called
|
||||
@@ -533,32 +517,22 @@ class TestEscalationPostCheck:
|
||||
async def test_escalate_to_human_called_when_rule_matches_and_assignee_set(self) -> None:
|
||||
"""When check_escalation_rules returns a rule and assignee is set, escalate_to_human must be called."""
|
||||
from orchestrator.tasks import _process_message
|
||||
from shared.models.message import KonstructMessage
|
||||
|
||||
agent = make_agent(
|
||||
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}],
|
||||
escalation_assignee="U_MANAGER",
|
||||
)
|
||||
tenant_id = str(uuid.uuid4())
|
||||
msg_data = make_message_data(channel="slack", tenant_id=tenant_id, text="refund issue again")
|
||||
msg_data.pop("placeholder_ts", None)
|
||||
msg_data.pop("channel_id", None)
|
||||
msg_data.pop("phone_number_id", None)
|
||||
msg_data.pop("bot_token", None)
|
||||
msg = KonstructMessage.model_validate(msg_data)
|
||||
|
||||
msg = make_process_message_msg(channel="slack", text="refund issue again")
|
||||
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
||||
|
||||
matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}
|
||||
|
||||
redis_mock = AsyncMock()
|
||||
redis_mock.get = AsyncMock(return_value=None)
|
||||
redis_mock.aclose = AsyncMock()
|
||||
redis_mock = make_fake_redis(escalated=False)
|
||||
sf_mock = make_session_factory_mock(agent)
|
||||
|
||||
with (
|
||||
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
||||
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
||||
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
||||
patch("orchestrator.tasks.async_session_factory", sf_mock),
|
||||
patch("orchestrator.tasks.configure_rls_hook"),
|
||||
patch("orchestrator.tasks.current_tenant_id"),
|
||||
patch("orchestrator.tasks.engine"),
|
||||
@@ -572,51 +546,41 @@ class TestEscalationPostCheck:
|
||||
patch("orchestrator.tasks.get_tools_for_agent", return_value=[]),
|
||||
patch("orchestrator.tasks.AuditLogger"),
|
||||
patch("orchestrator.tasks.check_escalation_rules", return_value=matched_rule),
|
||||
patch("orchestrator.tasks.escalate_to_human", new_callable=AsyncMock, return_value="I've brought in a team member") as mock_escalate,
|
||||
patch(
|
||||
"orchestrator.tasks.escalate_to_human",
|
||||
new_callable=AsyncMock,
|
||||
return_value="I've brought in a team member",
|
||||
) as mock_escalate,
|
||||
):
|
||||
mock_aioredis.from_url.return_value = redis_mock
|
||||
|
||||
session_mock = self._make_session_mock(agent)
|
||||
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
||||
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
result = await _process_message(msg, extras=extras)
|
||||
|
||||
# escalate_to_human must be called
|
||||
mock_escalate.assert_called_once()
|
||||
# Response should be the escalation confirmation
|
||||
assert "team member" in result["response"] or mock_escalate.return_value in result["response"]
|
||||
assert "team member" in result["response"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_to_human_not_called_when_no_assignee(self) -> None:
|
||||
"""When rule matches but escalation_assignee is None, escalate_to_human must NOT be called."""
|
||||
from orchestrator.tasks import _process_message
|
||||
from shared.models.message import KonstructMessage
|
||||
|
||||
agent = make_agent(
|
||||
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}],
|
||||
escalation_assignee=None,
|
||||
)
|
||||
tenant_id = str(uuid.uuid4())
|
||||
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
|
||||
msg_data.pop("placeholder_ts", None)
|
||||
msg_data.pop("channel_id", None)
|
||||
msg_data.pop("phone_number_id", None)
|
||||
msg_data.pop("bot_token", None)
|
||||
msg = KonstructMessage.model_validate(msg_data)
|
||||
|
||||
msg = make_process_message_msg(channel="slack")
|
||||
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
||||
|
||||
matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}
|
||||
|
||||
redis_mock = AsyncMock()
|
||||
redis_mock.get = AsyncMock(return_value=None)
|
||||
redis_mock.aclose = AsyncMock()
|
||||
redis_mock = make_fake_redis(escalated=False)
|
||||
sf_mock = make_session_factory_mock(agent)
|
||||
|
||||
with (
|
||||
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
||||
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
||||
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
||||
patch("orchestrator.tasks.async_session_factory", sf_mock),
|
||||
patch("orchestrator.tasks.configure_rls_hook"),
|
||||
patch("orchestrator.tasks.current_tenant_id"),
|
||||
patch("orchestrator.tasks.engine"),
|
||||
@@ -634,10 +598,6 @@ class TestEscalationPostCheck:
|
||||
):
|
||||
mock_aioredis.from_url.return_value = redis_mock
|
||||
|
||||
session_mock = self._make_session_mock(agent)
|
||||
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
||||
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
await _process_message(msg, extras=extras)
|
||||
|
||||
# escalate_to_human must NOT be called — no assignee configured
|
||||
|
||||
Reference in New Issue
Block a user