test(02-02): add failing tool registry and executor unit tests
- Tests for BUILTIN_TOOLS (4 tools present, correct fields, confirmation flags) - Tests for get_tools_for_agent filtering and to_litellm_format conversion - Tests for execute_tool: valid args, invalid args, unknown tool, confirmation flow - Tests for audit logger called on every invocation
This commit is contained in:
285
tests/unit/test_tool_executor.py
Normal file
285
tests/unit/test_tool_executor.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Unit tests for the tool executor.
|
||||
|
||||
Tests:
|
||||
- Valid args pass schema validation and handler is called
|
||||
- Invalid args are rejected before the handler is called
|
||||
- Unknown tool name raises ValueError
|
||||
- requires_confirmation=True returns confirmation message without executing
|
||||
- audit_logger.log_tool_call is called on every invocation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_registry(requires_confirmation: bool = False, handler_return: str = "tool result") -> dict:
|
||||
"""Build a minimal tool registry for testing."""
|
||||
from orchestrator.tools.registry import ToolDefinition
|
||||
|
||||
async def mock_handler(**kwargs: object) -> str:
|
||||
return handler_return
|
||||
|
||||
return {
|
||||
"test_tool": ToolDefinition(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"count": {"type": "integer"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
requires_confirmation=requires_confirmation,
|
||||
handler=mock_handler,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def make_audit_logger() -> MagicMock:
|
||||
"""Return a mock AuditLogger."""
|
||||
mock = MagicMock()
|
||||
mock.log_tool_call = AsyncMock()
|
||||
return mock
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExecuteToolValidArgs:
|
||||
"""execute_tool with valid args calls the handler and returns result."""
|
||||
|
||||
async def test_valid_args_returns_result(self):
|
||||
from orchestrator.tools.executor import execute_tool
|
||||
|
||||
registry = make_registry(handler_return="search results here")
|
||||
audit_logger = make_audit_logger()
|
||||
|
||||
tool_call = {
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"query": "hello world"}',
|
||||
}
|
||||
}
|
||||
|
||||
result = await execute_tool(
|
||||
tool_call=tool_call,
|
||||
registry=registry,
|
||||
tenant_id=uuid.uuid4(),
|
||||
agent_id=uuid.uuid4(),
|
||||
audit_logger=audit_logger,
|
||||
)
|
||||
|
||||
assert result == "search results here"
|
||||
|
||||
async def test_audit_logger_called_on_success(self):
|
||||
from orchestrator.tools.executor import execute_tool
|
||||
|
||||
registry = make_registry()
|
||||
audit_logger = make_audit_logger()
|
||||
|
||||
tool_call = {
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"query": "test"}',
|
||||
}
|
||||
}
|
||||
|
||||
await execute_tool(
|
||||
tool_call=tool_call,
|
||||
registry=registry,
|
||||
tenant_id=uuid.uuid4(),
|
||||
agent_id=uuid.uuid4(),
|
||||
audit_logger=audit_logger,
|
||||
)
|
||||
|
||||
audit_logger.log_tool_call.assert_called_once()
|
||||
call_kwargs = audit_logger.log_tool_call.call_args[1]
|
||||
assert call_kwargs["tool_name"] == "test_tool"
|
||||
assert call_kwargs["error"] is None
|
||||
|
||||
|
||||
class TestExecuteToolInvalidArgs:
|
||||
"""execute_tool with invalid args rejects before calling handler."""
|
||||
|
||||
async def test_missing_required_arg_returns_error(self):
|
||||
from orchestrator.tools.executor import execute_tool
|
||||
|
||||
registry = make_registry()
|
||||
audit_logger = make_audit_logger()
|
||||
|
||||
# 'query' is required but not provided
|
||||
tool_call = {
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"count": 5}',
|
||||
}
|
||||
}
|
||||
|
||||
result = await execute_tool(
|
||||
tool_call=tool_call,
|
||||
registry=registry,
|
||||
tenant_id=uuid.uuid4(),
|
||||
agent_id=uuid.uuid4(),
|
||||
audit_logger=audit_logger,
|
||||
)
|
||||
|
||||
assert "error" in result.lower() or "invalid" in result.lower() or "required" in result.lower()
|
||||
|
||||
async def test_wrong_type_returns_error(self):
|
||||
from orchestrator.tools.executor import execute_tool
|
||||
|
||||
registry = make_registry()
|
||||
audit_logger = make_audit_logger()
|
||||
|
||||
# count should be integer but string provided
|
||||
tool_call = {
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"query": "hello", "count": "not_a_number"}',
|
||||
}
|
||||
}
|
||||
|
||||
result = await execute_tool(
|
||||
tool_call=tool_call,
|
||||
registry=registry,
|
||||
tenant_id=uuid.uuid4(),
|
||||
agent_id=uuid.uuid4(),
|
||||
audit_logger=audit_logger,
|
||||
)
|
||||
|
||||
assert "error" in result.lower() or "invalid" in result.lower()
|
||||
|
||||
async def test_audit_logger_called_with_error_on_invalid_args(self):
|
||||
from orchestrator.tools.executor import execute_tool
|
||||
|
||||
registry = make_registry()
|
||||
audit_logger = make_audit_logger()
|
||||
|
||||
tool_call = {
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{}', # Missing required 'query'
|
||||
}
|
||||
}
|
||||
|
||||
await execute_tool(
|
||||
tool_call=tool_call,
|
||||
registry=registry,
|
||||
tenant_id=uuid.uuid4(),
|
||||
agent_id=uuid.uuid4(),
|
||||
audit_logger=audit_logger,
|
||||
)
|
||||
|
||||
audit_logger.log_tool_call.assert_called_once()
|
||||
call_kwargs = audit_logger.log_tool_call.call_args[1]
|
||||
assert call_kwargs["error"] is not None
|
||||
|
||||
|
||||
class TestExecuteToolUnknownTool:
|
||||
"""execute_tool with unknown tool name returns an error string."""
|
||||
|
||||
async def test_unknown_tool_returns_error(self):
|
||||
from orchestrator.tools.executor import execute_tool
|
||||
|
||||
registry = make_registry()
|
||||
audit_logger = make_audit_logger()
|
||||
|
||||
tool_call = {
|
||||
"function": {
|
||||
"name": "nonexistent_tool",
|
||||
"arguments": "{}",
|
||||
}
|
||||
}
|
||||
|
||||
result = await execute_tool(
|
||||
tool_call=tool_call,
|
||||
registry=registry,
|
||||
tenant_id=uuid.uuid4(),
|
||||
agent_id=uuid.uuid4(),
|
||||
audit_logger=audit_logger,
|
||||
)
|
||||
|
||||
assert "unknown" in result.lower() or "not found" in result.lower() or "error" in result.lower()
|
||||
|
||||
|
||||
class TestExecuteToolConfirmation:
|
||||
"""execute_tool with requires_confirmation=True returns confirmation message."""
|
||||
|
||||
async def test_confirmation_required_returns_confirmation_message(self):
|
||||
from orchestrator.tools.executor import execute_tool
|
||||
|
||||
registry = make_registry(requires_confirmation=True)
|
||||
audit_logger = make_audit_logger()
|
||||
|
||||
tool_call = {
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"query": "confirm this"}',
|
||||
}
|
||||
}
|
||||
|
||||
result = await execute_tool(
|
||||
tool_call=tool_call,
|
||||
registry=registry,
|
||||
tenant_id=uuid.uuid4(),
|
||||
agent_id=uuid.uuid4(),
|
||||
audit_logger=audit_logger,
|
||||
)
|
||||
|
||||
# Should return a confirmation message, not the tool result
|
||||
assert "confirm" in result.lower() or "permission" in result.lower() or "approval" in result.lower()
|
||||
|
||||
async def test_confirmation_required_does_not_call_handler(self):
|
||||
"""Handler must NOT be called when requires_confirmation=True."""
|
||||
from orchestrator.tools.registry import ToolDefinition
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def tracking_handler(**kwargs: object) -> str:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
return "should not be called"
|
||||
|
||||
registry = {
|
||||
"test_tool": ToolDefinition(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=tracking_handler,
|
||||
)
|
||||
}
|
||||
audit_logger = make_audit_logger()
|
||||
|
||||
tool_call = {
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"query": "test"}',
|
||||
}
|
||||
}
|
||||
|
||||
await execute_tool(
|
||||
tool_call=tool_call,
|
||||
registry=registry,
|
||||
tenant_id=uuid.uuid4(),
|
||||
agent_id=uuid.uuid4(),
|
||||
audit_logger=audit_logger,
|
||||
)
|
||||
|
||||
assert not handler_called, "Handler should not be called when requires_confirmation=True"
|
||||
Reference in New Issue
Block a user