- ToolDefinition Pydantic model with JSON Schema parameters + handler - BUILTIN_TOOLS: web_search, kb_search, http_request, calendar_lookup - http_request requires_confirmation=True (outbound side effects) - get_tools_for_agent filters by agent.tool_assignments - to_litellm_format converts to OpenAI function-calling schema - execute_tool: jsonschema validation before handler call - execute_tool: confirmation gate for requires_confirmation=True - execute_tool: audit logging on every invocation (success + failure) - web_search: Brave Search API with BRAVE_API_KEY env var - kb_search: pgvector cosine similarity with HNSW index - http_request: 30s timeout, 1MB cap, GET/POST/PUT/DELETE only - calendar_lookup: Google Calendar events.list read-only - jsonschema dependency added to orchestrator pyproject.toml - [Rule 1 - Bug] Added missing execute_tool import in test
287 lines
8.6 KiB
Python
287 lines
8.6 KiB
Python
"""
|
|
Unit tests for the tool executor.
|
|
|
|
Tests:
|
|
- Valid args pass schema validation and handler is called
|
|
- Invalid args are rejected before the handler is called
|
|
- Unknown tool name raises ValueError
|
|
- requires_confirmation=True returns confirmation message without executing
|
|
- audit_logger.log_tool_call is called on every invocation
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def make_registry(requires_confirmation: bool = False, handler_return: str = "tool result") -> dict:
|
|
"""Build a minimal tool registry for testing."""
|
|
from orchestrator.tools.registry import ToolDefinition
|
|
|
|
async def mock_handler(**kwargs: object) -> str:
|
|
return handler_return
|
|
|
|
return {
|
|
"test_tool": ToolDefinition(
|
|
name="test_tool",
|
|
description="A test tool",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {"type": "string"},
|
|
"count": {"type": "integer"},
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
requires_confirmation=requires_confirmation,
|
|
handler=mock_handler,
|
|
)
|
|
}
|
|
|
|
|
|
def make_audit_logger() -> MagicMock:
|
|
"""Return a mock AuditLogger."""
|
|
mock = MagicMock()
|
|
mock.log_tool_call = AsyncMock()
|
|
return mock
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestExecuteToolValidArgs:
|
|
"""execute_tool with valid args calls the handler and returns result."""
|
|
|
|
async def test_valid_args_returns_result(self):
|
|
from orchestrator.tools.executor import execute_tool
|
|
|
|
registry = make_registry(handler_return="search results here")
|
|
audit_logger = make_audit_logger()
|
|
|
|
tool_call = {
|
|
"function": {
|
|
"name": "test_tool",
|
|
"arguments": '{"query": "hello world"}',
|
|
}
|
|
}
|
|
|
|
result = await execute_tool(
|
|
tool_call=tool_call,
|
|
registry=registry,
|
|
tenant_id=uuid.uuid4(),
|
|
agent_id=uuid.uuid4(),
|
|
audit_logger=audit_logger,
|
|
)
|
|
|
|
assert result == "search results here"
|
|
|
|
async def test_audit_logger_called_on_success(self):
|
|
from orchestrator.tools.executor import execute_tool
|
|
|
|
registry = make_registry()
|
|
audit_logger = make_audit_logger()
|
|
|
|
tool_call = {
|
|
"function": {
|
|
"name": "test_tool",
|
|
"arguments": '{"query": "test"}',
|
|
}
|
|
}
|
|
|
|
await execute_tool(
|
|
tool_call=tool_call,
|
|
registry=registry,
|
|
tenant_id=uuid.uuid4(),
|
|
agent_id=uuid.uuid4(),
|
|
audit_logger=audit_logger,
|
|
)
|
|
|
|
audit_logger.log_tool_call.assert_called_once()
|
|
call_kwargs = audit_logger.log_tool_call.call_args[1]
|
|
assert call_kwargs["tool_name"] == "test_tool"
|
|
assert call_kwargs["error"] is None
|
|
|
|
|
|
class TestExecuteToolInvalidArgs:
|
|
"""execute_tool with invalid args rejects before calling handler."""
|
|
|
|
async def test_missing_required_arg_returns_error(self):
|
|
from orchestrator.tools.executor import execute_tool
|
|
|
|
registry = make_registry()
|
|
audit_logger = make_audit_logger()
|
|
|
|
# 'query' is required but not provided
|
|
tool_call = {
|
|
"function": {
|
|
"name": "test_tool",
|
|
"arguments": '{"count": 5}',
|
|
}
|
|
}
|
|
|
|
result = await execute_tool(
|
|
tool_call=tool_call,
|
|
registry=registry,
|
|
tenant_id=uuid.uuid4(),
|
|
agent_id=uuid.uuid4(),
|
|
audit_logger=audit_logger,
|
|
)
|
|
|
|
assert "error" in result.lower() or "invalid" in result.lower() or "required" in result.lower()
|
|
|
|
async def test_wrong_type_returns_error(self):
|
|
from orchestrator.tools.executor import execute_tool
|
|
|
|
registry = make_registry()
|
|
audit_logger = make_audit_logger()
|
|
|
|
# count should be integer but string provided
|
|
tool_call = {
|
|
"function": {
|
|
"name": "test_tool",
|
|
"arguments": '{"query": "hello", "count": "not_a_number"}',
|
|
}
|
|
}
|
|
|
|
result = await execute_tool(
|
|
tool_call=tool_call,
|
|
registry=registry,
|
|
tenant_id=uuid.uuid4(),
|
|
agent_id=uuid.uuid4(),
|
|
audit_logger=audit_logger,
|
|
)
|
|
|
|
assert "error" in result.lower() or "invalid" in result.lower()
|
|
|
|
async def test_audit_logger_called_with_error_on_invalid_args(self):
|
|
from orchestrator.tools.executor import execute_tool
|
|
|
|
registry = make_registry()
|
|
audit_logger = make_audit_logger()
|
|
|
|
tool_call = {
|
|
"function": {
|
|
"name": "test_tool",
|
|
"arguments": '{}', # Missing required 'query'
|
|
}
|
|
}
|
|
|
|
await execute_tool(
|
|
tool_call=tool_call,
|
|
registry=registry,
|
|
tenant_id=uuid.uuid4(),
|
|
agent_id=uuid.uuid4(),
|
|
audit_logger=audit_logger,
|
|
)
|
|
|
|
audit_logger.log_tool_call.assert_called_once()
|
|
call_kwargs = audit_logger.log_tool_call.call_args[1]
|
|
assert call_kwargs["error"] is not None
|
|
|
|
|
|
class TestExecuteToolUnknownTool:
|
|
"""execute_tool with unknown tool name returns an error string."""
|
|
|
|
async def test_unknown_tool_returns_error(self):
|
|
from orchestrator.tools.executor import execute_tool
|
|
|
|
registry = make_registry()
|
|
audit_logger = make_audit_logger()
|
|
|
|
tool_call = {
|
|
"function": {
|
|
"name": "nonexistent_tool",
|
|
"arguments": "{}",
|
|
}
|
|
}
|
|
|
|
result = await execute_tool(
|
|
tool_call=tool_call,
|
|
registry=registry,
|
|
tenant_id=uuid.uuid4(),
|
|
agent_id=uuid.uuid4(),
|
|
audit_logger=audit_logger,
|
|
)
|
|
|
|
assert "unknown" in result.lower() or "not found" in result.lower() or "error" in result.lower()
|
|
|
|
|
|
class TestExecuteToolConfirmation:
|
|
"""execute_tool with requires_confirmation=True returns confirmation message."""
|
|
|
|
async def test_confirmation_required_returns_confirmation_message(self):
|
|
from orchestrator.tools.executor import execute_tool
|
|
|
|
registry = make_registry(requires_confirmation=True)
|
|
audit_logger = make_audit_logger()
|
|
|
|
tool_call = {
|
|
"function": {
|
|
"name": "test_tool",
|
|
"arguments": '{"query": "confirm this"}',
|
|
}
|
|
}
|
|
|
|
result = await execute_tool(
|
|
tool_call=tool_call,
|
|
registry=registry,
|
|
tenant_id=uuid.uuid4(),
|
|
agent_id=uuid.uuid4(),
|
|
audit_logger=audit_logger,
|
|
)
|
|
|
|
# Should return a confirmation message, not the tool result
|
|
assert "confirm" in result.lower() or "permission" in result.lower() or "approval" in result.lower()
|
|
|
|
async def test_confirmation_required_does_not_call_handler(self):
|
|
"""Handler must NOT be called when requires_confirmation=True."""
|
|
from orchestrator.tools.executor import execute_tool
|
|
from orchestrator.tools.registry import ToolDefinition
|
|
|
|
handler_called = False
|
|
|
|
async def tracking_handler(**kwargs: object) -> str:
|
|
nonlocal handler_called
|
|
handler_called = True
|
|
return "should not be called"
|
|
|
|
registry = {
|
|
"test_tool": ToolDefinition(
|
|
name="test_tool",
|
|
description="A test tool",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {"query": {"type": "string"}},
|
|
"required": ["query"],
|
|
},
|
|
requires_confirmation=True,
|
|
handler=tracking_handler,
|
|
)
|
|
}
|
|
audit_logger = make_audit_logger()
|
|
|
|
tool_call = {
|
|
"function": {
|
|
"name": "test_tool",
|
|
"arguments": '{"query": "test"}',
|
|
}
|
|
}
|
|
|
|
await execute_tool(
|
|
tool_call=tool_call,
|
|
registry=registry,
|
|
tenant_id=uuid.uuid4(),
|
|
agent_id=uuid.uuid4(),
|
|
audit_logger=audit_logger,
|
|
)
|
|
|
|
assert not handler_called, "Handler should not be called when requires_confirmation=True"
|