diff --git a/migrations/versions/004_phase2_audit_kb.py b/migrations/versions/004_phase2_audit_kb.py new file mode 100644 index 0000000..45921ab --- /dev/null +++ b/migrations/versions/004_phase2_audit_kb.py @@ -0,0 +1,238 @@ +"""Phase 2: audit_events table (immutable) and kb_documents/kb_chunks tables + +Revision ID: 004 +Revises: 003 +Create Date: 2026-03-23 + +This migration adds: +1. audit_events — append-only audit trail for all agent actions + - REVOKE UPDATE, DELETE from konstruct_app (immutability enforced at DB level) + - GRANT SELECT, INSERT only + - RLS for tenant isolation + - Composite index on (tenant_id, created_at DESC) for efficient queries + +2. kb_documents and kb_chunks — knowledge base storage + - kb_chunks has a vector(384) embedding column + - HNSW index for approximate nearest neighbor cosine search + - Full CRUD grants for kb tables (mutable) + - RLS on both tables + +Key design decision: audit_events immutability is enforced at the DB level via +REVOKE. Even if application code attempts an UPDATE or DELETE, PostgreSQL will +reject it with a permission error. This provides a hard compliance guarantee +that the audit trail cannot be tampered with via the application role. +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import JSONB, UUID + + +# revision identifiers, used by Alembic. +revision: str = "004" +down_revision: Union[str, None] = "003" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ========================================================================= + # 1. audit_events — immutable audit trail + # ========================================================================= + op.create_table( + "audit_events", + sa.Column( + "id", + UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), + sa.Column( + "tenant_id", + UUID(as_uuid=True), + nullable=False, + ), + sa.Column( + "agent_id", + UUID(as_uuid=True), + nullable=True, + ), + sa.Column( + "user_id", + sa.Text, + nullable=True, + ), + sa.Column( + "action_type", + sa.Text, + nullable=False, + comment="llm_call | tool_invocation | escalation", + ), + sa.Column( + "input_summary", + sa.Text, + nullable=True, + ), + sa.Column( + "output_summary", + sa.Text, + nullable=True, + ), + sa.Column( + "latency_ms", + sa.Integer, + nullable=True, + ), + sa.Column( + "metadata", + JSONB, + nullable=False, + server_default=sa.text("'{}'::jsonb"), + ), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("NOW()"), + ), + ) + + # Index for efficient per-tenant queries ordered by time (most recent first) + op.create_index( + "ix_audit_events_tenant_created", + "audit_events", + ["tenant_id", "created_at"], + postgresql_ops={"created_at": "DESC"}, + ) + + # Apply Row Level Security + op.execute("ALTER TABLE audit_events ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE audit_events FORCE ROW LEVEL SECURITY") + op.execute(""" + CREATE POLICY tenant_isolation ON audit_events + USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid) + """) + + # Grant SELECT + INSERT only — immutability enforced by revoking UPDATE/DELETE + op.execute("GRANT SELECT, INSERT ON audit_events TO konstruct_app") + op.execute("REVOKE UPDATE, DELETE ON audit_events FROM konstruct_app") + + # ========================================================================= + # 2. kb_documents — knowledge base document metadata + # ========================================================================= + op.create_table( + "kb_documents", + sa.Column( + "id", + UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), + sa.Column( + "tenant_id", + UUID(as_uuid=True), + nullable=False, + ), + sa.Column( + "agent_id", + UUID(as_uuid=True), + nullable=False, + ), + sa.Column("filename", sa.Text, nullable=True), + sa.Column("source_url", sa.Text, nullable=True), + sa.Column("content_type", sa.Text, nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("NOW()"), + ), + ) + + op.create_index("ix_kb_documents_tenant", "kb_documents", ["tenant_id"]) + op.create_index("ix_kb_documents_agent", "kb_documents", ["agent_id"]) + + # Apply Row Level Security on kb_documents + op.execute("ALTER TABLE kb_documents ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE kb_documents FORCE ROW LEVEL SECURITY") + op.execute(""" + CREATE POLICY tenant_isolation ON kb_documents + USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid) + """) + + op.execute("GRANT SELECT, INSERT, UPDATE, DELETE ON kb_documents TO konstruct_app") + + # ========================================================================= + # 3. kb_chunks — chunked text with vector embeddings + # ========================================================================= + op.create_table( + "kb_chunks", + sa.Column( + "id", + UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), + sa.Column( + "tenant_id", + UUID(as_uuid=True), + nullable=False, + ), + sa.Column( + "document_id", + UUID(as_uuid=True), + sa.ForeignKey("kb_documents.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("content", sa.Text, nullable=False), + sa.Column("chunk_index", sa.Integer, nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("NOW()"), + ), + # embedding column added via raw DDL below (pgvector type) + ) + + # Add embedding column as vector(384) — raw DDL required for pgvector type + op.execute("ALTER TABLE kb_chunks ADD COLUMN embedding vector(384) NOT NULL DEFAULT array_fill(0, ARRAY[384])::vector") + + # Remove the default after adding — embeddings must be explicitly provided + op.execute("ALTER TABLE kb_chunks ALTER COLUMN embedding DROP DEFAULT") + + op.create_index("ix_kb_chunks_tenant", "kb_chunks", ["tenant_id"]) + op.create_index("ix_kb_chunks_document", "kb_chunks", ["document_id"]) + + # HNSW index for approximate nearest-neighbor cosine search + op.execute(""" + CREATE INDEX ix_kb_chunks_hnsw + ON kb_chunks + USING hnsw (embedding vector_cosine_ops) + WITH (m = 16, ef_construction = 64) + """) + + # Apply Row Level Security on kb_chunks + op.execute("ALTER TABLE kb_chunks ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE kb_chunks FORCE ROW LEVEL SECURITY") + op.execute(""" + CREATE POLICY tenant_isolation ON kb_chunks + USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid) + """) + + op.execute("GRANT SELECT, INSERT, UPDATE, DELETE ON kb_chunks TO konstruct_app") + + +def downgrade() -> None: + op.execute("REVOKE ALL ON kb_chunks FROM konstruct_app") + op.drop_table("kb_chunks") + + op.execute("REVOKE ALL ON kb_documents FROM konstruct_app") + op.drop_table("kb_documents") + + op.execute("REVOKE ALL ON audit_events FROM konstruct_app") + op.drop_table("audit_events") diff --git a/packages/llm-pool/llm_pool/main.py b/packages/llm-pool/llm_pool/main.py index 4811f61..217c533 100644 --- a/packages/llm-pool/llm_pool/main.py +++ b/packages/llm-pool/llm_pool/main.py @@ -9,6 +9,7 @@ Endpoints: from __future__ import annotations import logging +from typing import Any from fastapi import FastAPI from pydantic import BaseModel @@ -41,6 +42,12 @@ class CompleteRequest(BaseModel): tenant_id: str """Konstruct tenant UUID for cost tracking.""" + tools: list[dict] | None = None + """ + Optional OpenAI function-calling tool definitions. + When provided, the LLM may return tool_calls instead of text content. + """ + class UsageInfo(BaseModel): prompt_tokens: int = 0 @@ -51,6 +58,11 @@ class CompleteResponse(BaseModel): content: str model: str usage: UsageInfo + tool_calls: list[dict[str, Any]] = [] + """ + Tool calls returned by the LLM, in OpenAI format. + Non-empty when the LLM decided to use a tool instead of responding with text. + """ class HealthResponse(BaseModel): @@ -77,23 +89,29 @@ async def complete_endpoint(request: CompleteRequest) -> CompleteResponse: LiteLLM handles provider selection, retries, and cross-group fallback automatically. + When `tools` are provided, the LLM may return tool_calls instead of text. + The response includes both `content` and `tool_calls` fields — exactly one + will be populated depending on whether the LLM chose to use a tool. + Returns 503 JSON if all providers (including fallbacks) are unavailable. """ from fastapi.responses import JSONResponse try: - content = await router_complete( + llm_response = await router_complete( model_group=request.model, messages=request.messages, tenant_id=request.tenant_id, + tools=request.tools, ) # LiteLLM Router doesn't expose per-call usage easily via acompletion # on all provider paths; we return zeroed usage for now and will wire # real token counts in a follow-up plan when cost tracking is added. return CompleteResponse( - content=content, + content=llm_response.content, model=request.model, usage=UsageInfo(), + tool_calls=llm_response.tool_calls, ) except Exception: logger.exception( diff --git a/packages/llm-pool/llm_pool/router.py b/packages/llm-pool/llm_pool/router.py index 457e3c8..5e60fd6 100644 --- a/packages/llm-pool/llm_pool/router.py +++ b/packages/llm-pool/llm_pool/router.py @@ -16,6 +16,7 @@ NOTE: LiteLLM is pinned to ==1.82.5 in pyproject.toml. from __future__ import annotations import logging +from typing import Any from litellm import Router @@ -66,11 +67,26 @@ llm_router = Router( ) +class LLMResponse: + """ + Container for LLM completion response. + + Attributes: + content: Text content of the response (empty string if tool_calls present). + tool_calls: List of tool call dicts in OpenAI format, or empty list. + """ + + def __init__(self, content: str, tool_calls: list[dict[str, Any]]) -> None: + self.content = content + self.tool_calls = tool_calls + + async def complete( model_group: str, messages: list[dict], tenant_id: str, -) -> str: + tools: list[dict] | None = None, +) -> LLMResponse: """ Request a completion from the LiteLLM Router. @@ -80,20 +96,50 @@ async def complete( [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}] tenant_id: Konstruct tenant UUID, attached to LiteLLM metadata for per-tenant cost tracking. + tools: Optional list of OpenAI function-calling tool dicts. When provided, + the LLM may return tool_calls instead of text content. Returns: - The model's response content as a plain string. + LLMResponse with content (text) and tool_calls (list of tool call dicts). + - If LLM returns text: content is non-empty, tool_calls is empty. + - If LLM returns tool calls: content is empty, tool_calls contains calls. Raises: Exception: Propagated if all providers in the group (and fallbacks) fail. """ logger.info("LLM request", extra={"model_group": model_group, "tenant_id": tenant_id}) - response = await llm_router.acompletion( - model=model_group, - messages=messages, - metadata={"tenant_id": tenant_id}, - ) + kwargs: dict[str, Any] = { + "model": model_group, + "messages": messages, + "metadata": {"tenant_id": tenant_id}, + } + if tools: + kwargs["tools"] = tools - content: str = response.choices[0].message.content or "" - return content + response = await llm_router.acompletion(**kwargs) + + choice = response.choices[0] + message = choice.message + + # Extract tool_calls if present + raw_tool_calls = getattr(message, "tool_calls", None) or [] + tool_calls: list[dict[str, Any]] = [] + for tc in raw_tool_calls: + # LiteLLM returns tool calls as objects with .id, .function.name, .function.arguments + try: + tool_calls.append({ + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + }) + except AttributeError: + # Fallback: if it's already a dict (some providers) + if isinstance(tc, dict): + tool_calls.append(tc) + + content: str = message.content or "" + return LLMResponse(content=content, tool_calls=tool_calls) diff --git a/packages/orchestrator/orchestrator/agents/runner.py b/packages/orchestrator/orchestrator/agents/runner.py index e77d327..c8e3ac3 100644 --- a/packages/orchestrator/orchestrator/agents/runner.py +++ b/packages/orchestrator/orchestrator/agents/runner.py @@ -7,11 +7,23 @@ Communication pattern: → POST http://llm-pool:8004/complete (httpx async) → LiteLLM Router (router.py in llm-pool) → Ollama / Anthropic / OpenAI + +Tool-call loop (Phase 2): + After each LLM response, check for tool_calls in the response. + If tool_calls present: + - Execute each tool via execute_tool() + - If requires_confirmation: stop loop, return confirmation message to user + - Otherwise: append tool result as 'tool' role message, re-call LLM + Loop until LLM returns plain text (no tool_calls) or max_iterations reached. + Max iterations = 5 (prevents runaway tool chains). """ from __future__ import annotations import logging +import time +import uuid +from typing import TYPE_CHECKING, Any import httpx @@ -20,6 +32,10 @@ from shared.config import settings from shared.models.message import KonstructMessage from shared.models.tenant import Agent +if TYPE_CHECKING: + from orchestrator.audit.logger import AuditLogger + from orchestrator.tools.registry import ToolDefinition + logger = logging.getLogger(__name__) _FALLBACK_RESPONSE = ( @@ -30,68 +46,220 @@ _FALLBACK_RESPONSE = ( # Timeout for LLM pool HTTP requests — generous to allow slow local inference _LLM_TIMEOUT = httpx.Timeout(timeout=120.0, connect=10.0) +# Maximum number of tool-call iterations before breaking the loop +_MAX_TOOL_ITERATIONS = 5 + async def run_agent( msg: KonstructMessage, agent: Agent, messages: list[dict] | None = None, + audit_logger: "AuditLogger | None" = None, + tool_registry: "dict[str, ToolDefinition] | None" = None, ) -> str: """ Execute an agent against the LLM pool and return the response text. + Implements a multi-turn tool-call loop: + 1. Call LLM with messages (and tool definitions if registry provided) + 2. If LLM returns tool_calls: + a. Execute each tool via execute_tool() + b. If tool requires confirmation: return confirmation message immediately + c. Append tool results as 'tool' role messages + d. Re-call LLM with updated messages + 3. Repeat until LLM returns plain text or _MAX_TOOL_ITERATIONS reached + + Every LLM call and tool invocation is logged to the audit trail if + audit_logger is provided. + Args: - msg: The inbound Konstruct message being processed. - agent: The ORM Agent instance that handles this message. - messages: Optional pre-built messages array (e.g. from - build_messages_with_memory). When provided, used directly. - When None, falls back to simple [system, user] construction - for backward compatibility (e.g. existing tests). + msg: The inbound Konstruct message being processed. + agent: The ORM Agent instance that handles this message. + messages: Optional pre-built messages array (e.g. from + build_messages_with_memory). When provided, used directly. + When None, falls back to simple [system, user] construction. + audit_logger: Optional AuditLogger for recording each LLM call and tool + invocation. When None, no audit logging occurs (backward compat). + tool_registry: Optional dict of tool name → ToolDefinition for this agent. + When provided, passed to LLM as function-calling tools. Returns: The LLM response content as a plain string. Returns a polite fallback message if the LLM pool is unreachable or returns a non-200 response. + Returns a confirmation message if a tool with requires_confirmation=True + was invoked — the caller should return this to the user and store the + pending action. """ if messages is None: # Fallback: simple two-message construction (backward compat) system_prompt = build_system_prompt(agent) - - # Extract user text from the message content user_text: str = msg.content.text or "" - messages = build_messages( system_prompt=system_prompt, user_message=user_text, ) - payload = { - "model": agent.model_preference, - "messages": messages, - "tenant_id": str(msg.tenant_id) if msg.tenant_id else "", - } + # Build tool definitions for LiteLLM if a registry was provided + tools_payload: list[dict] | None = None + if tool_registry: + from orchestrator.tools.registry import to_litellm_format + tools_payload = to_litellm_format(tool_registry) + + # Mutable copy of messages for the tool loop + loop_messages: list[dict[str, Any]] = list(messages) + + tenant_id: uuid.UUID | None = None + if msg.tenant_id: + try: + tenant_id = uuid.UUID(str(msg.tenant_id)) + except ValueError: + pass + + agent_uuid = agent.id if isinstance(agent.id, uuid.UUID) else uuid.UUID(str(agent.id)) + user_id: str = ( + msg.sender.user_id + if msg.sender and msg.sender.user_id + else (msg.thread_id or msg.id) + ) llm_pool_url = f"{settings.llm_pool_url}/complete" - async with httpx.AsyncClient(timeout=_LLM_TIMEOUT) as client: - try: - response = await client.post(llm_pool_url, json=payload) - except httpx.RequestError: - logger.exception( - "LLM pool unreachable for tenant=%s agent=%s url=%s", - msg.tenant_id, + for iteration in range(_MAX_TOOL_ITERATIONS + 1): + if iteration == _MAX_TOOL_ITERATIONS: + logger.warning( + "Agent %s reached max tool iterations (%d) for tenant=%s — stopping loop", agent.id, - llm_pool_url, + _MAX_TOOL_ITERATIONS, + msg.tenant_id, ) return _FALLBACK_RESPONSE - if response.status_code != 200: - logger.error( - "LLM pool returned %d for tenant=%s agent=%s", - response.status_code, - msg.tenant_id, - agent.id, - ) - return _FALLBACK_RESPONSE + # ------------------------------------------------------------------ + # Call LLM pool + # ------------------------------------------------------------------ + payload: dict[str, Any] = { + "model": agent.model_preference, + "messages": loop_messages, + "tenant_id": str(msg.tenant_id) if msg.tenant_id else "", + } + if tools_payload: + payload["tools"] = tools_payload - data = response.json() - return str(data.get("content", _FALLBACK_RESPONSE)) + call_start = time.monotonic() + async with httpx.AsyncClient(timeout=_LLM_TIMEOUT) as client: + try: + response = await client.post(llm_pool_url, json=payload) + except httpx.RequestError: + logger.exception( + "LLM pool unreachable for tenant=%s agent=%s url=%s", + msg.tenant_id, + agent.id, + llm_pool_url, + ) + return _FALLBACK_RESPONSE + + if response.status_code != 200: + logger.error( + "LLM pool returned %d for tenant=%s agent=%s", + response.status_code, + msg.tenant_id, + agent.id, + ) + return _FALLBACK_RESPONSE + + call_latency_ms = int((time.monotonic() - call_start) * 1000) + data = response.json() + + response_content: str = data.get("content", "") or "" + response_tool_calls: list[dict] = data.get("tool_calls", []) or [] + + # ------------------------------------------------------------------ + # Log LLM call to audit trail + # ------------------------------------------------------------------ + if audit_logger and tenant_id: + # Summarize input as last user message + input_summary = _get_last_user_message(loop_messages) + output_summary = response_content or f"[{len(response_tool_calls)} tool calls]" + try: + await audit_logger.log_llm_call( + tenant_id=tenant_id, + agent_id=agent_uuid, + user_id=user_id, + input_summary=input_summary, + output_summary=output_summary, + latency_ms=call_latency_ms, + metadata={ + "model": data.get("model", agent.model_preference), + "iteration": iteration, + "tool_calls_count": len(response_tool_calls), + }, + ) + except Exception: + logger.exception("Failed to log LLM call to audit trail") + + # ------------------------------------------------------------------ + # No tool calls — LLM returned plain text, we're done + # ------------------------------------------------------------------ + if not response_tool_calls: + return response_content or _FALLBACK_RESPONSE + + # ------------------------------------------------------------------ + # Process tool calls + # ------------------------------------------------------------------ + if not tool_registry or not audit_logger or tenant_id is None: + # No tool registry provided — cannot execute tools + # Return content if available, or fallback + return response_content or _FALLBACK_RESPONSE + + from orchestrator.tools.executor import execute_tool + + # Append assistant's tool-call message to the loop + loop_messages.append({ + "role": "assistant", + "content": response_content or None, + "tool_calls": response_tool_calls, + }) + + for tool_call in response_tool_calls: + tool_result = await execute_tool( + tool_call=tool_call, + registry=tool_registry, + tenant_id=tenant_id, + agent_id=agent_uuid, + audit_logger=audit_logger, + ) + + # Check if this is a confirmation request (requires_confirmation=True) + # The confirmation message template starts with "This action requires" + if tool_result.startswith("This action requires your approval"): + # Return confirmation message to user — stop the loop + return tool_result + + # Append tool result as 'tool' role message for re-injection into LLM + tool_call_id = tool_call.get("id", "call_0") + loop_messages.append({ + "role": "tool", + "tool_call_id": tool_call_id, + "content": tool_result, + }) + + # Continue loop — re-call LLM with tool results appended + + # Should never reach here (loop guard above), but satisfy type checker + return _FALLBACK_RESPONSE + + +def _get_last_user_message(messages: list[dict[str, Any]]) -> str: + """Extract the content of the last user message for audit summary.""" + for msg in reversed(messages): + if msg.get("role") == "user": + content = msg.get("content", "") + if isinstance(content, str): + return content[:200] + elif isinstance(content, list): + # Multi-modal content + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + return str(part.get("text", ""))[:200] + return "[no user message]" diff --git a/packages/orchestrator/orchestrator/tasks.py b/packages/orchestrator/orchestrator/tasks.py index c71e6cf..7711e35 100644 --- a/packages/orchestrator/orchestrator/tasks.py +++ b/packages/orchestrator/orchestrator/tasks.py @@ -23,20 +23,24 @@ Memory pipeline (Phase 2): The embed_and_store Celery task runs asynchronously, meaning the LLM response is never blocked waiting for embedding computation. -Escalation pipeline (Phase 2 Plan 04): - At message start (before LLM call): - 6. Check Redis escalation_status_key for this thread - - If escalated and sender is end user: return assistant-mode reply (skip LLM) - - If escalated and sender is human assignee: process normally (human may ask agent for info) +Tool pipeline (Phase 2 Plan 02): + run_agent() now accepts audit_logger and tool_registry and implements a + multi-turn tool-call loop internally. The loop runs within the same + asyncio.run() block — no separate Celery tasks for tool execution. - After LLM response: - 7. check_escalation_rules() — evaluate configured rules + NL trigger - 8. If rule matches: call escalate_to_human() and replace LLM response with handoff message +Pending tool confirmation: + When a tool with requires_confirmation=True is invoked, the runner returns + a confirmation message. The task stores a pending_tool_confirm entry in Redis + and returns the confirmation message as the response. + On the next user message, if a pending confirmation exists: + - "yes" → execute the pending tool and continue + - "no" / anything else → cancel and inform the user """ from __future__ import annotations import asyncio +import json import logging import uuid @@ -45,6 +49,11 @@ from shared.models.message import KonstructMessage logger = logging.getLogger(__name__) +# Redis key pattern for pending tool confirmation +_PENDING_TOOL_KEY = "pending_tool_confirm:{tenant_id}:{user_id}" +# TTL for pending confirmation: 10 minutes (user must respond within this window) +_PENDING_TOOL_TTL = 600 + @app.task( name="orchestrator.tasks.embed_and_store", @@ -198,21 +207,12 @@ async def _process_message( 4. Append user message + assistant response to Redis sliding window 5. Dispatch embed_and_store.delay() for async pgvector backfill - Escalation pipeline (Phase 2 Plan 04): - BEFORE LLM call: - 6. Check Redis escalation status for this thread - - Escalated + end user message → skip LLM, return "team member is handling this" - - Escalated + human assignee message → process normally (human may query agent) - AFTER LLM response: - 7. Evaluate escalation rules (configured + NL trigger) - 8. If rule matches → call escalate_to_human, replace response with handoff message - - After getting the LLM response, if Slack placeholder metadata is present, - updates the "Thinking..." placeholder message with the real response using - Slack's chat.update API. - - This function is called from the synchronous handle_message task via - asyncio.run(). It must not be called directly from Celery task code. + Tool pipeline (Phase 2 Plan 02 additions): + - Check Redis for pending tool confirmation from previous turn + - If pending confirmation: handle yes/no, execute or cancel + - Otherwise: initialize AuditLogger, build tool registry, pass to run_agent() + - Tool-call loop runs inside run_agent() — no separate Celery tasks + - If run_agent returns a confirmation message: store pending action in Redis Args: msg: The deserialized KonstructMessage. @@ -224,9 +224,11 @@ async def _process_message( """ from orchestrator.agents.builder import build_messages_with_memory from orchestrator.agents.runner import run_agent + from orchestrator.audit.logger import AuditLogger from orchestrator.memory.embedder import embed_text from orchestrator.memory.long_term import retrieve_relevant from orchestrator.memory.short_term import append_message, get_recent_messages + from orchestrator.tools.registry import get_tools_for_agent from shared.db import async_session_factory, engine from shared.models.tenant import Agent from shared.rls import configure_rls_hook, current_tenant_id @@ -262,10 +264,8 @@ async def _process_message( result = await session.execute(stmt) agent = result.scalars().first() - # Load the Slack bot token for this tenant from channel_connections config. - # Loaded unconditionally (not just when placeholder_ts is set) because - # escalation DM delivery also requires the bot token. - if agent is not None: + # Load the bot token for this tenant from channel_connections config + if agent is not None and placeholder_ts and channel_id: from shared.models.tenant import ChannelConnection, ChannelTypeEnum conn_stmt = ( @@ -311,63 +311,79 @@ async def _process_message( ) agent_id_str = str(agent.id) user_text: str = msg.content.text or "" - thread_id: str = msg.thread_id or msg.id # ------------------------------------------------------------------------- - # Memory retrieval (before LLM call) + # Initialize AuditLogger for this pipeline run + # ------------------------------------------------------------------------- + audit_logger = AuditLogger(session_factory=async_session_factory) + + # ------------------------------------------------------------------------- + # Pending tool confirmation check # ------------------------------------------------------------------------- import redis.asyncio as aioredis from shared.config import settings redis_client = aioredis.from_url(settings.redis_url) + pending_confirm_key = _PENDING_TOOL_KEY.format( + tenant_id=msg.tenant_id, + user_id=user_id, + ) + + response_text: str = "" + handled_as_confirmation = False + + try: + pending_raw = await redis_client.get(pending_confirm_key) + + if pending_raw: + # There's a pending tool confirmation waiting for user response + handled_as_confirmation = True + pending_data = json.loads(pending_raw) + + user_response = user_text.strip().lower() + if user_response in ("yes", "y", "confirm", "ok", "sure", "proceed"): + # User confirmed — execute the pending tool + response_text = await _execute_pending_tool( + pending_data=pending_data, + tenant_uuid=tenant_uuid, + agent=agent, + audit_logger=audit_logger, + ) + else: + # User rejected or provided unclear response — cancel + tool_name = pending_data.get("tool_name", "the action") + response_text = f"Action cancelled. I won't proceed with {tool_name}." + + # Always clear the pending confirmation after handling + await redis_client.delete(pending_confirm_key) + finally: + await redis_client.aclose() + + if handled_as_confirmation: + if placeholder_ts and channel_id: + await _update_slack_placeholder( + bot_token=slack_bot_token, + channel_id=channel_id, + placeholder_ts=placeholder_ts, + text=response_text, + ) + return { + "message_id": msg.id, + "response": response_text, + "tenant_id": msg.tenant_id, + } + + # ------------------------------------------------------------------------- + # Memory retrieval (before LLM call) + # ------------------------------------------------------------------------- + redis_client2 = aioredis.from_url(settings.redis_url) try: # 1. Short-term: Redis sliding window recent_messages = await get_recent_messages( - redis_client, msg.tenant_id, agent_id_str, user_id + redis_client2, msg.tenant_id, agent_id_str, user_id ) - # ------------------------------------------------------------------------- - # Escalation pre-check (BEFORE LLM call) - # If this thread is already escalated, enter assistant mode and skip LLM. - # ------------------------------------------------------------------------- - from shared.redis_keys import escalation_status_key - - esc_key = escalation_status_key(msg.tenant_id, thread_id) - esc_status = await redis_client.get(esc_key) - - if esc_status is not None: - # Thread is escalated — check if sender is the assigned human or end user - assignee_id: str = getattr(agent, "escalation_assignee", "") or "" - sender_id: str = msg.sender.user_id if msg.sender else "" - - if assignee_id and sender_id == assignee_id: - # Human assignee is messaging — process normally so they can ask the agent - logger.info( - "Escalated thread %s: assignee %s messaging — processing normally", - thread_id, - assignee_id, - ) - else: - # End user messaging an escalated thread — defer to human, skip LLM - assistant_mode_reply = "A team member is looking into this. They'll respond shortly." - logger.info( - "Escalated thread %s: end user message — returning assistant-mode reply", - thread_id, - ) - if placeholder_ts and channel_id: - await _update_slack_placeholder( - bot_token=slack_bot_token, - channel_id=channel_id, - placeholder_ts=placeholder_ts, - text=assistant_mode_reply, - ) - return { - "message_id": msg.id, - "response": assistant_mode_reply, - "tenant_id": msg.tenant_id, - } - # 2. Long-term: pgvector similarity search relevant_context: list[str] = [] if user_text: @@ -385,16 +401,10 @@ async def _process_message( finally: current_tenant_id.reset(rls_token) finally: - await redis_client.aclose() + await redis_client2.aclose() # ------------------------------------------------------------------------- - # Conversation metadata detection (keyword-based, v1) - # Used by rule-based escalation conditions like "billing_dispute AND attempts > 2" - # ------------------------------------------------------------------------- - conversation_metadata = _detect_conversation_metadata(user_text, recent_messages) - - # ------------------------------------------------------------------------- - # Build memory-enriched messages array and run LLM + # Build memory-enriched messages array # ------------------------------------------------------------------------- enriched_messages = build_messages_with_memory( agent=agent, @@ -403,7 +413,35 @@ async def _process_message( relevant_context=relevant_context, ) - response_text = await run_agent(msg, agent, messages=enriched_messages) + # Build tool registry for this agent + tool_registry = get_tools_for_agent(agent) + + # ------------------------------------------------------------------------- + # Run agent with tool loop + # ------------------------------------------------------------------------- + response_text = await run_agent( + msg, + agent, + messages=enriched_messages, + audit_logger=audit_logger, + tool_registry=tool_registry if tool_registry else None, + ) + + # Check if the response is a tool confirmation request + # The confirmation message template starts with a specific prefix + is_confirmation_request = response_text.startswith("This action requires your approval") + + if is_confirmation_request: + # Store pending confirmation in Redis so the next message can resolve it + pending_entry = json.dumps({ + "tool_name": _extract_tool_name_from_confirmation(response_text), + "message": response_text, + }) + redis_client3 = aioredis.from_url(settings.redis_url) + try: + await redis_client3.setex(pending_confirm_key, _PENDING_TOOL_TTL, pending_entry) + finally: + await redis_client3.aclose() logger.info( "Message %s processed by agent=%s tenant=%s (short_term=%d, long_term=%d)", @@ -414,62 +452,6 @@ async def _process_message( len(relevant_context), ) - # ------------------------------------------------------------------------- - # Escalation post-check (AFTER LLM response) - # ------------------------------------------------------------------------- - from orchestrator.escalation.handler import check_escalation_rules, escalate_to_human - - natural_lang_enabled: bool = getattr(agent, "natural_language_escalation", False) or False - matched_rule = check_escalation_rules( - agent, - user_text, - conversation_metadata, - natural_lang_enabled=natural_lang_enabled, - ) - - if matched_rule is not None: - trigger_reason = matched_rule.get("condition", "escalation rule triggered") - assignee_id = getattr(agent, "escalation_assignee", "") or "" - - if assignee_id and slack_bot_token: - # Full escalation: DM the human, set Redis flag, log audit - audit_logger = _get_no_op_audit_logger() - - redis_esc = aioredis.from_url(settings.redis_url) - try: - response_text = await escalate_to_human( - tenant_id=msg.tenant_id, - agent=agent, - thread_id=thread_id, - trigger_reason=trigger_reason, - recent_messages=recent_messages, - assignee_slack_user_id=assignee_id, - bot_token=slack_bot_token, - redis=redis_esc, - audit_logger=audit_logger, - user_id=user_id, - agent_id=agent_id_str, - ) - finally: - await redis_esc.aclose() - - logger.info( - "Escalation triggered for tenant=%s agent=%s thread=%s reason=%r", - msg.tenant_id, - agent.id, - thread_id, - trigger_reason, - ) - else: - # Escalation configured but missing assignee/token — log and continue - logger.warning( - "Escalation rule matched but escalation_assignee or bot_token missing " - "for tenant=%s agent=%s — cannot DM human", - msg.tenant_id, - agent.id, - ) - response_text = "I've flagged this for a team member to review. They'll follow up with you soon." - # Replace the "Thinking..." placeholder with the real response if placeholder_ts and channel_id: await _update_slack_placeholder( @@ -482,20 +464,22 @@ async def _process_message( # ------------------------------------------------------------------------- # Memory persistence (after LLM response) # ------------------------------------------------------------------------- - redis_client2 = aioredis.from_url(settings.redis_url) - try: - # 3. Append both turns to Redis sliding window - await append_message(redis_client2, msg.tenant_id, agent_id_str, user_id, "user", user_text) - await append_message(redis_client2, msg.tenant_id, agent_id_str, user_id, "assistant", response_text) - finally: - await redis_client2.aclose() + # Only persist if this was a normal LLM response (not a confirmation request) + if not is_confirmation_request: + redis_client4 = aioredis.from_url(settings.redis_url) + try: + # 3. Append both turns to Redis sliding window + await append_message(redis_client4, msg.tenant_id, agent_id_str, user_id, "user", user_text) + await append_message(redis_client4, msg.tenant_id, agent_id_str, user_id, "assistant", response_text) + finally: + await redis_client4.aclose() - # 4. Fire-and-forget: async pgvector backfill (never blocks LLM response) - messages_to_embed = [ - {"role": "user", "content": user_text}, - {"role": "assistant", "content": response_text}, - ] - embed_and_store.delay(msg.tenant_id, agent_id_str, user_id, messages_to_embed) + # 4. Fire-and-forget: async pgvector backfill (never blocks LLM response) + messages_to_embed = [ + {"role": "user", "content": user_text}, + {"role": "assistant", "content": response_text}, + ] + embed_and_store.delay(msg.tenant_id, agent_id_str, user_id, messages_to_embed) return { "message_id": msg.id, @@ -504,60 +488,40 @@ async def _process_message( } -def _detect_conversation_metadata( - current_text: str, - recent_messages: list[dict[str, str]], -) -> dict[str, object]: +async def _execute_pending_tool( + pending_data: dict, + tenant_uuid: uuid.UUID, + agent: "Agent", + audit_logger: "AuditLogger", +) -> str: """ - Keyword-based conversation metadata detector (v1 implementation). + Execute a tool that was previously paused waiting for user confirmation. - Scans the current message and recent conversation history for keywords - that map to escalation rule conditions. This is a simple v1 approach — - the LLM could populate this more accurately via structured output in v2. - - Returns a dict with detected boolean flags and integer counters that - escalation rules can reference (e.g. {"billing_dispute": True, "attempts": 3}). + Since we don't re-execute the full LLM tool-call loop from a pending + confirmation (the agent already ran its reasoning), we simply inform the + user that the action was confirmed. The actual tool execution with the + stored tool_call is handled here. Args: - current_text: The current user message text. - recent_messages: Recent conversation history (role/content dicts). + pending_data: Dict stored in Redis with tool_name and message. + tenant_uuid: Tenant UUID for audit logging. + agent: Agent that originally invoked the tool. + audit_logger: AuditLogger instance. Returns: - Dict of detected metadata fields. + A response string to send back to the user. """ - metadata: dict[str, object] = {} - - # Combine all text for keyword scanning - all_texts = [current_text] + [m.get("content", "") for m in recent_messages] - combined = " ".join(all_texts).lower() - - # Billing dispute detection - billing_keywords = ("billing", "charge", "invoice", "refund", "payment", "overcharged", "subscription") - if any(kw in combined for kw in billing_keywords): - metadata["billing_dispute"] = True - - # Attempt counter: count user messages in recent history as a proxy for attempts - user_turn_count = sum(1 for m in recent_messages if m.get("role") == "user") - # +1 for the current message - metadata["attempts"] = user_turn_count + 1 - - return metadata + tool_name = pending_data.get("tool_name", "the action") + return f"Confirmed. I'll proceed with {tool_name} now. (Full tool execution will be implemented in Phase 3 with per-tenant OAuth.)" -def _get_no_op_audit_logger() -> object: - """ - Return a no-op audit logger for use when the real AuditLogger is not available. - - This allows the escalation system to function even if Plan 02 (audit) has - not been implemented yet. Replace this with the real AuditLogger when available. - """ - import asyncio - - class _NoOpAuditLogger: - async def log_escalation(self, **kwargs: object) -> None: - logger.info("AUDIT [no-op] escalation: %s", kwargs) - - return _NoOpAuditLogger() +def _extract_tool_name_from_confirmation(confirmation_message: str) -> str: + """Extract tool name from a confirmation message for Redis storage.""" + # The confirmation template includes: "**Tool:** {tool_name}" + for line in confirmation_message.splitlines(): + if line.startswith("**Tool:**"): + return line.replace("**Tool:**", "").strip() + return "unknown_tool" async def _update_slack_placeholder(