feat(02-06): re-wire escalation and WhatsApp outbound routing in pipeline
- Move key imports to module level in tasks.py for testability and clarity
- Pop WhatsApp extras (phone_number_id, bot_token) in handle_message before model_validate
- Build unified extras dict and extract wa_id from sender.user_id
- Change _process_message signature to accept extras dict
- Add _build_response_extras() helper for channel-aware extras assembly
- Replace all _update_slack_placeholder calls in _process_message with _send_response()
- Add escalation pre-check: skip LLM when Redis escalation_status_key == 'escalated'
- Add escalation post-check: check_escalation_rules after run_agent; call escalate_to_human
when rule matches and agent.escalation_assignee is set
- Add _build_conversation_metadata() helper (billing keyword v1 detection)
- Add channel parameter to build_system_prompt(), build_messages_with_memory(),
build_messages_with_media() for WhatsApp tier-2 business-function scoping
- WhatsApp scoping appends 'You only handle: {topics}' when tool_assignments non-empty
- Pass msg.channel to build_messages_with_memory() in _process_message
- All 26 new tests pass; all existing escalation/WhatsApp tests pass (no regressions)
This commit is contained in:
@@ -133,7 +133,7 @@ def generate_presigned_url(storage_key: str, expiry: int = 3600) -> str:
|
|||||||
return presigned_url
|
return presigned_url
|
||||||
|
|
||||||
|
|
||||||
def build_system_prompt(agent: Agent) -> str:
|
def build_system_prompt(agent: Agent, channel: str = "") -> str:
|
||||||
"""
|
"""
|
||||||
Assemble the full system prompt for an agent.
|
Assemble the full system prompt for an agent.
|
||||||
|
|
||||||
@@ -142,9 +142,20 @@ def build_system_prompt(agent: Agent) -> str:
|
|||||||
- Identity section: name + role
|
- Identity section: name + role
|
||||||
- Persona section (if agent.persona is non-empty)
|
- Persona section (if agent.persona is non-empty)
|
||||||
- AI transparency clause (always appended)
|
- AI transparency clause (always appended)
|
||||||
|
- WhatsApp tier-2 scoping (only when channel == "whatsapp" and tool_assignments set)
|
||||||
|
|
||||||
|
WhatsApp Tier-2 Business-Function Scoping:
|
||||||
|
Per Meta 2026 policy, agents responding on WhatsApp must constrain themselves
|
||||||
|
to declared business functions. When channel == "whatsapp" and the agent has
|
||||||
|
non-empty tool_assignments, a scoping clause is appended that instructs the LLM
|
||||||
|
to handle only the declared topics and redirect off-topic requests.
|
||||||
|
This is the tier-2 gate (borderline messages reach the LLM with this constraint;
|
||||||
|
clearly off-topic messages are caught by the tier-1 keyword gate in the gateway).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent: ORM Agent instance.
|
agent: ORM Agent instance.
|
||||||
|
channel: Channel name (e.g. "slack", "whatsapp"). Used for tier-2 scoping.
|
||||||
|
Default empty string means no channel-specific scoping.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A complete system prompt string ready to pass to the LLM.
|
A complete system prompt string ready to pass to the LLM.
|
||||||
@@ -167,6 +178,17 @@ def build_system_prompt(agent: Agent) -> str:
|
|||||||
"If asked directly whether you are an AI, always respond honestly that you are an AI assistant."
|
"If asked directly whether you are an AI, always respond honestly that you are an AI assistant."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 5. WhatsApp tier-2 scoping — constrain LLM to declared business functions
|
||||||
|
if channel == "whatsapp":
|
||||||
|
functions: list[str] = getattr(agent, "tool_assignments", []) or []
|
||||||
|
if functions:
|
||||||
|
topics = ", ".join(functions)
|
||||||
|
parts.append(
|
||||||
|
f"You are responding on WhatsApp. You only handle: {topics}. "
|
||||||
|
f"If the user asks about something outside these topics, "
|
||||||
|
f"politely redirect them to the allowed topics."
|
||||||
|
)
|
||||||
|
|
||||||
return "\n\n".join(parts)
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
@@ -204,12 +226,14 @@ def build_messages_with_memory(
|
|||||||
current_message: str,
|
current_message: str,
|
||||||
recent_messages: list[dict],
|
recent_messages: list[dict],
|
||||||
relevant_context: list[str],
|
relevant_context: list[str],
|
||||||
|
channel: str = "",
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Build an LLM messages array enriched with two-layer memory.
|
Build an LLM messages array enriched with two-layer memory.
|
||||||
|
|
||||||
Structure (in order):
|
Structure (in order):
|
||||||
1. System message — agent identity, persona, AI transparency clause
|
1. System message — agent identity, persona, AI transparency clause,
|
||||||
|
and optional WhatsApp tier-2 scoping (when channel == "whatsapp")
|
||||||
2. System message — long-term context from pgvector (ONLY if non-empty)
|
2. System message — long-term context from pgvector (ONLY if non-empty)
|
||||||
Injected as a system message before the sliding window so the LLM
|
Injected as a system message before the sliding window so the LLM
|
||||||
has relevant background without it appearing in the conversation.
|
has relevant background without it appearing in the conversation.
|
||||||
@@ -226,11 +250,13 @@ def build_messages_with_memory(
|
|||||||
from Redis sliding window (oldest first).
|
from Redis sliding window (oldest first).
|
||||||
relevant_context: Long-term memory — list of content strings from
|
relevant_context: Long-term memory — list of content strings from
|
||||||
pgvector similarity search (most relevant first).
|
pgvector similarity search (most relevant first).
|
||||||
|
channel: Channel name for tier-2 scoping (e.g. "whatsapp").
|
||||||
|
Default empty string means no channel-specific scoping.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of message dicts suitable for an OpenAI-compatible API call.
|
List of message dicts suitable for an OpenAI-compatible API call.
|
||||||
"""
|
"""
|
||||||
system_prompt = build_system_prompt(agent)
|
system_prompt = build_system_prompt(agent, channel=channel)
|
||||||
messages: list[dict] = [{"role": "system", "content": system_prompt}]
|
messages: list[dict] = [{"role": "system", "content": system_prompt}]
|
||||||
|
|
||||||
# Inject long-term pgvector context as a system message BEFORE sliding window
|
# Inject long-term pgvector context as a system message BEFORE sliding window
|
||||||
@@ -256,6 +282,7 @@ def build_messages_with_media(
|
|||||||
media_attachments: list[MediaAttachment],
|
media_attachments: list[MediaAttachment],
|
||||||
recent_messages: list[dict],
|
recent_messages: list[dict],
|
||||||
relevant_context: list[str],
|
relevant_context: list[str],
|
||||||
|
channel: str = "",
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Build an LLM messages array with memory enrichment AND multimodal media injection.
|
Build an LLM messages array with memory enrichment AND multimodal media injection.
|
||||||
@@ -277,7 +304,8 @@ def build_messages_with_media(
|
|||||||
gracefully rather than raising an error.
|
gracefully rather than raising an error.
|
||||||
|
|
||||||
Structure (in order):
|
Structure (in order):
|
||||||
1. System message — agent identity, persona, AI transparency clause
|
1. System message — agent identity, persona, AI transparency clause,
|
||||||
|
and optional WhatsApp tier-2 scoping (when channel == "whatsapp")
|
||||||
2. System message — long-term context (ONLY if non-empty)
|
2. System message — long-term context (ONLY if non-empty)
|
||||||
3. Sliding window messages — recent history
|
3. Sliding window messages — recent history
|
||||||
4. Current user message (plain string or multipart content list)
|
4. Current user message (plain string or multipart content list)
|
||||||
@@ -289,6 +317,8 @@ def build_messages_with_media(
|
|||||||
Empty list produces the same output as build_messages_with_memory().
|
Empty list produces the same output as build_messages_with_memory().
|
||||||
recent_messages: Short-term memory — list of {"role", "content"} dicts.
|
recent_messages: Short-term memory — list of {"role", "content"} dicts.
|
||||||
relevant_context: Long-term memory — list of content strings from pgvector.
|
relevant_context: Long-term memory — list of content strings from pgvector.
|
||||||
|
channel: Channel name for tier-2 scoping (e.g. "whatsapp").
|
||||||
|
Default empty string means no channel-specific scoping.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of message dicts suitable for an OpenAI/LiteLLM-compatible API call.
|
List of message dicts suitable for an OpenAI/LiteLLM-compatible API call.
|
||||||
@@ -301,6 +331,7 @@ def build_messages_with_media(
|
|||||||
current_message=current_message,
|
current_message=current_message,
|
||||||
recent_messages=recent_messages,
|
recent_messages=recent_messages,
|
||||||
relevant_context=relevant_context,
|
relevant_context=relevant_context,
|
||||||
|
channel=channel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If no media attachments, return the base messages unchanged
|
# If no media attachments, return the base messages unchanged
|
||||||
|
|||||||
@@ -35,6 +35,23 @@ Pending tool confirmation:
|
|||||||
On the next user message, if a pending confirmation exists:
|
On the next user message, if a pending confirmation exists:
|
||||||
- "yes" → execute the pending tool and continue
|
- "yes" → execute the pending tool and continue
|
||||||
- "no" / anything else → cancel and inform the user
|
- "no" / anything else → cancel and inform the user
|
||||||
|
|
||||||
|
Escalation pipeline (Phase 2 Plan 06):
|
||||||
|
Pre-check (before LLM call):
|
||||||
|
If Redis escalation_status_key == "escalated", return early assistant-mode reply.
|
||||||
|
This prevents the LLM from being called when a human has already taken over.
|
||||||
|
|
||||||
|
Post-check (after LLM response):
|
||||||
|
check_escalation_rules() evaluates configured rules against conversation metadata.
|
||||||
|
If a rule matches AND agent.escalation_assignee is set, escalate_to_human() is
|
||||||
|
called and its return value replaces the LLM response.
|
||||||
|
|
||||||
|
Outbound routing (Phase 2 Plan 06):
|
||||||
|
All response delivery goes through _send_response() which routes to:
|
||||||
|
- Slack: _update_slack_placeholder() via chat.update
|
||||||
|
- WhatsApp: send_whatsapp_message() via Meta Cloud API
|
||||||
|
handle_message now pops WhatsApp extras (phone_number_id, bot_token) and
|
||||||
|
passes them through to _process_message via the extras dict.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -43,10 +60,25 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
from gateway.channels.whatsapp import send_whatsapp_message
|
from gateway.channels.whatsapp import send_whatsapp_message
|
||||||
|
from orchestrator.agents.builder import build_messages_with_memory
|
||||||
|
from orchestrator.agents.runner import run_agent
|
||||||
|
from orchestrator.audit.logger import AuditLogger
|
||||||
|
from orchestrator.escalation.handler import check_escalation_rules, escalate_to_human
|
||||||
from orchestrator.main import app
|
from orchestrator.main import app
|
||||||
|
from orchestrator.memory.embedder import embed_text
|
||||||
|
from orchestrator.memory.long_term import retrieve_relevant
|
||||||
|
from orchestrator.memory.short_term import append_message, get_recent_messages
|
||||||
|
from orchestrator.tools.registry import get_tools_for_agent
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.db import async_session_factory, engine
|
||||||
from shared.models.message import KonstructMessage
|
from shared.models.message import KonstructMessage
|
||||||
|
from shared.redis_keys import escalation_status_key
|
||||||
|
from shared.rls import configure_rls_hook, current_tenant_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -152,23 +184,28 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped
|
|||||||
by the Channel Gateway after tenant resolution completes.
|
by the Channel Gateway after tenant resolution completes.
|
||||||
|
|
||||||
The ``message_data`` dict MAY contain extra keys beyond KonstructMessage
|
The ``message_data`` dict MAY contain extra keys beyond KonstructMessage
|
||||||
fields. Specifically, the Slack handler injects:
|
fields. Specifically:
|
||||||
- ``placeholder_ts``: Slack message timestamp of the "Thinking..." placeholder
|
- Slack handler injects:
|
||||||
- ``channel_id``: Slack channel ID where the response should be posted
|
``placeholder_ts``: Slack message timestamp of the "Thinking..." placeholder
|
||||||
|
``channel_id``: Slack channel ID where the response should be posted
|
||||||
|
- WhatsApp gateway injects:
|
||||||
|
``phone_number_id``: WhatsApp phone number ID for outbound messaging
|
||||||
|
``bot_token``: WhatsApp access_token (injected as bot_token by gateway)
|
||||||
|
|
||||||
These are extracted before KonstructMessage validation and used to update
|
These are extracted before KonstructMessage validation and used to route
|
||||||
the placeholder with the real LLM response via chat.update.
|
outbound responses via _send_response().
|
||||||
|
|
||||||
Pipeline:
|
Pipeline:
|
||||||
1. Extract Slack reply metadata (placeholder_ts, channel_id) if present
|
1. Extract channel reply metadata (Slack + WhatsApp extras) if present
|
||||||
2. Deserialize message_data -> KonstructMessage
|
2. Deserialize message_data -> KonstructMessage
|
||||||
3. Run async agent pipeline via asyncio.run()
|
3. Extract wa_id from sender.user_id for WhatsApp messages
|
||||||
4. If Slack metadata present: call chat.update to replace placeholder
|
4. Build extras dict for channel-aware outbound routing
|
||||||
5. Return response dict
|
5. Run async agent pipeline via asyncio.run()
|
||||||
|
6. Return response dict
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_data: JSON-serializable dict. Must contain KonstructMessage
|
message_data: JSON-serializable dict. Must contain KonstructMessage
|
||||||
fields plus optional ``placeholder_ts`` and ``channel_id``.
|
fields plus optional channel-specific extras.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with keys:
|
Dict with keys:
|
||||||
@@ -181,20 +218,39 @@ def handle_message(self, message_data: dict) -> dict: # type: ignore[no-untyped
|
|||||||
placeholder_ts: str = message_data.pop("placeholder_ts", "") or ""
|
placeholder_ts: str = message_data.pop("placeholder_ts", "") or ""
|
||||||
channel_id: str = message_data.pop("channel_id", "") or ""
|
channel_id: str = message_data.pop("channel_id", "") or ""
|
||||||
|
|
||||||
|
# Extract WhatsApp-specific extras before model validation
|
||||||
|
# The WhatsApp gateway injects these alongside the normalized KonstructMessage fields
|
||||||
|
phone_number_id: str = message_data.pop("phone_number_id", "") or ""
|
||||||
|
bot_token: str = message_data.pop("bot_token", "") or ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
msg = KonstructMessage.model_validate(message_data)
|
msg = KonstructMessage.model_validate(message_data)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Failed to deserialize KonstructMessage: %s", message_data)
|
logger.exception("Failed to deserialize KonstructMessage: %s", message_data)
|
||||||
raise self.retry(exc=exc)
|
raise self.retry(exc=exc)
|
||||||
|
|
||||||
result = asyncio.run(_process_message(msg, placeholder_ts=placeholder_ts, channel_id=channel_id))
|
# Extract wa_id from sender.user_id — WhatsApp normalizer sets sender.user_id
|
||||||
|
# to the wa_id (recipient phone number). This must happen AFTER model_validate.
|
||||||
|
wa_id: str = ""
|
||||||
|
if msg.channel == "whatsapp" and msg.sender and msg.sender.user_id:
|
||||||
|
wa_id = msg.sender.user_id
|
||||||
|
|
||||||
|
# Build the unified extras dict for channel-aware outbound routing
|
||||||
|
extras: dict[str, Any] = {
|
||||||
|
"placeholder_ts": placeholder_ts,
|
||||||
|
"channel_id": channel_id,
|
||||||
|
"phone_number_id": phone_number_id,
|
||||||
|
"bot_token": bot_token,
|
||||||
|
"wa_id": wa_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = asyncio.run(_process_message(msg, extras=extras))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def _process_message(
|
async def _process_message(
|
||||||
msg: KonstructMessage,
|
msg: KonstructMessage,
|
||||||
placeholder_ts: str = "",
|
extras: dict[str, Any] | None = None,
|
||||||
channel_id: str = "",
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Async agent pipeline — load agent config, build memory-enriched prompt, call LLM pool.
|
Async agent pipeline — load agent config, build memory-enriched prompt, call LLM pool.
|
||||||
@@ -215,24 +271,26 @@ async def _process_message(
|
|||||||
- Tool-call loop runs inside run_agent() — no separate Celery tasks
|
- Tool-call loop runs inside run_agent() — no separate Celery tasks
|
||||||
- If run_agent returns a confirmation message: store pending action in Redis
|
- If run_agent returns a confirmation message: store pending action in Redis
|
||||||
|
|
||||||
|
Escalation pipeline (Phase 2 Plan 06 additions):
|
||||||
|
Pre-check: if Redis escalation_status_key == "escalated", return assistant-mode reply
|
||||||
|
Post-check: check_escalation_rules after LLM response; if triggered, escalate_to_human
|
||||||
|
|
||||||
|
Outbound routing (Phase 2 Plan 06 additions):
|
||||||
|
All response delivery goes through _send_response() — never direct _update_slack_placeholder.
|
||||||
|
extras dict is passed through to _send_response for channel-aware routing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
msg: The deserialized KonstructMessage.
|
msg: The deserialized KonstructMessage.
|
||||||
placeholder_ts: Slack message timestamp of the "Thinking..." placeholder.
|
extras: Channel-specific routing metadata. For Slack: placeholder_ts, channel_id,
|
||||||
channel_id: Slack channel ID for the chat.update call.
|
bot_token. For WhatsApp: phone_number_id, bot_token, wa_id.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with message_id, response, and tenant_id.
|
Dict with message_id, response, and tenant_id.
|
||||||
"""
|
"""
|
||||||
from orchestrator.agents.builder import build_messages_with_memory
|
|
||||||
from orchestrator.agents.runner import run_agent
|
|
||||||
from orchestrator.audit.logger import AuditLogger
|
|
||||||
from orchestrator.memory.embedder import embed_text
|
|
||||||
from orchestrator.memory.long_term import retrieve_relevant
|
|
||||||
from orchestrator.memory.short_term import append_message, get_recent_messages
|
|
||||||
from orchestrator.tools.registry import get_tools_for_agent
|
|
||||||
from shared.db import async_session_factory, engine
|
|
||||||
from shared.models.tenant import Agent
|
from shared.models.tenant import Agent
|
||||||
from shared.rls import configure_rls_hook, current_tenant_id
|
|
||||||
|
if extras is None:
|
||||||
|
extras = {}
|
||||||
|
|
||||||
if msg.tenant_id is None:
|
if msg.tenant_id is None:
|
||||||
logger.warning("Message %s has no tenant_id — cannot process", msg.id)
|
logger.warning("Message %s has no tenant_id — cannot process", msg.id)
|
||||||
@@ -265,8 +323,10 @@ async def _process_message(
|
|||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
agent = result.scalars().first()
|
agent = result.scalars().first()
|
||||||
|
|
||||||
# Load the bot token for this tenant from channel_connections config
|
# Load the Slack bot token for this tenant from channel_connections config.
|
||||||
if agent is not None and placeholder_ts and channel_id:
|
# This is needed for escalation DM delivery even on WhatsApp messages —
|
||||||
|
# the escalation handler always sends via Slack DM to the assignee.
|
||||||
|
if agent is not None and (extras.get("placeholder_ts") and extras.get("channel_id")):
|
||||||
from shared.models.tenant import ChannelConnection, ChannelTypeEnum
|
from shared.models.tenant import ChannelConnection, ChannelTypeEnum
|
||||||
|
|
||||||
conn_stmt = (
|
conn_stmt = (
|
||||||
@@ -290,13 +350,9 @@ async def _process_message(
|
|||||||
msg.id,
|
msg.id,
|
||||||
)
|
)
|
||||||
no_agent_response = "No active agent is configured for your workspace. Please contact your administrator."
|
no_agent_response = "No active agent is configured for your workspace. Please contact your administrator."
|
||||||
if placeholder_ts and channel_id:
|
# Build response_extras for channel-aware delivery
|
||||||
await _update_slack_placeholder(
|
response_extras = _build_response_extras(msg.channel, extras, slack_bot_token)
|
||||||
bot_token=slack_bot_token,
|
await _send_response(msg.channel, no_agent_response, response_extras)
|
||||||
channel_id=channel_id,
|
|
||||||
placeholder_ts=placeholder_ts,
|
|
||||||
text=no_agent_response,
|
|
||||||
)
|
|
||||||
return {
|
return {
|
||||||
"message_id": msg.id,
|
"message_id": msg.id,
|
||||||
"response": no_agent_response,
|
"response": no_agent_response,
|
||||||
@@ -318,14 +374,30 @@ async def _process_message(
|
|||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
audit_logger = AuditLogger(session_factory=async_session_factory)
|
audit_logger = AuditLogger(session_factory=async_session_factory)
|
||||||
|
|
||||||
|
# Build response_extras dict used for all outbound delivery in this pipeline run.
|
||||||
|
# For Slack: merges DB-loaded slack_bot_token with incoming extras.
|
||||||
|
# For WhatsApp: extras already contain phone_number_id, bot_token, wa_id.
|
||||||
|
response_extras = _build_response_extras(msg.channel, extras, slack_bot_token)
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Escalation pre-check — if conversation already escalated, reply in assistant mode
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
thread_key = msg.thread_id or user_id
|
||||||
|
esc_key = escalation_status_key(msg.tenant_id, thread_key)
|
||||||
|
pre_check_redis = aioredis.from_url(settings.redis_url)
|
||||||
|
try:
|
||||||
|
esc_status = await pre_check_redis.get(esc_key)
|
||||||
|
finally:
|
||||||
|
await pre_check_redis.aclose()
|
||||||
|
|
||||||
|
if esc_status == b"escalated":
|
||||||
|
assistant_reply = "I've already connected you with a team member. They'll continue assisting you."
|
||||||
|
await _send_response(msg.channel, assistant_reply, response_extras)
|
||||||
|
return {"message_id": msg.id, "response": assistant_reply, "tenant_id": msg.tenant_id}
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# Pending tool confirmation check
|
# Pending tool confirmation check
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
import redis.asyncio as aioredis
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
|
|
||||||
redis_client = aioredis.from_url(settings.redis_url)
|
|
||||||
pending_confirm_key = _PENDING_TOOL_KEY.format(
|
pending_confirm_key = _PENDING_TOOL_KEY.format(
|
||||||
tenant_id=msg.tenant_id,
|
tenant_id=msg.tenant_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -334,6 +406,7 @@ async def _process_message(
|
|||||||
response_text: str = ""
|
response_text: str = ""
|
||||||
handled_as_confirmation = False
|
handled_as_confirmation = False
|
||||||
|
|
||||||
|
redis_client = aioredis.from_url(settings.redis_url)
|
||||||
try:
|
try:
|
||||||
pending_raw = await redis_client.get(pending_confirm_key)
|
pending_raw = await redis_client.get(pending_confirm_key)
|
||||||
|
|
||||||
@@ -362,13 +435,7 @@ async def _process_message(
|
|||||||
await redis_client.aclose()
|
await redis_client.aclose()
|
||||||
|
|
||||||
if handled_as_confirmation:
|
if handled_as_confirmation:
|
||||||
if placeholder_ts and channel_id:
|
await _send_response(msg.channel, response_text, response_extras)
|
||||||
await _update_slack_placeholder(
|
|
||||||
bot_token=slack_bot_token,
|
|
||||||
channel_id=channel_id,
|
|
||||||
placeholder_ts=placeholder_ts,
|
|
||||||
text=response_text,
|
|
||||||
)
|
|
||||||
return {
|
return {
|
||||||
"message_id": msg.id,
|
"message_id": msg.id,
|
||||||
"response": response_text,
|
"response": response_text,
|
||||||
@@ -412,6 +479,7 @@ async def _process_message(
|
|||||||
current_message=user_text,
|
current_message=user_text,
|
||||||
recent_messages=recent_messages,
|
recent_messages=recent_messages,
|
||||||
relevant_context=relevant_context,
|
relevant_context=relevant_context,
|
||||||
|
channel=str(msg.channel) if msg.channel else "",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build tool registry for this agent
|
# Build tool registry for this agent
|
||||||
@@ -428,6 +496,37 @@ async def _process_message(
|
|||||||
tool_registry=tool_registry if tool_registry else None,
|
tool_registry=tool_registry if tool_registry else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Escalation post-check — evaluate rules against conversation metadata
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
conversation_metadata = _build_conversation_metadata(recent_messages, user_text)
|
||||||
|
|
||||||
|
triggered_rule = check_escalation_rules(
|
||||||
|
agent=agent,
|
||||||
|
message_text=user_text,
|
||||||
|
conversation_metadata=conversation_metadata,
|
||||||
|
natural_lang_enabled=getattr(agent, "natural_language_escalation", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
if triggered_rule and getattr(agent, "escalation_assignee", None):
|
||||||
|
escalation_redis = aioredis.from_url(settings.redis_url)
|
||||||
|
try:
|
||||||
|
response_text = await escalate_to_human(
|
||||||
|
tenant_id=msg.tenant_id,
|
||||||
|
agent=agent,
|
||||||
|
thread_id=thread_key,
|
||||||
|
trigger_reason=triggered_rule.get("condition", "rule triggered"),
|
||||||
|
recent_messages=recent_messages,
|
||||||
|
assignee_slack_user_id=agent.escalation_assignee,
|
||||||
|
bot_token=slack_bot_token,
|
||||||
|
redis=escalation_redis,
|
||||||
|
audit_logger=audit_logger,
|
||||||
|
user_id=user_id,
|
||||||
|
agent_id=agent_id_str,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await escalation_redis.aclose()
|
||||||
|
|
||||||
# Check if the response is a tool confirmation request
|
# Check if the response is a tool confirmation request
|
||||||
# The confirmation message template starts with a specific prefix
|
# The confirmation message template starts with a specific prefix
|
||||||
is_confirmation_request = response_text.startswith("This action requires your approval")
|
is_confirmation_request = response_text.startswith("This action requires your approval")
|
||||||
@@ -453,14 +552,8 @@ async def _process_message(
|
|||||||
len(relevant_context),
|
len(relevant_context),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replace the "Thinking..." placeholder with the real response
|
# Send response via channel-aware routing
|
||||||
if placeholder_ts and channel_id:
|
await _send_response(msg.channel, response_text, response_extras)
|
||||||
await _update_slack_placeholder(
|
|
||||||
bot_token=slack_bot_token,
|
|
||||||
channel_id=channel_id,
|
|
||||||
placeholder_ts=placeholder_ts,
|
|
||||||
text=response_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# Memory persistence (after LLM response)
|
# Memory persistence (after LLM response)
|
||||||
@@ -489,6 +582,74 @@ async def _process_message(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_response_extras(
|
||||||
|
channel: Any,
|
||||||
|
extras: dict[str, Any],
|
||||||
|
slack_bot_token: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Build the response_extras dict for channel-aware outbound delivery.
|
||||||
|
|
||||||
|
For Slack: injects slack_bot_token into extras["bot_token"] so _send_response
|
||||||
|
can use it for chat.update calls.
|
||||||
|
For WhatsApp: extras already contain phone_number_id, bot_token (access_token),
|
||||||
|
and wa_id — no transformation needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: Channel name from KonstructMessage.channel.
|
||||||
|
extras: Incoming extras from handle_message.
|
||||||
|
slack_bot_token: Bot token loaded from DB channel_connections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict suitable for passing to _send_response.
|
||||||
|
"""
|
||||||
|
channel_str = str(channel) if channel else ""
|
||||||
|
if channel_str == "slack":
|
||||||
|
return {
|
||||||
|
"bot_token": slack_bot_token,
|
||||||
|
"channel_id": extras.get("channel_id", "") or "",
|
||||||
|
"placeholder_ts": extras.get("placeholder_ts", "") or "",
|
||||||
|
}
|
||||||
|
elif channel_str == "whatsapp":
|
||||||
|
return {
|
||||||
|
"phone_number_id": extras.get("phone_number_id", "") or "",
|
||||||
|
"bot_token": extras.get("bot_token", "") or "",
|
||||||
|
"wa_id": extras.get("wa_id", "") or "",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return dict(extras)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_conversation_metadata(
|
||||||
|
recent_messages: list[dict],
|
||||||
|
current_text: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Build conversation metadata dict for escalation rule evaluation.
|
||||||
|
|
||||||
|
Scans recent messages and current_text for billing keywords and counts occurrences.
|
||||||
|
Returns a dict with:
|
||||||
|
- "billing_dispute" (bool): True if any billing keyword found in any message
|
||||||
|
- "attempts" (int): count of messages containing billing keywords
|
||||||
|
|
||||||
|
This is the v1 keyword-based metadata detection (per STATE.md decisions).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recent_messages: List of {"role", "content"} dicts from Redis sliding window.
|
||||||
|
current_text: The user's current message text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with billing_dispute and attempts keys.
|
||||||
|
"""
|
||||||
|
billing_keywords = {"billing", "invoice", "charge", "refund", "payment", "subscription"}
|
||||||
|
all_texts = [m.get("content", "") for m in recent_messages] + [current_text]
|
||||||
|
billing_count = sum(1 for t in all_texts if any(kw in t.lower() for kw in billing_keywords))
|
||||||
|
return {
|
||||||
|
"billing_dispute": billing_count > 0,
|
||||||
|
"attempts": billing_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def _execute_pending_tool(
|
async def _execute_pending_tool(
|
||||||
pending_data: dict,
|
pending_data: dict,
|
||||||
tenant_uuid: uuid.UUID,
|
tenant_uuid: uuid.UUID,
|
||||||
@@ -526,7 +687,7 @@ def _extract_tool_name_from_confirmation(confirmation_message: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def _send_response(
|
async def _send_response(
|
||||||
channel: str,
|
channel: Any,
|
||||||
text: str,
|
text: str,
|
||||||
extras: dict,
|
extras: dict,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -545,7 +706,9 @@ async def _send_response(
|
|||||||
For Slack: ``bot_token``, ``channel_id``, ``placeholder_ts``
|
For Slack: ``bot_token``, ``channel_id``, ``placeholder_ts``
|
||||||
For WhatsApp: ``phone_number_id``, ``bot_token`` (access_token), ``wa_id``
|
For WhatsApp: ``phone_number_id``, ``bot_token`` (access_token), ``wa_id``
|
||||||
"""
|
"""
|
||||||
if channel == "slack":
|
channel_str = str(channel) if channel else ""
|
||||||
|
|
||||||
|
if channel_str == "slack":
|
||||||
bot_token: str = extras.get("bot_token", "") or ""
|
bot_token: str = extras.get("bot_token", "") or ""
|
||||||
channel_id: str = extras.get("channel_id", "") or ""
|
channel_id: str = extras.get("channel_id", "") or ""
|
||||||
placeholder_ts: str = extras.get("placeholder_ts", "") or ""
|
placeholder_ts: str = extras.get("placeholder_ts", "") or ""
|
||||||
@@ -563,7 +726,7 @@ async def _send_response(
|
|||||||
text=text,
|
text=text,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif channel == "whatsapp":
|
elif channel_str == "whatsapp":
|
||||||
phone_number_id: str = extras.get("phone_number_id", "") or ""
|
phone_number_id: str = extras.get("phone_number_id", "") or ""
|
||||||
access_token: str = extras.get("bot_token", "") or ""
|
access_token: str = extras.get("bot_token", "") or ""
|
||||||
wa_id: str = extras.get("wa_id", "") or ""
|
wa_id: str = extras.get("wa_id", "") or ""
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ Tests verify:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
@@ -24,7 +23,7 @@ import pytest
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Shared test helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -93,6 +92,89 @@ def make_message_data(
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def make_fake_redis(escalated: bool = False) -> AsyncMock:
|
||||||
|
"""Build a fakeredis-like mock for the async redis client."""
|
||||||
|
redis_mock = AsyncMock()
|
||||||
|
redis_mock.get = AsyncMock(return_value=b"escalated" if escalated else None)
|
||||||
|
redis_mock.delete = AsyncMock()
|
||||||
|
redis_mock.set = AsyncMock()
|
||||||
|
redis_mock.setex = AsyncMock()
|
||||||
|
redis_mock.aclose = AsyncMock()
|
||||||
|
return redis_mock
|
||||||
|
|
||||||
|
|
||||||
|
def make_session_factory_mock(agent: Any) -> MagicMock:
|
||||||
|
"""
|
||||||
|
Build a mock for async_session_factory that returns the given agent.
|
||||||
|
|
||||||
|
The returned mock is callable and its return value acts as an async
|
||||||
|
context manager that yields a session whose execute() returns the agent.
|
||||||
|
"""
|
||||||
|
# Build result chain: session.execute() -> result.scalars().first() -> agent
|
||||||
|
scalars_mock = MagicMock()
|
||||||
|
scalars_mock.first.return_value = agent
|
||||||
|
result_mock = MagicMock()
|
||||||
|
result_mock.scalars.return_value = scalars_mock
|
||||||
|
|
||||||
|
session = AsyncMock()
|
||||||
|
session.execute = AsyncMock(return_value=result_mock)
|
||||||
|
|
||||||
|
# Make session_factory() return an async context manager
|
||||||
|
cm = AsyncMock()
|
||||||
|
cm.__aenter__ = AsyncMock(return_value=session)
|
||||||
|
cm.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
factory = MagicMock()
|
||||||
|
factory.return_value = cm
|
||||||
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
def make_process_message_msg(
|
||||||
|
channel: str = "slack",
|
||||||
|
tenant_id: str | None = None,
|
||||||
|
user_id: str = "U123",
|
||||||
|
text: str = "hello",
|
||||||
|
thread_id: str = "thread_001",
|
||||||
|
) -> Any:
|
||||||
|
"""Build a KonstructMessage for use with _process_message tests."""
|
||||||
|
from shared.models.message import KonstructMessage
|
||||||
|
|
||||||
|
tid = tenant_id or str(uuid.uuid4())
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"tenant_id": tid,
|
||||||
|
"channel": channel,
|
||||||
|
"channel_metadata": {},
|
||||||
|
"sender": {"user_id": user_id, "display_name": "Test User"},
|
||||||
|
"content": {"text": text, "attachments": []},
|
||||||
|
"timestamp": "2026-01-01T00:00:00Z",
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"reply_to": None,
|
||||||
|
"context": {},
|
||||||
|
}
|
||||||
|
return KonstructMessage.model_validate(data)
|
||||||
|
|
||||||
|
|
||||||
|
def _standard_process_patches(
|
||||||
|
agent: Any,
|
||||||
|
escalated: bool = False,
|
||||||
|
run_agent_return: str = "Agent reply",
|
||||||
|
escalation_rule: dict | None = None,
|
||||||
|
escalate_return: str = "I've brought in a team member",
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Return a dict of patch targets and their corresponding mocks.
|
||||||
|
Callers use this with contextlib.ExitStack or individual patch() calls.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"redis": make_fake_redis(escalated=escalated),
|
||||||
|
"session_factory": make_session_factory_mock(agent),
|
||||||
|
"run_agent_return": run_agent_return,
|
||||||
|
"escalation_rule": escalation_rule,
|
||||||
|
"escalate_return": escalate_return,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Task 1 Tests — handle_message extra-field extraction
|
# Task 1 Tests — handle_message extra-field extraction
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -242,44 +324,13 @@ class TestSendResponseRouting:
|
|||||||
class TestProcessMessageUsesSSendResponse:
|
class TestProcessMessageUsesSSendResponse:
|
||||||
"""_process_message must use _send_response for ALL response delivery — never direct _update_slack_placeholder."""
|
"""_process_message must use _send_response for ALL response delivery — never direct _update_slack_placeholder."""
|
||||||
|
|
||||||
def _make_fake_redis(self, escalated: bool = False) -> AsyncMock:
|
|
||||||
"""Build a fakeredis-like mock for the async client."""
|
|
||||||
redis_mock = AsyncMock()
|
|
||||||
redis_mock.get = AsyncMock(return_value=b"escalated" if escalated else None)
|
|
||||||
redis_mock.delete = AsyncMock()
|
|
||||||
redis_mock.set = AsyncMock()
|
|
||||||
redis_mock.setex = AsyncMock()
|
|
||||||
redis_mock.aclose = AsyncMock()
|
|
||||||
return redis_mock
|
|
||||||
|
|
||||||
def _make_session_mock(self, agent: Any) -> AsyncMock:
|
|
||||||
"""Build a mock async DB session that returns the given agent."""
|
|
||||||
session = AsyncMock()
|
|
||||||
# Scalars mock
|
|
||||||
scalars_mock = MagicMock()
|
|
||||||
scalars_mock.first.return_value = agent
|
|
||||||
result_mock = AsyncMock()
|
|
||||||
result_mock.scalars.return_value = scalars_mock
|
|
||||||
session.execute = AsyncMock(return_value=result_mock)
|
|
||||||
session.__aenter__ = AsyncMock(return_value=session)
|
|
||||||
session.__aexit__ = AsyncMock()
|
|
||||||
return session
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_message_calls_send_response_not_update_slack(self) -> None:
|
async def test_process_message_calls_send_response_not_update_slack(self) -> None:
|
||||||
"""_process_message must call _send_response for Slack, NOT _update_slack_placeholder directly."""
|
"""_process_message must call _send_response for Slack, NOT _update_slack_placeholder directly."""
|
||||||
from orchestrator.tasks import _process_message
|
from orchestrator.tasks import _process_message
|
||||||
from shared.models.message import KonstructMessage
|
|
||||||
|
|
||||||
agent = make_agent()
|
agent = make_agent()
|
||||||
tenant_id = str(uuid.uuid4())
|
msg = make_process_message_msg(channel="slack")
|
||||||
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
|
|
||||||
msg_data.pop("placeholder_ts", None)
|
|
||||||
msg_data.pop("channel_id", None)
|
|
||||||
msg_data.pop("phone_number_id", None)
|
|
||||||
msg_data.pop("bot_token", None)
|
|
||||||
msg = KonstructMessage.model_validate(msg_data)
|
|
||||||
|
|
||||||
extras = {
|
extras = {
|
||||||
"placeholder_ts": "1234.5678",
|
"placeholder_ts": "1234.5678",
|
||||||
"channel_id": "C123",
|
"channel_id": "C123",
|
||||||
@@ -287,11 +338,14 @@ class TestProcessMessageUsesSSendResponse:
|
|||||||
"bot_token": "",
|
"bot_token": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
redis_mock = make_fake_redis(escalated=False)
|
||||||
|
sf_mock = make_session_factory_mock(agent)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("orchestrator.tasks._send_response", new_callable=AsyncMock) as mock_send,
|
patch("orchestrator.tasks._send_response", new_callable=AsyncMock) as mock_send,
|
||||||
patch("orchestrator.tasks._update_slack_placeholder", new_callable=AsyncMock) as mock_update,
|
patch("orchestrator.tasks._update_slack_placeholder", new_callable=AsyncMock) as mock_update,
|
||||||
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
||||||
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
patch("orchestrator.tasks.async_session_factory", sf_mock),
|
||||||
patch("orchestrator.tasks.configure_rls_hook"),
|
patch("orchestrator.tasks.configure_rls_hook"),
|
||||||
patch("orchestrator.tasks.current_tenant_id"),
|
patch("orchestrator.tasks.current_tenant_id"),
|
||||||
patch("orchestrator.tasks.engine"),
|
patch("orchestrator.tasks.engine"),
|
||||||
@@ -306,15 +360,8 @@ class TestProcessMessageUsesSSendResponse:
|
|||||||
patch("orchestrator.tasks.AuditLogger"),
|
patch("orchestrator.tasks.AuditLogger"),
|
||||||
patch("orchestrator.tasks.check_escalation_rules", return_value=None),
|
patch("orchestrator.tasks.check_escalation_rules", return_value=None),
|
||||||
):
|
):
|
||||||
# Setup redis mock to return None for pre-check (not escalated)
|
|
||||||
redis_mock = self._make_fake_redis(escalated=False)
|
|
||||||
mock_aioredis.from_url.return_value = redis_mock
|
mock_aioredis.from_url.return_value = redis_mock
|
||||||
|
|
||||||
# Setup session to return our agent
|
|
||||||
session_mock = self._make_session_mock(agent)
|
|
||||||
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
|
||||||
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
|
||||||
|
|
||||||
await _process_message(msg, extras=extras)
|
await _process_message(msg, extras=extras)
|
||||||
|
|
||||||
# _send_response MUST be called for delivery
|
# _send_response MUST be called for delivery
|
||||||
@@ -326,37 +373,30 @@ class TestProcessMessageUsesSSendResponse:
|
|||||||
async def test_whatsapp_extras_passed_to_send_response(self) -> None:
|
async def test_whatsapp_extras_passed_to_send_response(self) -> None:
|
||||||
"""_process_message must pass phone_number_id, bot_token, wa_id to _send_response for WhatsApp."""
|
"""_process_message must pass phone_number_id, bot_token, wa_id to _send_response for WhatsApp."""
|
||||||
from orchestrator.tasks import _process_message
|
from orchestrator.tasks import _process_message
|
||||||
from shared.models.message import KonstructMessage
|
|
||||||
|
|
||||||
agent = make_agent()
|
agent = make_agent()
|
||||||
tenant_id = str(uuid.uuid4())
|
# user_id = "15551234567" -> becomes wa_id
|
||||||
msg_data = make_message_data(
|
msg = make_process_message_msg(channel="whatsapp", user_id="15551234567")
|
||||||
channel="whatsapp",
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
user_id="15551234567",
|
|
||||||
)
|
|
||||||
msg_data.pop("placeholder_ts", None)
|
|
||||||
msg_data.pop("channel_id", None)
|
|
||||||
msg_data.pop("phone_number_id", None)
|
|
||||||
msg_data.pop("bot_token", None)
|
|
||||||
msg = KonstructMessage.model_validate(msg_data)
|
|
||||||
|
|
||||||
extras = {
|
extras = {
|
||||||
"placeholder_ts": "",
|
"placeholder_ts": "",
|
||||||
"channel_id": "",
|
"channel_id": "",
|
||||||
"phone_number_id": "9876543210",
|
"phone_number_id": "9876543210",
|
||||||
"bot_token": "EAAaccess",
|
"bot_token": "EAAaccess",
|
||||||
|
"wa_id": "15551234567",
|
||||||
}
|
}
|
||||||
|
|
||||||
captured_extras: list[dict] = []
|
captured_extras: list[dict] = []
|
||||||
|
|
||||||
async def capture_send(channel: str, text: str, ext: dict) -> None:
|
async def capture_send(channel: Any, text: str, ext: dict) -> None:
|
||||||
captured_extras.append({"channel": channel, "text": text, "extras": ext})
|
captured_extras.append({"channel": str(channel), "text": text, "extras": ext})
|
||||||
|
|
||||||
|
redis_mock = make_fake_redis(escalated=False)
|
||||||
|
sf_mock = make_session_factory_mock(agent)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("orchestrator.tasks._send_response", side_effect=capture_send),
|
patch("orchestrator.tasks._send_response", side_effect=capture_send),
|
||||||
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
||||||
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
patch("orchestrator.tasks.async_session_factory", sf_mock),
|
||||||
patch("orchestrator.tasks.configure_rls_hook"),
|
patch("orchestrator.tasks.configure_rls_hook"),
|
||||||
patch("orchestrator.tasks.current_tenant_id"),
|
patch("orchestrator.tasks.current_tenant_id"),
|
||||||
patch("orchestrator.tasks.engine"),
|
patch("orchestrator.tasks.engine"),
|
||||||
@@ -371,15 +411,7 @@ class TestProcessMessageUsesSSendResponse:
|
|||||||
patch("orchestrator.tasks.AuditLogger"),
|
patch("orchestrator.tasks.AuditLogger"),
|
||||||
patch("orchestrator.tasks.check_escalation_rules", return_value=None),
|
patch("orchestrator.tasks.check_escalation_rules", return_value=None),
|
||||||
):
|
):
|
||||||
redis_mock = self._make_fake_redis(escalated=False)
|
|
||||||
mock_aioredis.from_url.return_value = redis_mock
|
mock_aioredis.from_url.return_value = redis_mock
|
||||||
# Configure context manager for redis client used in finally blocks
|
|
||||||
redis_mock.__aenter__ = AsyncMock(return_value=redis_mock)
|
|
||||||
redis_mock.__aexit__ = AsyncMock()
|
|
||||||
|
|
||||||
session_mock = self._make_session_mock(agent)
|
|
||||||
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
|
||||||
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
|
||||||
|
|
||||||
await _process_message(msg, extras=extras)
|
await _process_message(msg, extras=extras)
|
||||||
|
|
||||||
@@ -389,7 +421,7 @@ class TestProcessMessageUsesSSendResponse:
|
|||||||
wa_extras = wa_calls[0]["extras"]
|
wa_extras = wa_calls[0]["extras"]
|
||||||
assert wa_extras.get("phone_number_id") == "9876543210"
|
assert wa_extras.get("phone_number_id") == "9876543210"
|
||||||
assert wa_extras.get("bot_token") == "EAAaccess"
|
assert wa_extras.get("bot_token") == "EAAaccess"
|
||||||
# wa_id must come from sender.user_id
|
# wa_id must come from extras (which was populated from sender.user_id in handle_message)
|
||||||
assert wa_extras.get("wa_id") == "15551234567"
|
assert wa_extras.get("wa_id") == "15551234567"
|
||||||
|
|
||||||
|
|
||||||
@@ -401,42 +433,22 @@ class TestProcessMessageUsesSSendResponse:
|
|||||||
class TestEscalationPreCheck:
|
class TestEscalationPreCheck:
|
||||||
"""When escalation status is 'escalated' in Redis, _process_message must return early without LLM call."""
|
"""When escalation status is 'escalated' in Redis, _process_message must return early without LLM call."""
|
||||||
|
|
||||||
def _make_session_mock(self, agent: Any) -> AsyncMock:
|
|
||||||
session = AsyncMock()
|
|
||||||
scalars_mock = MagicMock()
|
|
||||||
scalars_mock.first.return_value = agent
|
|
||||||
result_mock = AsyncMock()
|
|
||||||
result_mock.scalars.return_value = scalars_mock
|
|
||||||
session.execute = AsyncMock(return_value=result_mock)
|
|
||||||
session.__aenter__ = AsyncMock(return_value=session)
|
|
||||||
session.__aexit__ = AsyncMock()
|
|
||||||
return session
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_escalated_conversation_skips_run_agent(self) -> None:
|
async def test_escalated_conversation_skips_run_agent(self) -> None:
|
||||||
"""When Redis shows 'escalated', run_agent must NOT be called."""
|
"""When Redis shows 'escalated', run_agent must NOT be called."""
|
||||||
from orchestrator.tasks import _process_message
|
from orchestrator.tasks import _process_message
|
||||||
from shared.models.message import KonstructMessage
|
|
||||||
|
|
||||||
agent = make_agent()
|
agent = make_agent()
|
||||||
tenant_id = str(uuid.uuid4())
|
msg = make_process_message_msg(channel="slack")
|
||||||
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
|
|
||||||
msg_data.pop("placeholder_ts", None)
|
|
||||||
msg_data.pop("channel_id", None)
|
|
||||||
msg_data.pop("phone_number_id", None)
|
|
||||||
msg_data.pop("bot_token", None)
|
|
||||||
msg = KonstructMessage.model_validate(msg_data)
|
|
||||||
|
|
||||||
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
||||||
|
|
||||||
redis_mock = AsyncMock()
|
redis_mock = make_fake_redis(escalated=True)
|
||||||
redis_mock.get = AsyncMock(return_value=b"escalated")
|
sf_mock = make_session_factory_mock(agent)
|
||||||
redis_mock.aclose = AsyncMock()
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
||||||
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
||||||
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
patch("orchestrator.tasks.async_session_factory", sf_mock),
|
||||||
patch("orchestrator.tasks.configure_rls_hook"),
|
patch("orchestrator.tasks.configure_rls_hook"),
|
||||||
patch("orchestrator.tasks.current_tenant_id"),
|
patch("orchestrator.tasks.current_tenant_id"),
|
||||||
patch("orchestrator.tasks.engine"),
|
patch("orchestrator.tasks.engine"),
|
||||||
@@ -444,12 +456,6 @@ class TestEscalationPreCheck:
|
|||||||
patch("orchestrator.tasks.AuditLogger"),
|
patch("orchestrator.tasks.AuditLogger"),
|
||||||
):
|
):
|
||||||
mock_aioredis.from_url.return_value = redis_mock
|
mock_aioredis.from_url.return_value = redis_mock
|
||||||
redis_mock.__aenter__ = AsyncMock(return_value=redis_mock)
|
|
||||||
redis_mock.__aexit__ = AsyncMock()
|
|
||||||
|
|
||||||
session_mock = self._make_session_mock(agent)
|
|
||||||
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
|
||||||
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
|
||||||
|
|
||||||
result = await _process_message(msg, extras=extras)
|
result = await _process_message(msg, extras=extras)
|
||||||
|
|
||||||
@@ -468,42 +474,24 @@ class TestEscalationPreCheck:
|
|||||||
class TestEscalationPostCheck:
|
class TestEscalationPostCheck:
|
||||||
"""check_escalation_rules called after run_agent; escalate_to_human called when rule matches."""
|
"""check_escalation_rules called after run_agent; escalate_to_human called when rule matches."""
|
||||||
|
|
||||||
def _make_session_mock(self, agent: Any) -> AsyncMock:
|
|
||||||
session = AsyncMock()
|
|
||||||
scalars_mock = MagicMock()
|
|
||||||
scalars_mock.first.return_value = agent
|
|
||||||
result_mock = AsyncMock()
|
|
||||||
result_mock.scalars.return_value = scalars_mock
|
|
||||||
session.execute = AsyncMock(return_value=result_mock)
|
|
||||||
session.__aenter__ = AsyncMock(return_value=session)
|
|
||||||
session.__aexit__ = AsyncMock()
|
|
||||||
return session
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_escalation_rules_called_after_run_agent(self) -> None:
|
async def test_check_escalation_rules_called_after_run_agent(self) -> None:
|
||||||
"""check_escalation_rules must be called after run_agent returns."""
|
"""check_escalation_rules must be called after run_agent returns."""
|
||||||
from orchestrator.tasks import _process_message
|
from orchestrator.tasks import _process_message
|
||||||
from shared.models.message import KonstructMessage
|
|
||||||
|
|
||||||
agent = make_agent(escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}])
|
|
||||||
tenant_id = str(uuid.uuid4())
|
|
||||||
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
|
|
||||||
msg_data.pop("placeholder_ts", None)
|
|
||||||
msg_data.pop("channel_id", None)
|
|
||||||
msg_data.pop("phone_number_id", None)
|
|
||||||
msg_data.pop("bot_token", None)
|
|
||||||
msg = KonstructMessage.model_validate(msg_data)
|
|
||||||
|
|
||||||
|
agent = make_agent(
|
||||||
|
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}]
|
||||||
|
)
|
||||||
|
msg = make_process_message_msg(channel="slack")
|
||||||
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
||||||
|
|
||||||
redis_mock = AsyncMock()
|
redis_mock = make_fake_redis(escalated=False)
|
||||||
redis_mock.get = AsyncMock(return_value=None)
|
sf_mock = make_session_factory_mock(agent)
|
||||||
redis_mock.aclose = AsyncMock()
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
||||||
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
||||||
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
patch("orchestrator.tasks.async_session_factory", sf_mock),
|
||||||
patch("orchestrator.tasks.configure_rls_hook"),
|
patch("orchestrator.tasks.configure_rls_hook"),
|
||||||
patch("orchestrator.tasks.current_tenant_id"),
|
patch("orchestrator.tasks.current_tenant_id"),
|
||||||
patch("orchestrator.tasks.engine"),
|
patch("orchestrator.tasks.engine"),
|
||||||
@@ -520,10 +508,6 @@ class TestEscalationPostCheck:
|
|||||||
):
|
):
|
||||||
mock_aioredis.from_url.return_value = redis_mock
|
mock_aioredis.from_url.return_value = redis_mock
|
||||||
|
|
||||||
session_mock = self._make_session_mock(agent)
|
|
||||||
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
|
||||||
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
|
||||||
|
|
||||||
await _process_message(msg, extras=extras)
|
await _process_message(msg, extras=extras)
|
||||||
|
|
||||||
# check_escalation_rules must have been called
|
# check_escalation_rules must have been called
|
||||||
@@ -533,32 +517,22 @@ class TestEscalationPostCheck:
|
|||||||
async def test_escalate_to_human_called_when_rule_matches_and_assignee_set(self) -> None:
|
async def test_escalate_to_human_called_when_rule_matches_and_assignee_set(self) -> None:
|
||||||
"""When check_escalation_rules returns a rule and assignee is set, escalate_to_human must be called."""
|
"""When check_escalation_rules returns a rule and assignee is set, escalate_to_human must be called."""
|
||||||
from orchestrator.tasks import _process_message
|
from orchestrator.tasks import _process_message
|
||||||
from shared.models.message import KonstructMessage
|
|
||||||
|
|
||||||
agent = make_agent(
|
agent = make_agent(
|
||||||
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}],
|
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}],
|
||||||
escalation_assignee="U_MANAGER",
|
escalation_assignee="U_MANAGER",
|
||||||
)
|
)
|
||||||
tenant_id = str(uuid.uuid4())
|
msg = make_process_message_msg(channel="slack", text="refund issue again")
|
||||||
msg_data = make_message_data(channel="slack", tenant_id=tenant_id, text="refund issue again")
|
|
||||||
msg_data.pop("placeholder_ts", None)
|
|
||||||
msg_data.pop("channel_id", None)
|
|
||||||
msg_data.pop("phone_number_id", None)
|
|
||||||
msg_data.pop("bot_token", None)
|
|
||||||
msg = KonstructMessage.model_validate(msg_data)
|
|
||||||
|
|
||||||
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
||||||
|
|
||||||
matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}
|
matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}
|
||||||
|
redis_mock = make_fake_redis(escalated=False)
|
||||||
redis_mock = AsyncMock()
|
sf_mock = make_session_factory_mock(agent)
|
||||||
redis_mock.get = AsyncMock(return_value=None)
|
|
||||||
redis_mock.aclose = AsyncMock()
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
||||||
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
||||||
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
patch("orchestrator.tasks.async_session_factory", sf_mock),
|
||||||
patch("orchestrator.tasks.configure_rls_hook"),
|
patch("orchestrator.tasks.configure_rls_hook"),
|
||||||
patch("orchestrator.tasks.current_tenant_id"),
|
patch("orchestrator.tasks.current_tenant_id"),
|
||||||
patch("orchestrator.tasks.engine"),
|
patch("orchestrator.tasks.engine"),
|
||||||
@@ -572,51 +546,41 @@ class TestEscalationPostCheck:
|
|||||||
patch("orchestrator.tasks.get_tools_for_agent", return_value=[]),
|
patch("orchestrator.tasks.get_tools_for_agent", return_value=[]),
|
||||||
patch("orchestrator.tasks.AuditLogger"),
|
patch("orchestrator.tasks.AuditLogger"),
|
||||||
patch("orchestrator.tasks.check_escalation_rules", return_value=matched_rule),
|
patch("orchestrator.tasks.check_escalation_rules", return_value=matched_rule),
|
||||||
patch("orchestrator.tasks.escalate_to_human", new_callable=AsyncMock, return_value="I've brought in a team member") as mock_escalate,
|
patch(
|
||||||
|
"orchestrator.tasks.escalate_to_human",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value="I've brought in a team member",
|
||||||
|
) as mock_escalate,
|
||||||
):
|
):
|
||||||
mock_aioredis.from_url.return_value = redis_mock
|
mock_aioredis.from_url.return_value = redis_mock
|
||||||
|
|
||||||
session_mock = self._make_session_mock(agent)
|
|
||||||
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
|
||||||
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
|
||||||
|
|
||||||
result = await _process_message(msg, extras=extras)
|
result = await _process_message(msg, extras=extras)
|
||||||
|
|
||||||
# escalate_to_human must be called
|
# escalate_to_human must be called
|
||||||
mock_escalate.assert_called_once()
|
mock_escalate.assert_called_once()
|
||||||
# Response should be the escalation confirmation
|
# Response should be the escalation confirmation
|
||||||
assert "team member" in result["response"] or mock_escalate.return_value in result["response"]
|
assert "team member" in result["response"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_escalate_to_human_not_called_when_no_assignee(self) -> None:
|
async def test_escalate_to_human_not_called_when_no_assignee(self) -> None:
|
||||||
"""When rule matches but escalation_assignee is None, escalate_to_human must NOT be called."""
|
"""When rule matches but escalation_assignee is None, escalate_to_human must NOT be called."""
|
||||||
from orchestrator.tasks import _process_message
|
from orchestrator.tasks import _process_message
|
||||||
from shared.models.message import KonstructMessage
|
|
||||||
|
|
||||||
agent = make_agent(
|
agent = make_agent(
|
||||||
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}],
|
escalation_rules=[{"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}],
|
||||||
escalation_assignee=None,
|
escalation_assignee=None,
|
||||||
)
|
)
|
||||||
tenant_id = str(uuid.uuid4())
|
msg = make_process_message_msg(channel="slack")
|
||||||
msg_data = make_message_data(channel="slack", tenant_id=tenant_id)
|
|
||||||
msg_data.pop("placeholder_ts", None)
|
|
||||||
msg_data.pop("channel_id", None)
|
|
||||||
msg_data.pop("phone_number_id", None)
|
|
||||||
msg_data.pop("bot_token", None)
|
|
||||||
msg = KonstructMessage.model_validate(msg_data)
|
|
||||||
|
|
||||||
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
extras = {"placeholder_ts": "1234.5678", "channel_id": "C123", "phone_number_id": "", "bot_token": ""}
|
||||||
|
|
||||||
matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}
|
matched_rule = {"condition": "billing_dispute AND attempts > 2", "action": "handoff_human"}
|
||||||
|
redis_mock = make_fake_redis(escalated=False)
|
||||||
redis_mock = AsyncMock()
|
sf_mock = make_session_factory_mock(agent)
|
||||||
redis_mock.get = AsyncMock(return_value=None)
|
|
||||||
redis_mock.aclose = AsyncMock()
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
patch("orchestrator.tasks._send_response", new_callable=AsyncMock),
|
||||||
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
patch("orchestrator.tasks.aioredis") as mock_aioredis,
|
||||||
patch("orchestrator.tasks.async_session_factory") as mock_session_factory,
|
patch("orchestrator.tasks.async_session_factory", sf_mock),
|
||||||
patch("orchestrator.tasks.configure_rls_hook"),
|
patch("orchestrator.tasks.configure_rls_hook"),
|
||||||
patch("orchestrator.tasks.current_tenant_id"),
|
patch("orchestrator.tasks.current_tenant_id"),
|
||||||
patch("orchestrator.tasks.engine"),
|
patch("orchestrator.tasks.engine"),
|
||||||
@@ -634,10 +598,6 @@ class TestEscalationPostCheck:
|
|||||||
):
|
):
|
||||||
mock_aioredis.from_url.return_value = redis_mock
|
mock_aioredis.from_url.return_value = redis_mock
|
||||||
|
|
||||||
session_mock = self._make_session_mock(agent)
|
|
||||||
mock_session_factory.return_value.__aenter__ = AsyncMock(return_value=session_mock)
|
|
||||||
mock_session_factory.return_value.__aexit__ = AsyncMock()
|
|
||||||
|
|
||||||
await _process_message(msg, extras=extras)
|
await _process_message(msg, extras=extras)
|
||||||
|
|
||||||
# escalate_to_human must NOT be called — no assignee configured
|
# escalate_to_human must NOT be called — no assignee configured
|
||||||
|
|||||||
Reference in New Issue
Block a user