diff --git a/tests/unit/test_tool_executor.py b/tests/unit/test_tool_executor.py new file mode 100644 index 0000000..e5d7736 --- /dev/null +++ b/tests/unit/test_tool_executor.py @@ -0,0 +1,285 @@ +""" +Unit tests for the tool executor. + +Tests: + - Valid args pass schema validation and handler is called + - Invalid args are rejected before the handler is called + - Unknown tool name raises ValueError + - requires_confirmation=True returns confirmation message without executing + - audit_logger.log_tool_call is called on every invocation +""" + +from __future__ import annotations + +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_registry(requires_confirmation: bool = False, handler_return: str = "tool result") -> dict: + """Build a minimal tool registry for testing.""" + from orchestrator.tools.registry import ToolDefinition + + async def mock_handler(**kwargs: object) -> str: + return handler_return + + return { + "test_tool": ToolDefinition( + name="test_tool", + description="A test tool", + parameters={ + "type": "object", + "properties": { + "query": {"type": "string"}, + "count": {"type": "integer"}, + }, + "required": ["query"], + }, + requires_confirmation=requires_confirmation, + handler=mock_handler, + ) + } + + +def make_audit_logger() -> MagicMock: + """Return a mock AuditLogger.""" + mock = MagicMock() + mock.log_tool_call = AsyncMock() + return mock + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestExecuteToolValidArgs: + """execute_tool with valid args calls the handler and returns result.""" + + async def test_valid_args_returns_result(self): + from orchestrator.tools.executor import execute_tool + + registry = make_registry(handler_return="search results here") + audit_logger = make_audit_logger() + + tool_call = { + "function": { + "name": "test_tool", + "arguments": '{"query": "hello world"}', + } + } + + result = await execute_tool( + tool_call=tool_call, + registry=registry, + tenant_id=uuid.uuid4(), + agent_id=uuid.uuid4(), + audit_logger=audit_logger, + ) + + assert result == "search results here" + + async def test_audit_logger_called_on_success(self): + from orchestrator.tools.executor import execute_tool + + registry = make_registry() + audit_logger = make_audit_logger() + + tool_call = { + "function": { + "name": "test_tool", + "arguments": '{"query": "test"}', + } + } + + await execute_tool( + tool_call=tool_call, + registry=registry, + tenant_id=uuid.uuid4(), + agent_id=uuid.uuid4(), + audit_logger=audit_logger, + ) + + audit_logger.log_tool_call.assert_called_once() + call_kwargs = audit_logger.log_tool_call.call_args[1] + assert call_kwargs["tool_name"] == "test_tool" + assert call_kwargs["error"] is None + + +class TestExecuteToolInvalidArgs: + """execute_tool with invalid args rejects before calling handler.""" + + async def test_missing_required_arg_returns_error(self): + from orchestrator.tools.executor import execute_tool + + registry = make_registry() + audit_logger = make_audit_logger() + + # 'query' is required but not provided + tool_call = { + "function": { + "name": "test_tool", + "arguments": '{"count": 5}', + } + } + + result = await execute_tool( + tool_call=tool_call, + registry=registry, + tenant_id=uuid.uuid4(), + agent_id=uuid.uuid4(), + audit_logger=audit_logger, + ) + + assert "error" in result.lower() or "invalid" in result.lower() or "required" in result.lower() + + async def test_wrong_type_returns_error(self): + from orchestrator.tools.executor import execute_tool + + registry = make_registry() + audit_logger = make_audit_logger() + + # count should be integer but string provided + tool_call = { + "function": { + "name": "test_tool", + "arguments": '{"query": "hello", "count": "not_a_number"}', + } + } + + result = await execute_tool( + tool_call=tool_call, + registry=registry, + tenant_id=uuid.uuid4(), + agent_id=uuid.uuid4(), + audit_logger=audit_logger, + ) + + assert "error" in result.lower() or "invalid" in result.lower() + + async def test_audit_logger_called_with_error_on_invalid_args(self): + from orchestrator.tools.executor import execute_tool + + registry = make_registry() + audit_logger = make_audit_logger() + + tool_call = { + "function": { + "name": "test_tool", + "arguments": '{}', # Missing required 'query' + } + } + + await execute_tool( + tool_call=tool_call, + registry=registry, + tenant_id=uuid.uuid4(), + agent_id=uuid.uuid4(), + audit_logger=audit_logger, + ) + + audit_logger.log_tool_call.assert_called_once() + call_kwargs = audit_logger.log_tool_call.call_args[1] + assert call_kwargs["error"] is not None + + +class TestExecuteToolUnknownTool: + """execute_tool with unknown tool name returns an error string.""" + + async def test_unknown_tool_returns_error(self): + from orchestrator.tools.executor import execute_tool + + registry = make_registry() + audit_logger = make_audit_logger() + + tool_call = { + "function": { + "name": "nonexistent_tool", + "arguments": "{}", + } + } + + result = await execute_tool( + tool_call=tool_call, + registry=registry, + tenant_id=uuid.uuid4(), + agent_id=uuid.uuid4(), + audit_logger=audit_logger, + ) + + assert "unknown" in result.lower() or "not found" in result.lower() or "error" in result.lower() + + +class TestExecuteToolConfirmation: + """execute_tool with requires_confirmation=True returns confirmation message.""" + + async def test_confirmation_required_returns_confirmation_message(self): + from orchestrator.tools.executor import execute_tool + + registry = make_registry(requires_confirmation=True) + audit_logger = make_audit_logger() + + tool_call = { + "function": { + "name": "test_tool", + "arguments": '{"query": "confirm this"}', + } + } + + result = await execute_tool( + tool_call=tool_call, + registry=registry, + tenant_id=uuid.uuid4(), + agent_id=uuid.uuid4(), + audit_logger=audit_logger, + ) + + # Should return a confirmation message, not the tool result + assert "confirm" in result.lower() or "permission" in result.lower() or "approval" in result.lower() + + async def test_confirmation_required_does_not_call_handler(self): + """Handler must NOT be called when requires_confirmation=True.""" + from orchestrator.tools.registry import ToolDefinition + + handler_called = False + + async def tracking_handler(**kwargs: object) -> str: + nonlocal handler_called + handler_called = True + return "should not be called" + + registry = { + "test_tool": ToolDefinition( + name="test_tool", + description="A test tool", + parameters={ + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + requires_confirmation=True, + handler=tracking_handler, + ) + } + audit_logger = make_audit_logger() + + tool_call = { + "function": { + "name": "test_tool", + "arguments": '{"query": "test"}', + } + } + + await execute_tool( + tool_call=tool_call, + registry=registry, + tenant_id=uuid.uuid4(), + agent_id=uuid.uuid4(), + audit_logger=audit_logger, + ) + + assert not handler_called, "Handler should not be called when requires_confirmation=True" diff --git a/tests/unit/test_tool_registry.py b/tests/unit/test_tool_registry.py new file mode 100644 index 0000000..46f8aed --- /dev/null +++ b/tests/unit/test_tool_registry.py @@ -0,0 +1,149 @@ +""" +Unit tests for the tool registry. + +Tests: + - BUILTIN_TOOLS contains all 4 expected tools + - get_tools_for_agent filters correctly based on agent.tool_assignments + - to_litellm_format produces valid OpenAI function-calling schema + - ToolDefinition model validation +""" + +from __future__ import annotations + +import uuid +from unittest.mock import MagicMock + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_agent(tool_list: list[str]) -> MagicMock: + """Create a mock Agent with tool_assignments set.""" + agent = MagicMock() + agent.id = uuid.uuid4() + agent.tool_assignments = tool_list + return agent + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestBuiltinTools: + """BUILTIN_TOOLS registry contains the correct tool definitions.""" + + def test_all_four_tools_present(self): + from orchestrator.tools.registry import BUILTIN_TOOLS + + assert "web_search" in BUILTIN_TOOLS + assert "kb_search" in BUILTIN_TOOLS + assert "http_request" in BUILTIN_TOOLS + assert "calendar_lookup" in BUILTIN_TOOLS + + def test_tool_definitions_have_required_fields(self): + from orchestrator.tools.registry import BUILTIN_TOOLS + + for name, tool in BUILTIN_TOOLS.items(): + assert tool.name == name, f"Tool name mismatch for {name}" + assert tool.description, f"Tool {name} missing description" + assert isinstance(tool.parameters, dict), f"Tool {name} parameters must be dict" + assert "type" in tool.parameters, f"Tool {name} parameters missing 'type' key" + assert tool.handler is not None, f"Tool {name} has no handler" + + def test_http_request_requires_confirmation(self): + from orchestrator.tools.registry import BUILTIN_TOOLS + + assert BUILTIN_TOOLS["http_request"].requires_confirmation is True + + def test_web_search_no_confirmation(self): + from orchestrator.tools.registry import BUILTIN_TOOLS + + assert BUILTIN_TOOLS["web_search"].requires_confirmation is False + + def test_kb_search_no_confirmation(self): + from orchestrator.tools.registry import BUILTIN_TOOLS + + assert BUILTIN_TOOLS["kb_search"].requires_confirmation is False + + def test_calendar_lookup_no_confirmation(self): + from orchestrator.tools.registry import BUILTIN_TOOLS + + assert BUILTIN_TOOLS["calendar_lookup"].requires_confirmation is False + + +class TestGetToolsForAgent: + """get_tools_for_agent filters BUILTIN_TOOLS by agent's tool_assignments list.""" + + def test_filters_to_assigned_tools(self): + from orchestrator.tools.registry import get_tools_for_agent + + agent = make_agent(["web_search", "kb_search"]) + result = get_tools_for_agent(agent) + + assert set(result.keys()) == {"web_search", "kb_search"} + + def test_empty_tool_list_returns_empty(self): + from orchestrator.tools.registry import get_tools_for_agent + + agent = make_agent([]) + result = get_tools_for_agent(agent) + + assert result == {} + + def test_unknown_tools_ignored_silently(self): + """Tools in agent.tool_assignments that don't exist in BUILTIN_TOOLS are skipped.""" + from orchestrator.tools.registry import get_tools_for_agent + + agent = make_agent(["web_search", "nonexistent_tool"]) + result = get_tools_for_agent(agent) + + assert "web_search" in result + assert "nonexistent_tool" not in result + + def test_all_tools_accessible(self): + from orchestrator.tools.registry import BUILTIN_TOOLS, get_tools_for_agent + + agent = make_agent(list(BUILTIN_TOOLS.keys())) + result = get_tools_for_agent(agent) + + assert set(result.keys()) == set(BUILTIN_TOOLS.keys()) + + +class TestToLitellmFormat: + """to_litellm_format converts tool definitions to OpenAI function-calling schema.""" + + def test_returns_list_of_dicts(self): + from orchestrator.tools.registry import BUILTIN_TOOLS, to_litellm_format + + result = to_litellm_format({"web_search": BUILTIN_TOOLS["web_search"]}) + + assert isinstance(result, list) + assert len(result) == 1 + + def test_openai_schema_structure(self): + """Each entry must have type='function' and a nested function object.""" + from orchestrator.tools.registry import BUILTIN_TOOLS, to_litellm_format + + result = to_litellm_format({"web_search": BUILTIN_TOOLS["web_search"]}) + entry = result[0] + + assert entry["type"] == "function" + assert "function" in entry + func = entry["function"] + assert func["name"] == "web_search" + assert "description" in func + assert "parameters" in func + + def test_empty_tools_returns_empty_list(self): + from orchestrator.tools.registry import to_litellm_format + + assert to_litellm_format({}) == [] + + def test_multiple_tools_converted(self): + from orchestrator.tools.registry import BUILTIN_TOOLS, to_litellm_format + + result = to_litellm_format(BUILTIN_TOOLS) + assert len(result) == 4