feat(02-02): tool registry, executor, and 4 built-in tools
- ToolDefinition Pydantic model with JSON Schema parameters + handler - BUILTIN_TOOLS: web_search, kb_search, http_request, calendar_lookup - http_request requires_confirmation=True (outbound side effects) - get_tools_for_agent filters by agent.tool_assignments - to_litellm_format converts to OpenAI function-calling schema - execute_tool: jsonschema validation before handler call - execute_tool: confirmation gate for requires_confirmation=True - execute_tool: audit logging on every invocation (success + failure) - web_search: Brave Search API with BRAVE_API_KEY env var - kb_search: pgvector cosine similarity with HNSW index - http_request: 30s timeout, 1MB cap, GET/POST/PUT/DELETE only - calendar_lookup: Google Calendar events.list read-only - jsonschema dependency added to orchestrator pyproject.toml - [Rule 1 - Bug] Added missing execute_tool import in test
This commit is contained in:
1
packages/orchestrator/orchestrator/tools/__init__.py
Normal file
1
packages/orchestrator/orchestrator/tools/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tool framework for the Konstruct Agent Orchestrator."""
|
||||
@@ -0,0 +1 @@
|
||||
"""Built-in tool handlers for the Konstruct Agent Orchestrator."""
|
||||
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Built-in tool: calendar_lookup
|
||||
|
||||
Reads calendar events from Google Calendar for a given date.
|
||||
|
||||
Authentication options (in priority order):
|
||||
1. GOOGLE_SERVICE_ACCOUNT_KEY env var — JSON key for service account impersonation
|
||||
2. Per-tenant OAuth (future: Phase 3 portal) — not yet implemented
|
||||
3. Graceful degradation: returns informative message if not configured
|
||||
|
||||
This tool is read-only (requires_confirmation=False in registry).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def calendar_lookup(
|
||||
date: str,
|
||||
calendar_id: str = "primary",
|
||||
**kwargs: object,
|
||||
) -> str:
|
||||
"""
|
||||
Look up calendar events for a specific date.
|
||||
|
||||
Args:
|
||||
date: Date in YYYY-MM-DD format.
|
||||
calendar_id: Google Calendar ID. Defaults to 'primary'.
|
||||
|
||||
Returns:
|
||||
Formatted string listing events for the given date,
|
||||
or an informative message if Google Calendar is not configured.
|
||||
"""
|
||||
service_account_key_json = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY", "")
|
||||
if not service_account_key_json:
|
||||
return (
|
||||
"Calendar lookup is not configured. "
|
||||
"Set the GOOGLE_SERVICE_ACCOUNT_KEY environment variable to enable calendar access."
|
||||
)
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
result = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
_fetch_calendar_events_sync,
|
||||
service_account_key_json,
|
||||
calendar_id,
|
||||
date,
|
||||
)
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception("Calendar lookup failed for date=%s calendar=%s", date, calendar_id)
|
||||
return f"Calendar lookup failed for {date}. Please try again."
|
||||
|
||||
|
||||
def _fetch_calendar_events_sync(
|
||||
service_account_key_json: str,
|
||||
calendar_id: str,
|
||||
date: str,
|
||||
) -> str:
|
||||
"""
|
||||
Synchronous implementation — runs in thread executor to avoid blocking event loop.
|
||||
|
||||
Uses google-api-python-client with service account credentials.
|
||||
"""
|
||||
try:
|
||||
from google.oauth2 import service_account
|
||||
from googleapiclient.discovery import build
|
||||
except ImportError:
|
||||
return (
|
||||
"Google Calendar library not installed. "
|
||||
"Run: uv add google-api-python-client google-auth"
|
||||
)
|
||||
|
||||
try:
|
||||
key_data = json.loads(service_account_key_json)
|
||||
except json.JSONDecodeError:
|
||||
return "Invalid GOOGLE_SERVICE_ACCOUNT_KEY: not valid JSON."
|
||||
|
||||
try:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
key_data,
|
||||
scopes=["https://www.googleapis.com/auth/calendar.readonly"],
|
||||
)
|
||||
except Exception as exc:
|
||||
return f"Failed to create Google credentials: {exc}"
|
||||
|
||||
# Parse the date and create RFC3339 time boundaries for the day
|
||||
try:
|
||||
date_obj = datetime.strptime(date, "%Y-%m-%d").replace(tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
return f"Invalid date format: {date!r}. Expected YYYY-MM-DD."
|
||||
|
||||
time_min = date_obj.strftime("%Y-%m-%dT00:00:00Z")
|
||||
time_max = date_obj.strftime("%Y-%m-%dT23:59:59Z")
|
||||
|
||||
try:
|
||||
service = build("calendar", "v3", credentials=credentials, cache_discovery=False)
|
||||
events_result = (
|
||||
service.events()
|
||||
.list(
|
||||
calendarId=calendar_id,
|
||||
timeMin=time_min,
|
||||
timeMax=time_max,
|
||||
singleEvents=True,
|
||||
orderBy="startTime",
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Google Calendar API error: %s", exc)
|
||||
return f"Calendar API error: {exc}"
|
||||
|
||||
items = events_result.get("items", [])
|
||||
if not items:
|
||||
return f"No events found on {date}."
|
||||
|
||||
lines = [f"Calendar events for {date}:\n"]
|
||||
for event in items:
|
||||
start = event["start"].get("dateTime", event["start"].get("date", "Unknown time"))
|
||||
summary = event.get("summary", "Untitled event")
|
||||
lines.append(f"- {start}: {summary}")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Built-in tool: http_request
|
||||
|
||||
Makes an outbound HTTP request with safety constraints:
|
||||
- Timeout: 30 seconds
|
||||
- Response size cap: 1MB
|
||||
- Allowed methods: GET, POST, PUT, DELETE
|
||||
|
||||
This tool requires_confirmation=True (set in registry) because outbound HTTP
|
||||
requests can have side effects and should always be user-approved.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ALLOWED_METHODS = frozenset({"GET", "POST", "PUT", "DELETE"})
|
||||
_REQUEST_TIMEOUT = httpx.Timeout(timeout=30.0, connect=10.0)
|
||||
_MAX_RESPONSE_BYTES = 1 * 1024 * 1024 # 1MB
|
||||
|
||||
|
||||
async def http_request(
|
||||
url: str,
|
||||
method: str = "GET",
|
||||
body: str | None = None,
|
||||
**kwargs: object,
|
||||
) -> str:
|
||||
"""
|
||||
Make an outbound HTTP request.
|
||||
|
||||
Args:
|
||||
url: The target URL.
|
||||
method: HTTP method — GET, POST, PUT, or DELETE.
|
||||
body: Optional request body string (used for POST/PUT).
|
||||
|
||||
Returns:
|
||||
Response body as a string (capped at 1MB), or an error message.
|
||||
"""
|
||||
method = method.upper()
|
||||
if method not in _ALLOWED_METHODS:
|
||||
return f"Invalid HTTP method '{method}'. Allowed: {', '.join(sorted(_ALLOWED_METHODS))}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_REQUEST_TIMEOUT) as client:
|
||||
request_kwargs: dict = {}
|
||||
if body is not None:
|
||||
request_kwargs["content"] = body.encode("utf-8")
|
||||
|
||||
response = await client.request(method, url, **request_kwargs)
|
||||
response.raise_for_status()
|
||||
|
||||
# Cap response size at 1MB
|
||||
content_bytes = response.content[:_MAX_RESPONSE_BYTES]
|
||||
content = content_bytes.decode("utf-8", errors="replace")
|
||||
|
||||
truncated = len(response.content) > _MAX_RESPONSE_BYTES
|
||||
suffix = "\n[Response truncated at 1MB limit]" if truncated else ""
|
||||
|
||||
return f"HTTP {response.status_code} {method} {url}\n\n{content}{suffix}"
|
||||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.warning("http_request HTTP error: %s %s -> %s", method, url, exc.response.status_code)
|
||||
return f"HTTP {exc.response.status_code} error from {url}"
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("http_request timeout: %s %s", method, url)
|
||||
return f"Request to {url} timed out after 30 seconds."
|
||||
except httpx.RequestError as exc:
|
||||
logger.warning("http_request connection error: %s", exc)
|
||||
return f"Failed to connect to {url}: {exc}"
|
||||
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Built-in tool: kb_search
|
||||
|
||||
Searches the tenant's knowledge base using pgvector cosine similarity.
|
||||
|
||||
The query is embedded using the same all-MiniLM-L6-v2 model as conversation
|
||||
embeddings (vector(384)), then matched against kb_chunks via HNSW ANN search.
|
||||
|
||||
The tool accepts an optional tenant_id and agent_id via kwargs — these are
|
||||
injected by the executor using context that the LLM doesn't provide directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TOP_K = 3
|
||||
|
||||
|
||||
async def kb_search(
|
||||
query: str,
|
||||
tenant_id: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
**kwargs: object,
|
||||
) -> str:
|
||||
"""
|
||||
Search the knowledge base for content relevant to the query.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
tenant_id: Injected by executor — the current tenant UUID string.
|
||||
agent_id: Injected by executor — the current agent UUID string.
|
||||
|
||||
Returns:
|
||||
Formatted string with top matching KB chunks, or a message if the KB
|
||||
is empty or search fails.
|
||||
"""
|
||||
if not tenant_id:
|
||||
return "Knowledge base search unavailable: tenant context not set."
|
||||
|
||||
try:
|
||||
import uuid as _uuid
|
||||
|
||||
from shared.db import async_session_factory, engine
|
||||
from shared.rls import configure_rls_hook, current_tenant_id
|
||||
from orchestrator.memory.embedder import embed_text
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
tenant_uuid = _uuid.UUID(tenant_id)
|
||||
query_embedding = embed_text(query)
|
||||
|
||||
# Format embedding as PostgreSQL vector literal
|
||||
embedding_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
|
||||
|
||||
configure_rls_hook(engine)
|
||||
token = current_tenant_id.set(tenant_uuid)
|
||||
try:
|
||||
async with async_session_factory() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT content, chunk_index,
|
||||
embedding <=> CAST(:embedding AS vector) AS distance
|
||||
FROM kb_chunks
|
||||
WHERE tenant_id = :tenant_id
|
||||
ORDER BY embedding <=> CAST(:embedding AS vector)
|
||||
LIMIT :top_k
|
||||
"""),
|
||||
{
|
||||
"embedding": embedding_str,
|
||||
"tenant_id": str(tenant_uuid),
|
||||
"top_k": _TOP_K,
|
||||
},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
finally:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
except Exception:
|
||||
logger.exception("KB search failed for tenant=%s", tenant_id)
|
||||
return "Knowledge base search encountered an error. Please try again."
|
||||
|
||||
if not rows:
|
||||
return f"No relevant knowledge base content found for: {query}"
|
||||
|
||||
lines = [f"Knowledge base results for: {query}\n"]
|
||||
for i, row in enumerate(rows, start=1):
|
||||
lines.append(f"{i}. {row.content}\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Built-in tool: web_search
|
||||
|
||||
Uses the Brave Search API to return top 3 search results.
|
||||
|
||||
Environment variable required:
|
||||
BRAVE_API_KEY — Brave Search API key. Set in .env.
|
||||
|
||||
If BRAVE_API_KEY is not set, returns an informative error message instead of
|
||||
raising an exception (graceful degradation for agents without search configured).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BRAVE_API_URL = "https://api.search.brave.com/res/v1/web/search"
|
||||
_BRAVE_TIMEOUT = httpx.Timeout(timeout=15.0, connect=5.0)
|
||||
_MAX_RESULTS = 3
|
||||
|
||||
|
||||
async def web_search(query: str) -> str:
|
||||
"""
|
||||
Search the web using Brave Search API.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
|
||||
Returns:
|
||||
Formatted string with top 3 search results (title + URL + description),
|
||||
or an error message if the API is unavailable.
|
||||
"""
|
||||
api_key = os.getenv("BRAVE_API_KEY", "")
|
||||
if not api_key:
|
||||
return (
|
||||
"Web search is not configured. "
|
||||
"Set the BRAVE_API_KEY environment variable to enable web search."
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_BRAVE_TIMEOUT) as client:
|
||||
response = await client.get(
|
||||
_BRAVE_API_URL,
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Accept-Encoding": "gzip",
|
||||
"X-Subscription-Token": api_key,
|
||||
},
|
||||
params={"q": query, "count": _MAX_RESULTS},
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.warning("Brave Search API error: %s", exc.response.status_code)
|
||||
return f"Web search failed: HTTP {exc.response.status_code}"
|
||||
except httpx.RequestError as exc:
|
||||
logger.warning("Brave Search connection error: %s", exc)
|
||||
return "Web search is unavailable right now. Please try again later."
|
||||
|
||||
data = response.json()
|
||||
results = data.get("web", {}).get("results", [])
|
||||
|
||||
if not results:
|
||||
return f"No results found for: {query}"
|
||||
|
||||
lines = [f"Search results for: {query}\n"]
|
||||
for i, item in enumerate(results[:_MAX_RESULTS], start=1):
|
||||
title = item.get("title", "Untitled")
|
||||
url = item.get("url", "")
|
||||
description = item.get("description", "No description available.")
|
||||
lines.append(f"{i}. **{title}**\n {url}\n {description}\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
194
packages/orchestrator/orchestrator/tools/executor.py
Normal file
194
packages/orchestrator/orchestrator/tools/executor.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Tool executor — schema-validated tool dispatch with audit logging.
|
||||
|
||||
execute_tool() is the single entry point for all tool invocations. It:
|
||||
1. Looks up the tool in the registry
|
||||
2. Parses the tool call arguments (LLM output — always untrusted JSON)
|
||||
3. Validates args against the tool's JSON Schema (rejects invalid input)
|
||||
4. If requires_confirmation=True, returns a confirmation request without executing
|
||||
5. Calls the tool handler
|
||||
6. Logs the invocation (success or failure) to the audit trail
|
||||
7. Returns the result string
|
||||
|
||||
CRITICAL: Tool arguments come from LLM output, which is untrusted. Schema
|
||||
validation MUST happen before any handler is called. Invalid args are rejected
|
||||
with an error message — never silently coerced.
|
||||
|
||||
CRITICAL: All tool execution happens within the single Celery task context.
|
||||
Never dispatch separate Celery tasks for tool execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import jsonschema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from orchestrator.audit.logger import AuditLogger
|
||||
from orchestrator.tools.registry import ToolDefinition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CONFIRMATION_MESSAGE_TEMPLATE = (
|
||||
"This action requires your approval before I proceed:\n\n"
|
||||
"**Tool:** {tool_name}\n"
|
||||
"**Arguments:** {args_summary}\n\n"
|
||||
"Please reply **yes** to confirm or **no** to cancel."
|
||||
)
|
||||
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: dict[str, Any],
|
||||
registry: dict[str, "ToolDefinition"],
|
||||
tenant_id: uuid.UUID,
|
||||
agent_id: uuid.UUID,
|
||||
audit_logger: "AuditLogger",
|
||||
) -> str:
|
||||
"""
|
||||
Execute a tool call from LLM output.
|
||||
|
||||
Args:
|
||||
tool_call: LLM tool call dict. Must contain:
|
||||
{"function": {"name": "...", "arguments": "{ JSON string }"}}
|
||||
registry: Tool registry dict (name → ToolDefinition).
|
||||
tenant_id: Tenant UUID for audit logging.
|
||||
agent_id: Agent UUID for audit logging.
|
||||
audit_logger: AuditLogger instance for recording the invocation.
|
||||
|
||||
Returns:
|
||||
Tool result as a string. Returns an error message string (not an
|
||||
exception) for invalid args, unknown tools, or handler errors.
|
||||
"""
|
||||
# Extract tool name and raw arguments
|
||||
function_data = tool_call.get("function", {})
|
||||
tool_name: str = function_data.get("name", "")
|
||||
raw_arguments: str = function_data.get("arguments", "{}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Look up tool in registry
|
||||
# ------------------------------------------------------------------
|
||||
if tool_name not in registry:
|
||||
error_msg = f"Unknown tool: '{tool_name}'. Available tools: {', '.join(registry.keys())}"
|
||||
logger.warning("execute_tool: %s", error_msg)
|
||||
return error_msg
|
||||
|
||||
tool = registry[tool_name]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Parse JSON arguments (LLM output is untrusted)
|
||||
# ------------------------------------------------------------------
|
||||
try:
|
||||
args: dict[str, Any] = json.loads(raw_arguments)
|
||||
if not isinstance(args, dict):
|
||||
raise ValueError("Tool arguments must be a JSON object")
|
||||
except (json.JSONDecodeError, ValueError) as exc:
|
||||
error_msg = f"Invalid tool arguments for '{tool_name}': {exc}"
|
||||
logger.warning("execute_tool: %s", error_msg)
|
||||
await _log_failure(audit_logger, tool_name, {}, error_msg, tenant_id, agent_id)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Schema validation — reject invalid args before calling handler
|
||||
# ------------------------------------------------------------------
|
||||
validation_error: str | None = _validate_args(args, tool.parameters)
|
||||
if validation_error:
|
||||
error_msg = f"Invalid arguments for '{tool_name}': {validation_error}"
|
||||
logger.warning("execute_tool schema validation failed: %s", error_msg)
|
||||
await _log_failure(audit_logger, tool_name, args, error_msg, tenant_id, agent_id)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Confirmation gate — pause if tool requires user approval
|
||||
# ------------------------------------------------------------------
|
||||
if tool.requires_confirmation:
|
||||
try:
|
||||
args_summary = json.dumps(args, ensure_ascii=False)
|
||||
except Exception:
|
||||
args_summary = repr(args)
|
||||
|
||||
confirmation_msg = _CONFIRMATION_MESSAGE_TEMPLATE.format(
|
||||
tool_name=tool_name,
|
||||
args_summary=args_summary,
|
||||
)
|
||||
# Don't log a tool invocation — this is a confirmation request, not an execution
|
||||
return confirmation_msg
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. Execute the handler
|
||||
# ------------------------------------------------------------------
|
||||
start_ms = time.monotonic()
|
||||
try:
|
||||
result: str = await tool.handler(**args)
|
||||
latency_ms = int((time.monotonic() - start_ms) * 1000)
|
||||
|
||||
await audit_logger.log_tool_call(
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
result=result,
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
latency_ms=latency_ms,
|
||||
error=None,
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as exc:
|
||||
latency_ms = int((time.monotonic() - start_ms) * 1000)
|
||||
error_msg = f"{type(exc).__name__}: {exc}"
|
||||
logger.exception("Tool handler '%s' raised exception: %s", tool_name, error_msg)
|
||||
|
||||
await audit_logger.log_tool_call(
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
result=None,
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
latency_ms=latency_ms,
|
||||
error=error_msg,
|
||||
)
|
||||
return f"Tool '{tool_name}' encountered an error: {error_msg}"
|
||||
|
||||
|
||||
def _validate_args(args: dict[str, Any], schema: dict[str, Any]) -> str | None:
|
||||
"""
|
||||
Validate args against a JSON Schema.
|
||||
|
||||
Returns None if valid, or an error message string if invalid.
|
||||
"""
|
||||
try:
|
||||
jsonschema.validate(instance=args, schema=schema)
|
||||
return None
|
||||
except jsonschema.ValidationError as exc:
|
||||
# Return a concise validation error message
|
||||
return exc.message
|
||||
except jsonschema.SchemaError as exc:
|
||||
logger.error("Tool schema is invalid: %s", exc.message)
|
||||
return f"Schema error: {exc.message}"
|
||||
|
||||
|
||||
async def _log_failure(
|
||||
audit_logger: "AuditLogger",
|
||||
tool_name: str,
|
||||
args: dict[str, Any],
|
||||
error_msg: str,
|
||||
tenant_id: uuid.UUID,
|
||||
agent_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Log a tool invocation failure to the audit trail."""
|
||||
try:
|
||||
await audit_logger.log_tool_call(
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
result=None,
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
latency_ms=0,
|
||||
error=error_msg,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to write audit log for tool failure: %s", tool_name)
|
||||
219
packages/orchestrator/orchestrator/tools/registry.py
Normal file
219
packages/orchestrator/orchestrator/tools/registry.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Tool registry — defines ToolDefinition and the BUILTIN_TOOLS catalog.
|
||||
|
||||
ToolDefinition:
|
||||
A Pydantic model describing a tool's name, description, JSON Schema
|
||||
parameters, whether it requires user confirmation, and its async handler.
|
||||
|
||||
BUILTIN_TOOLS:
|
||||
The four built-in tools available to all agents: web_search, kb_search,
|
||||
http_request, calendar_lookup.
|
||||
|
||||
Usage:
|
||||
from orchestrator.tools.registry import BUILTIN_TOOLS, get_tools_for_agent, to_litellm_format
|
||||
|
||||
agent_tools = get_tools_for_agent(agent)
|
||||
litellm_tools = to_litellm_format(agent_tools)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
"""
|
||||
Describes a tool that an agent can invoke.
|
||||
|
||||
Attributes:
|
||||
name: Unique tool identifier (snake_case).
|
||||
description: Human-readable description for LLM function calling.
|
||||
parameters: JSON Schema object defining accepted arguments.
|
||||
requires_confirmation: If True, pause and ask user before executing.
|
||||
handler: Async callable that executes the tool.
|
||||
Excluded from serialization.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict[str, Any]
|
||||
requires_confirmation: bool = False
|
||||
handler: Callable[..., Any]
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
# Exclude handler from .model_dump() / .model_json_schema()
|
||||
# since callables are not JSON-serializable
|
||||
)
|
||||
|
||||
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Exclude handler from dict output."""
|
||||
kwargs.setdefault("exclude", set())
|
||||
if isinstance(kwargs["exclude"], set):
|
||||
kwargs["exclude"] = kwargs["exclude"] | {"handler"}
|
||||
return super().model_dump(**kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import builtin handlers (lazy — avoids import errors for missing optional deps)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from orchestrator.tools.builtins.web_search import web_search as _web_search_handler
|
||||
from orchestrator.tools.builtins.kb_search import kb_search as _kb_search_handler
|
||||
from orchestrator.tools.builtins.http_request import http_request as _http_request_handler
|
||||
from orchestrator.tools.builtins.calendar_lookup import calendar_lookup as _calendar_lookup_handler
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BUILTIN_TOOLS — the four built-in tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BUILTIN_TOOLS: dict[str, ToolDefinition] = {
|
||||
"web_search": ToolDefinition(
|
||||
name="web_search",
|
||||
description=(
|
||||
"Search the web using Brave Search and return the top results. "
|
||||
"Use this to find current information, news, or facts not in your training data."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query string.",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_web_search_handler,
|
||||
),
|
||||
"kb_search": ToolDefinition(
|
||||
name="kb_search",
|
||||
description=(
|
||||
"Search the tenant's knowledge base using semantic similarity. "
|
||||
"Use this to find relevant internal documentation, policies, or FAQs."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query to find relevant knowledge base content.",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_kb_search_handler,
|
||||
),
|
||||
"http_request": ToolDefinition(
|
||||
name="http_request",
|
||||
description=(
|
||||
"Make an outbound HTTP request to an external API or URL. "
|
||||
"Supports GET, POST, PUT, DELETE. Response is capped at 1MB."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The target URL.",
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"enum": ["GET", "POST", "PUT", "DELETE"],
|
||||
"description": "HTTP method. Defaults to GET.",
|
||||
},
|
||||
"body": {
|
||||
"type": "string",
|
||||
"description": "Request body as a string (for POST/PUT).",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
requires_confirmation=True, # Outbound requests always require user approval
|
||||
handler=_http_request_handler,
|
||||
),
|
||||
"calendar_lookup": ToolDefinition(
|
||||
name="calendar_lookup",
|
||||
description=(
|
||||
"Look up calendar events for a specific date. "
|
||||
"Returns availability and scheduled events from Google Calendar."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"date": {
|
||||
"type": "string",
|
||||
"description": "The date to check in YYYY-MM-DD format.",
|
||||
},
|
||||
"calendar_id": {
|
||||
"type": "string",
|
||||
"description": "Google Calendar ID. Defaults to 'primary'.",
|
||||
},
|
||||
},
|
||||
"required": ["date"],
|
||||
},
|
||||
requires_confirmation=False, # Read-only calendar lookup
|
||||
handler=_calendar_lookup_handler,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_tools_for_agent(agent: Any) -> dict[str, ToolDefinition]:
|
||||
"""
|
||||
Filter BUILTIN_TOOLS to only those assigned to the given agent.
|
||||
|
||||
Args:
|
||||
agent: An Agent ORM instance with a tool_assignments list attribute.
|
||||
|
||||
Returns:
|
||||
Dict mapping tool name → ToolDefinition for each tool in
|
||||
agent.tool_assignments that exists in BUILTIN_TOOLS. Tools that are
|
||||
assigned but not in BUILTIN_TOOLS are silently ignored.
|
||||
"""
|
||||
assigned: list[str] = agent.tool_assignments or []
|
||||
return {name: tool for name, tool in BUILTIN_TOOLS.items() if name in assigned}
|
||||
|
||||
|
||||
def to_litellm_format(tools: dict[str, ToolDefinition]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert a tool registry dict to OpenAI function-calling schema.
|
||||
|
||||
This is the format expected by LiteLLM's `tools` parameter in
|
||||
acompletion() calls, which follows the OpenAI function-calling spec:
|
||||
|
||||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "...",
|
||||
"description": "...",
|
||||
"parameters": { ... JSON Schema ... }
|
||||
}
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
Args:
|
||||
tools: Dict of tool name → ToolDefinition.
|
||||
|
||||
Returns:
|
||||
List of OpenAI-format tool dicts ready to pass to LiteLLM.
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
},
|
||||
}
|
||||
for tool in tools.values()
|
||||
]
|
||||
@@ -13,6 +13,7 @@ dependencies = [
|
||||
"celery[redis]>=5.4.0",
|
||||
"httpx>=0.28.0",
|
||||
"sentence-transformers>=3.0.0",
|
||||
"jsonschema>=4.26.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
|
||||
Reference in New Issue
Block a user