""" Unit tests for the tool executor. Tests: - Valid args pass schema validation and handler is called - Invalid args are rejected before the handler is called - Unknown tool name raises ValueError - requires_confirmation=True returns confirmation message without executing - audit_logger.log_tool_call is called on every invocation """ from __future__ import annotations import uuid from unittest.mock import AsyncMock, MagicMock, patch # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def make_registry(requires_confirmation: bool = False, handler_return: str = "tool result") -> dict: """Build a minimal tool registry for testing.""" from orchestrator.tools.registry import ToolDefinition async def mock_handler(**kwargs: object) -> str: return handler_return return { "test_tool": ToolDefinition( name="test_tool", description="A test tool", parameters={ "type": "object", "properties": { "query": {"type": "string"}, "count": {"type": "integer"}, }, "required": ["query"], }, requires_confirmation=requires_confirmation, handler=mock_handler, ) } def make_audit_logger() -> MagicMock: """Return a mock AuditLogger.""" mock = MagicMock() mock.log_tool_call = AsyncMock() return mock # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- class TestExecuteToolValidArgs: """execute_tool with valid args calls the handler and returns result.""" async def test_valid_args_returns_result(self): from orchestrator.tools.executor import execute_tool registry = make_registry(handler_return="search results here") audit_logger = make_audit_logger() tool_call = { "function": { "name": "test_tool", "arguments": '{"query": "hello world"}', } } result = await execute_tool( tool_call=tool_call, registry=registry, tenant_id=uuid.uuid4(), agent_id=uuid.uuid4(), audit_logger=audit_logger, ) assert result == "search results here" async def test_audit_logger_called_on_success(self): from orchestrator.tools.executor import execute_tool registry = make_registry() audit_logger = make_audit_logger() tool_call = { "function": { "name": "test_tool", "arguments": '{"query": "test"}', } } await execute_tool( tool_call=tool_call, registry=registry, tenant_id=uuid.uuid4(), agent_id=uuid.uuid4(), audit_logger=audit_logger, ) audit_logger.log_tool_call.assert_called_once() call_kwargs = audit_logger.log_tool_call.call_args[1] assert call_kwargs["tool_name"] == "test_tool" assert call_kwargs["error"] is None class TestExecuteToolInvalidArgs: """execute_tool with invalid args rejects before calling handler.""" async def test_missing_required_arg_returns_error(self): from orchestrator.tools.executor import execute_tool registry = make_registry() audit_logger = make_audit_logger() # 'query' is required but not provided tool_call = { "function": { "name": "test_tool", "arguments": '{"count": 5}', } } result = await execute_tool( tool_call=tool_call, registry=registry, tenant_id=uuid.uuid4(), agent_id=uuid.uuid4(), audit_logger=audit_logger, ) assert "error" in result.lower() or "invalid" in result.lower() or "required" in result.lower() async def test_wrong_type_returns_error(self): from orchestrator.tools.executor import execute_tool registry = make_registry() audit_logger = make_audit_logger() # count should be integer but string provided tool_call = { "function": { "name": "test_tool", "arguments": '{"query": "hello", "count": "not_a_number"}', } } result = await execute_tool( tool_call=tool_call, registry=registry, tenant_id=uuid.uuid4(), agent_id=uuid.uuid4(), audit_logger=audit_logger, ) assert "error" in result.lower() or "invalid" in result.lower() async def test_audit_logger_called_with_error_on_invalid_args(self): from orchestrator.tools.executor import execute_tool registry = make_registry() audit_logger = make_audit_logger() tool_call = { "function": { "name": "test_tool", "arguments": '{}', # Missing required 'query' } } await execute_tool( tool_call=tool_call, registry=registry, tenant_id=uuid.uuid4(), agent_id=uuid.uuid4(), audit_logger=audit_logger, ) audit_logger.log_tool_call.assert_called_once() call_kwargs = audit_logger.log_tool_call.call_args[1] assert call_kwargs["error"] is not None class TestExecuteToolUnknownTool: """execute_tool with unknown tool name returns an error string.""" async def test_unknown_tool_returns_error(self): from orchestrator.tools.executor import execute_tool registry = make_registry() audit_logger = make_audit_logger() tool_call = { "function": { "name": "nonexistent_tool", "arguments": "{}", } } result = await execute_tool( tool_call=tool_call, registry=registry, tenant_id=uuid.uuid4(), agent_id=uuid.uuid4(), audit_logger=audit_logger, ) assert "unknown" in result.lower() or "not found" in result.lower() or "error" in result.lower() class TestExecuteToolConfirmation: """execute_tool with requires_confirmation=True returns confirmation message.""" async def test_confirmation_required_returns_confirmation_message(self): from orchestrator.tools.executor import execute_tool registry = make_registry(requires_confirmation=True) audit_logger = make_audit_logger() tool_call = { "function": { "name": "test_tool", "arguments": '{"query": "confirm this"}', } } result = await execute_tool( tool_call=tool_call, registry=registry, tenant_id=uuid.uuid4(), agent_id=uuid.uuid4(), audit_logger=audit_logger, ) # Should return a confirmation message, not the tool result assert "confirm" in result.lower() or "permission" in result.lower() or "approval" in result.lower() async def test_confirmation_required_does_not_call_handler(self): """Handler must NOT be called when requires_confirmation=True.""" from orchestrator.tools.registry import ToolDefinition handler_called = False async def tracking_handler(**kwargs: object) -> str: nonlocal handler_called handler_called = True return "should not be called" registry = { "test_tool": ToolDefinition( name="test_tool", description="A test tool", parameters={ "type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"], }, requires_confirmation=True, handler=tracking_handler, ) } audit_logger = make_audit_logger() tool_call = { "function": { "name": "test_tool", "arguments": '{"query": "test"}', } } await execute_tool( tool_call=tool_call, registry=registry, tenant_id=uuid.uuid4(), agent_id=uuid.uuid4(), audit_logger=audit_logger, ) assert not handler_called, "Handler should not be called when requires_confirmation=True"