test(02-02): add failing tool registry and executor unit tests
- Tests for BUILTIN_TOOLS (4 tools present, correct fields, confirmation flags) - Tests for get_tools_for_agent filtering and to_litellm_format conversion - Tests for execute_tool: valid args, invalid args, unknown tool, confirmation flow - Tests for audit logger called on every invocation
This commit is contained in:
285
tests/unit/test_tool_executor.py
Normal file
285
tests/unit/test_tool_executor.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the tool executor.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- Valid args pass schema validation and handler is called
|
||||||
|
- Invalid args are rejected before the handler is called
|
||||||
|
- Unknown tool name raises ValueError
|
||||||
|
- requires_confirmation=True returns confirmation message without executing
|
||||||
|
- audit_logger.log_tool_call is called on every invocation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def make_registry(requires_confirmation: bool = False, handler_return: str = "tool result") -> dict:
|
||||||
|
"""Build a minimal tool registry for testing."""
|
||||||
|
from orchestrator.tools.registry import ToolDefinition
|
||||||
|
|
||||||
|
async def mock_handler(**kwargs: object) -> str:
|
||||||
|
return handler_return
|
||||||
|
|
||||||
|
return {
|
||||||
|
"test_tool": ToolDefinition(
|
||||||
|
name="test_tool",
|
||||||
|
description="A test tool",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"},
|
||||||
|
"count": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
requires_confirmation=requires_confirmation,
|
||||||
|
handler=mock_handler,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_audit_logger() -> MagicMock:
|
||||||
|
"""Return a mock AuditLogger."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.log_tool_call = AsyncMock()
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecuteToolValidArgs:
|
||||||
|
"""execute_tool with valid args calls the handler and returns result."""
|
||||||
|
|
||||||
|
async def test_valid_args_returns_result(self):
|
||||||
|
from orchestrator.tools.executor import execute_tool
|
||||||
|
|
||||||
|
registry = make_registry(handler_return="search results here")
|
||||||
|
audit_logger = make_audit_logger()
|
||||||
|
|
||||||
|
tool_call = {
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": '{"query": "hello world"}',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await execute_tool(
|
||||||
|
tool_call=tool_call,
|
||||||
|
registry=registry,
|
||||||
|
tenant_id=uuid.uuid4(),
|
||||||
|
agent_id=uuid.uuid4(),
|
||||||
|
audit_logger=audit_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "search results here"
|
||||||
|
|
||||||
|
async def test_audit_logger_called_on_success(self):
|
||||||
|
from orchestrator.tools.executor import execute_tool
|
||||||
|
|
||||||
|
registry = make_registry()
|
||||||
|
audit_logger = make_audit_logger()
|
||||||
|
|
||||||
|
tool_call = {
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": '{"query": "test"}',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await execute_tool(
|
||||||
|
tool_call=tool_call,
|
||||||
|
registry=registry,
|
||||||
|
tenant_id=uuid.uuid4(),
|
||||||
|
agent_id=uuid.uuid4(),
|
||||||
|
audit_logger=audit_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
audit_logger.log_tool_call.assert_called_once()
|
||||||
|
call_kwargs = audit_logger.log_tool_call.call_args[1]
|
||||||
|
assert call_kwargs["tool_name"] == "test_tool"
|
||||||
|
assert call_kwargs["error"] is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecuteToolInvalidArgs:
|
||||||
|
"""execute_tool with invalid args rejects before calling handler."""
|
||||||
|
|
||||||
|
async def test_missing_required_arg_returns_error(self):
|
||||||
|
from orchestrator.tools.executor import execute_tool
|
||||||
|
|
||||||
|
registry = make_registry()
|
||||||
|
audit_logger = make_audit_logger()
|
||||||
|
|
||||||
|
# 'query' is required but not provided
|
||||||
|
tool_call = {
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": '{"count": 5}',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await execute_tool(
|
||||||
|
tool_call=tool_call,
|
||||||
|
registry=registry,
|
||||||
|
tenant_id=uuid.uuid4(),
|
||||||
|
agent_id=uuid.uuid4(),
|
||||||
|
audit_logger=audit_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "error" in result.lower() or "invalid" in result.lower() or "required" in result.lower()
|
||||||
|
|
||||||
|
async def test_wrong_type_returns_error(self):
|
||||||
|
from orchestrator.tools.executor import execute_tool
|
||||||
|
|
||||||
|
registry = make_registry()
|
||||||
|
audit_logger = make_audit_logger()
|
||||||
|
|
||||||
|
# count should be integer but string provided
|
||||||
|
tool_call = {
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": '{"query": "hello", "count": "not_a_number"}',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await execute_tool(
|
||||||
|
tool_call=tool_call,
|
||||||
|
registry=registry,
|
||||||
|
tenant_id=uuid.uuid4(),
|
||||||
|
agent_id=uuid.uuid4(),
|
||||||
|
audit_logger=audit_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "error" in result.lower() or "invalid" in result.lower()
|
||||||
|
|
||||||
|
async def test_audit_logger_called_with_error_on_invalid_args(self):
|
||||||
|
from orchestrator.tools.executor import execute_tool
|
||||||
|
|
||||||
|
registry = make_registry()
|
||||||
|
audit_logger = make_audit_logger()
|
||||||
|
|
||||||
|
tool_call = {
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": '{}', # Missing required 'query'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await execute_tool(
|
||||||
|
tool_call=tool_call,
|
||||||
|
registry=registry,
|
||||||
|
tenant_id=uuid.uuid4(),
|
||||||
|
agent_id=uuid.uuid4(),
|
||||||
|
audit_logger=audit_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
audit_logger.log_tool_call.assert_called_once()
|
||||||
|
call_kwargs = audit_logger.log_tool_call.call_args[1]
|
||||||
|
assert call_kwargs["error"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecuteToolUnknownTool:
|
||||||
|
"""execute_tool with unknown tool name returns an error string."""
|
||||||
|
|
||||||
|
async def test_unknown_tool_returns_error(self):
|
||||||
|
from orchestrator.tools.executor import execute_tool
|
||||||
|
|
||||||
|
registry = make_registry()
|
||||||
|
audit_logger = make_audit_logger()
|
||||||
|
|
||||||
|
tool_call = {
|
||||||
|
"function": {
|
||||||
|
"name": "nonexistent_tool",
|
||||||
|
"arguments": "{}",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await execute_tool(
|
||||||
|
tool_call=tool_call,
|
||||||
|
registry=registry,
|
||||||
|
tenant_id=uuid.uuid4(),
|
||||||
|
agent_id=uuid.uuid4(),
|
||||||
|
audit_logger=audit_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "unknown" in result.lower() or "not found" in result.lower() or "error" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecuteToolConfirmation:
|
||||||
|
"""execute_tool with requires_confirmation=True returns confirmation message."""
|
||||||
|
|
||||||
|
async def test_confirmation_required_returns_confirmation_message(self):
|
||||||
|
from orchestrator.tools.executor import execute_tool
|
||||||
|
|
||||||
|
registry = make_registry(requires_confirmation=True)
|
||||||
|
audit_logger = make_audit_logger()
|
||||||
|
|
||||||
|
tool_call = {
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": '{"query": "confirm this"}',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await execute_tool(
|
||||||
|
tool_call=tool_call,
|
||||||
|
registry=registry,
|
||||||
|
tenant_id=uuid.uuid4(),
|
||||||
|
agent_id=uuid.uuid4(),
|
||||||
|
audit_logger=audit_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return a confirmation message, not the tool result
|
||||||
|
assert "confirm" in result.lower() or "permission" in result.lower() or "approval" in result.lower()
|
||||||
|
|
||||||
|
async def test_confirmation_required_does_not_call_handler(self):
|
||||||
|
"""Handler must NOT be called when requires_confirmation=True."""
|
||||||
|
from orchestrator.tools.registry import ToolDefinition
|
||||||
|
|
||||||
|
handler_called = False
|
||||||
|
|
||||||
|
async def tracking_handler(**kwargs: object) -> str:
|
||||||
|
nonlocal handler_called
|
||||||
|
handler_called = True
|
||||||
|
return "should not be called"
|
||||||
|
|
||||||
|
registry = {
|
||||||
|
"test_tool": ToolDefinition(
|
||||||
|
name="test_tool",
|
||||||
|
description="A test tool",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"query": {"type": "string"}},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
requires_confirmation=True,
|
||||||
|
handler=tracking_handler,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
audit_logger = make_audit_logger()
|
||||||
|
|
||||||
|
tool_call = {
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": '{"query": "test"}',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await execute_tool(
|
||||||
|
tool_call=tool_call,
|
||||||
|
registry=registry,
|
||||||
|
tenant_id=uuid.uuid4(),
|
||||||
|
agent_id=uuid.uuid4(),
|
||||||
|
audit_logger=audit_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not handler_called, "Handler should not be called when requires_confirmation=True"
|
||||||
149
tests/unit/test_tool_registry.py
Normal file
149
tests/unit/test_tool_registry.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the tool registry.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- BUILTIN_TOOLS contains all 4 expected tools
|
||||||
|
- get_tools_for_agent filters correctly based on agent.tool_assignments
|
||||||
|
- to_litellm_format produces valid OpenAI function-calling schema
|
||||||
|
- ToolDefinition model validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def make_agent(tool_list: list[str]) -> MagicMock:
|
||||||
|
"""Create a mock Agent with tool_assignments set."""
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.id = uuid.uuid4()
|
||||||
|
agent.tool_assignments = tool_list
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuiltinTools:
|
||||||
|
"""BUILTIN_TOOLS registry contains the correct tool definitions."""
|
||||||
|
|
||||||
|
def test_all_four_tools_present(self):
|
||||||
|
from orchestrator.tools.registry import BUILTIN_TOOLS
|
||||||
|
|
||||||
|
assert "web_search" in BUILTIN_TOOLS
|
||||||
|
assert "kb_search" in BUILTIN_TOOLS
|
||||||
|
assert "http_request" in BUILTIN_TOOLS
|
||||||
|
assert "calendar_lookup" in BUILTIN_TOOLS
|
||||||
|
|
||||||
|
def test_tool_definitions_have_required_fields(self):
|
||||||
|
from orchestrator.tools.registry import BUILTIN_TOOLS
|
||||||
|
|
||||||
|
for name, tool in BUILTIN_TOOLS.items():
|
||||||
|
assert tool.name == name, f"Tool name mismatch for {name}"
|
||||||
|
assert tool.description, f"Tool {name} missing description"
|
||||||
|
assert isinstance(tool.parameters, dict), f"Tool {name} parameters must be dict"
|
||||||
|
assert "type" in tool.parameters, f"Tool {name} parameters missing 'type' key"
|
||||||
|
assert tool.handler is not None, f"Tool {name} has no handler"
|
||||||
|
|
||||||
|
def test_http_request_requires_confirmation(self):
|
||||||
|
from orchestrator.tools.registry import BUILTIN_TOOLS
|
||||||
|
|
||||||
|
assert BUILTIN_TOOLS["http_request"].requires_confirmation is True
|
||||||
|
|
||||||
|
def test_web_search_no_confirmation(self):
|
||||||
|
from orchestrator.tools.registry import BUILTIN_TOOLS
|
||||||
|
|
||||||
|
assert BUILTIN_TOOLS["web_search"].requires_confirmation is False
|
||||||
|
|
||||||
|
def test_kb_search_no_confirmation(self):
|
||||||
|
from orchestrator.tools.registry import BUILTIN_TOOLS
|
||||||
|
|
||||||
|
assert BUILTIN_TOOLS["kb_search"].requires_confirmation is False
|
||||||
|
|
||||||
|
def test_calendar_lookup_no_confirmation(self):
|
||||||
|
from orchestrator.tools.registry import BUILTIN_TOOLS
|
||||||
|
|
||||||
|
assert BUILTIN_TOOLS["calendar_lookup"].requires_confirmation is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetToolsForAgent:
|
||||||
|
"""get_tools_for_agent filters BUILTIN_TOOLS by agent's tool_assignments list."""
|
||||||
|
|
||||||
|
def test_filters_to_assigned_tools(self):
|
||||||
|
from orchestrator.tools.registry import get_tools_for_agent
|
||||||
|
|
||||||
|
agent = make_agent(["web_search", "kb_search"])
|
||||||
|
result = get_tools_for_agent(agent)
|
||||||
|
|
||||||
|
assert set(result.keys()) == {"web_search", "kb_search"}
|
||||||
|
|
||||||
|
def test_empty_tool_list_returns_empty(self):
|
||||||
|
from orchestrator.tools.registry import get_tools_for_agent
|
||||||
|
|
||||||
|
agent = make_agent([])
|
||||||
|
result = get_tools_for_agent(agent)
|
||||||
|
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_unknown_tools_ignored_silently(self):
|
||||||
|
"""Tools in agent.tool_assignments that don't exist in BUILTIN_TOOLS are skipped."""
|
||||||
|
from orchestrator.tools.registry import get_tools_for_agent
|
||||||
|
|
||||||
|
agent = make_agent(["web_search", "nonexistent_tool"])
|
||||||
|
result = get_tools_for_agent(agent)
|
||||||
|
|
||||||
|
assert "web_search" in result
|
||||||
|
assert "nonexistent_tool" not in result
|
||||||
|
|
||||||
|
def test_all_tools_accessible(self):
|
||||||
|
from orchestrator.tools.registry import BUILTIN_TOOLS, get_tools_for_agent
|
||||||
|
|
||||||
|
agent = make_agent(list(BUILTIN_TOOLS.keys()))
|
||||||
|
result = get_tools_for_agent(agent)
|
||||||
|
|
||||||
|
assert set(result.keys()) == set(BUILTIN_TOOLS.keys())
|
||||||
|
|
||||||
|
|
||||||
|
class TestToLitellmFormat:
|
||||||
|
"""to_litellm_format converts tool definitions to OpenAI function-calling schema."""
|
||||||
|
|
||||||
|
def test_returns_list_of_dicts(self):
|
||||||
|
from orchestrator.tools.registry import BUILTIN_TOOLS, to_litellm_format
|
||||||
|
|
||||||
|
result = to_litellm_format({"web_search": BUILTIN_TOOLS["web_search"]})
|
||||||
|
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_openai_schema_structure(self):
|
||||||
|
"""Each entry must have type='function' and a nested function object."""
|
||||||
|
from orchestrator.tools.registry import BUILTIN_TOOLS, to_litellm_format
|
||||||
|
|
||||||
|
result = to_litellm_format({"web_search": BUILTIN_TOOLS["web_search"]})
|
||||||
|
entry = result[0]
|
||||||
|
|
||||||
|
assert entry["type"] == "function"
|
||||||
|
assert "function" in entry
|
||||||
|
func = entry["function"]
|
||||||
|
assert func["name"] == "web_search"
|
||||||
|
assert "description" in func
|
||||||
|
assert "parameters" in func
|
||||||
|
|
||||||
|
def test_empty_tools_returns_empty_list(self):
|
||||||
|
from orchestrator.tools.registry import to_litellm_format
|
||||||
|
|
||||||
|
assert to_litellm_format({}) == []
|
||||||
|
|
||||||
|
def test_multiple_tools_converted(self):
|
||||||
|
from orchestrator.tools.registry import BUILTIN_TOOLS, to_litellm_format
|
||||||
|
|
||||||
|
result = to_litellm_format(BUILTIN_TOOLS)
|
||||||
|
assert len(result) == 4
|
||||||
Reference in New Issue
Block a user