From 28a5ee996e9ceb9194208ef53ad2f742e8a96621 Mon Sep 17 00:00:00 2001 From: Adolfo Delorenzo Date: Mon, 23 Mar 2026 14:41:57 -0600 Subject: [PATCH] =?UTF-8?q?feat(02-01):=20add=20two-layer=20memory=20syste?= =?UTF-8?q?m=20=E2=80=94=20Redis=20sliding=20window=20+=20pgvector=20long-?= =?UTF-8?q?term?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ConversationEmbedding ORM model with Vector(384) column (pgvector) - memory_short_key, escalation_status_key, pending_tool_confirm_key in redis_keys.py - orchestrator/memory/short_term.py: RPUSH/LTRIM sliding window (get_recent_messages, append_message) - orchestrator/memory/long_term.py: pgvector HNSW cosine search (retrieve_relevant, store_embedding) - Migration 002: conversation_embeddings table, HNSW index, RLS with FORCE, SELECT/INSERT only - 10 unit tests (fakeredis), 6 integration tests (pgvector) — all passing - Auto-fix [Rule 3]: postgres image updated to pgvector/pgvector:pg16 (extension required) --- docker-compose.yml | 2 +- migrations/versions/002_phase2_memory.py | 146 ++++++++++ .../orchestrator/memory/__init__.py | 22 ++ .../orchestrator/memory/long_term.py | 164 +++++++++++ .../orchestrator/memory/short_term.py | 111 ++++++++ packages/orchestrator/pyproject.toml | 1 + packages/shared/pyproject.toml | 1 + packages/shared/shared/models/memory.py | 96 +++++++ packages/shared/shared/redis_keys.py | 58 ++++ tests/integration/test_memory_long_term.py | 259 ++++++++++++++++++ tests/unit/test_memory_short_term.py | 139 ++++++++++ 11 files changed, 998 insertions(+), 1 deletion(-) create mode 100644 migrations/versions/002_phase2_memory.py create mode 100644 packages/orchestrator/orchestrator/memory/__init__.py create mode 100644 packages/orchestrator/orchestrator/memory/long_term.py create mode 100644 packages/orchestrator/orchestrator/memory/short_term.py create mode 100644 packages/shared/shared/models/memory.py create mode 100644 tests/integration/test_memory_long_term.py create mode 100644 tests/unit/test_memory_short_term.py diff --git a/docker-compose.yml b/docker-compose.yml index ce7646f..6493777 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,7 +11,7 @@ volumes: services: postgres: - image: postgres:16-alpine + image: pgvector/pgvector:pg16 container_name: konstruct-postgres environment: POSTGRES_DB: konstruct diff --git a/migrations/versions/002_phase2_memory.py b/migrations/versions/002_phase2_memory.py new file mode 100644 index 0000000..179f276 --- /dev/null +++ b/migrations/versions/002_phase2_memory.py @@ -0,0 +1,146 @@ +"""Phase 2: conversation_embeddings table with HNSW index and RLS + +Revision ID: 002 +Revises: 001 +Create Date: 2026-03-23 + +This migration adds the conversation_embeddings table for the long-term +conversational memory system. It stores pgvector embeddings of past +conversation turns for semantic similarity retrieval. + +Key design decisions: +1. pgvector extension is enabled (CREATE EXTENSION IF NOT EXISTS vector) +2. HNSW index with m=16, ef_construction=64 for approximate nearest neighbor + search — cosine distance operator (vector_cosine_ops) +3. Covering index on (tenant_id, agent_id, user_id, created_at DESC) for + pre-filtering before ANN search +4. RLS with FORCE — tenant_id isolation enforced at DB level +5. GRANT SELECT, INSERT only — embeddings are immutable (no UPDATE/DELETE) + This models conversation history as an append-only audit log +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import UUID + + +# revision identifiers, used by Alembic. +revision: str = "002" +down_revision: Union[str, None] = "001" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ------------------------------------------------------------------------- + # 1. Enable pgvector extension (idempotent) + # ------------------------------------------------------------------------- + op.execute("CREATE EXTENSION IF NOT EXISTS vector") + + # ------------------------------------------------------------------------- + # 2. Create conversation_embeddings table + # ------------------------------------------------------------------------- + op.create_table( + "conversation_embeddings", + sa.Column( + "id", + UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), + sa.Column( + "tenant_id", + UUID(as_uuid=True), + sa.ForeignKey("tenants.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "agent_id", + UUID(as_uuid=True), + nullable=False, + ), + sa.Column( + "user_id", + sa.Text, + nullable=False, + comment="Channel-native user identifier (e.g. Slack user ID U12345)", + ), + sa.Column( + "content", + sa.Text, + nullable=False, + comment="Original message text that was embedded", + ), + sa.Column( + "role", + sa.Text, + nullable=False, + comment="Message role: 'user' or 'assistant'", + ), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("NOW()"), + ), + # The embedding column uses pgvector type — created via raw DDL below + # because SQLAlchemy doesn't know the 'vector' type without pgvector extension + ) + + # Add embedding column as vector(384) — must be raw DDL for pgvector type + op.execute( + "ALTER TABLE conversation_embeddings " + "ADD COLUMN embedding vector(384) NOT NULL" + ) + + # ------------------------------------------------------------------------- + # 3. Create covering index for pre-filter (tenant + agent + user + time) + # Used to scope queries before the ANN operator for isolation + performance + # ------------------------------------------------------------------------- + op.create_index( + "ix_conv_embed_tenant_agent_user_time", + "conversation_embeddings", + ["tenant_id", "agent_id", "user_id", "created_at"], + postgresql_ops={"created_at": "DESC"}, + ) + + # ------------------------------------------------------------------------- + # 4. Create HNSW index for approximate nearest neighbor cosine search + # m=16: number of bidirectional links per node (quality vs. memory tradeoff) + # ef_construction=64: search width during build (quality vs. speed) + # vector_cosine_ops: uses cosine distance (compatible with <=> operator) + # ------------------------------------------------------------------------- + op.execute(""" + CREATE INDEX ix_conv_embed_hnsw + ON conversation_embeddings + USING hnsw (embedding vector_cosine_ops) + WITH (m = 16, ef_construction = 64) + """) + + # ------------------------------------------------------------------------- + # 5. Apply Row Level Security + # FORCE ensures even the table owner cannot bypass tenant isolation + # ------------------------------------------------------------------------- + op.execute("ALTER TABLE conversation_embeddings ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE conversation_embeddings FORCE ROW LEVEL SECURITY") + op.execute(""" + CREATE POLICY tenant_isolation ON conversation_embeddings + USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid) + """) + + # ------------------------------------------------------------------------- + # 6. Grant permissions to konstruct_app role + # SELECT + INSERT only — embeddings are immutable (no UPDATE or DELETE) + # This models conversation history as an append-only audit log + # ------------------------------------------------------------------------- + op.execute("GRANT SELECT, INSERT ON conversation_embeddings TO konstruct_app") + + +def downgrade() -> None: + op.execute("REVOKE ALL ON conversation_embeddings FROM konstruct_app") + op.drop_table("conversation_embeddings") + # Note: We do NOT drop the vector extension — other tables may use it diff --git a/packages/orchestrator/orchestrator/memory/__init__.py b/packages/orchestrator/orchestrator/memory/__init__.py new file mode 100644 index 0000000..51be918 --- /dev/null +++ b/packages/orchestrator/orchestrator/memory/__init__.py @@ -0,0 +1,22 @@ +""" +Konstruct Agent Memory Layer. + +Two-layer conversational memory system: + +1. Short-term (Redis sliding window): + - Stores the last N messages verbatim + - Zero latency — Redis is always available + - Provides immediate in-session context continuity + - See: short_term.py + +2. Long-term (pgvector HNSW similarity search): + - Stores all messages as semantic embeddings + - Retrieves top-K semantically relevant past exchanges + - Provides cross-session recall (user preferences, past issues, etc.) + - Embedding model: all-MiniLM-L6-v2 (384 dimensions) + - See: long_term.py + +Memory scoping: All operations are keyed by (tenant_id, agent_id, user_id). +This ensures complete isolation — no cross-tenant, cross-agent, or cross-user +contamination is possible. +""" diff --git a/packages/orchestrator/orchestrator/memory/long_term.py b/packages/orchestrator/orchestrator/memory/long_term.py new file mode 100644 index 0000000..3e6b6c3 --- /dev/null +++ b/packages/orchestrator/orchestrator/memory/long_term.py @@ -0,0 +1,164 @@ +""" +pgvector-backed long-term conversational memory. + +Stores conversation turns as 384-dimensional embeddings (all-MiniLM-L6-v2) +and retrieves semantically relevant past exchanges using HNSW cosine similarity +search. + +CRITICAL SECURITY CONSTRAINTS: +1. ALL queries MUST pre-filter by (tenant_id, agent_id, user_id) BEFORE the + ANN operator. This prevents cross-tenant, cross-agent, or cross-user data + leakage even in the face of embedding collisions. +2. Cosine similarity threshold filters out low-relevance results — only content + genuinely related to the query should be injected into the LLM prompt. +3. RLS (Row Level Security) is the DB-level backstop — the application-level + filters above are the primary guard; RLS is the safety net. + +pgvector cosine operations: +- <=> operator: cosine DISTANCE (0 = identical, 2 = opposite) +- cosine similarity = 1 - cosine distance +- A threshold of 0.75 means: only return results where + 1 - (embedding <=> query) >= 0.75 → distance <= 0.25 +""" + +from __future__ import annotations + +import logging +import uuid + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession + +logger = logging.getLogger(__name__) + + +async def retrieve_relevant( + session: AsyncSession, + tenant_id: uuid.UUID, + agent_id: uuid.UUID, + user_id: str, + query_embedding: list[float], + top_k: int = 3, + threshold: float = 0.75, +) -> list[str]: + """ + Retrieve semantically relevant past conversation content. + + Performs an HNSW approximate nearest neighbor search scoped strictly to + (tenant_id, agent_id, user_id). Results below the cosine similarity + threshold are discarded. + + IMPORTANT: The (tenant_id, agent_id, user_id) pre-filter is applied + BEFORE the ANN operator to guarantee isolation. This is not optional — + removing these WHERE clauses would allow cross-tenant data leakage. + + Args: + session: Async SQLAlchemy session (must have RLS configured). + tenant_id: Tenant UUID — mandatory for isolation. + agent_id: Agent UUID — mandatory for isolation. + user_id: End-user identifier — mandatory for isolation. + query_embedding: 384-dimensional query vector (all-MiniLM-L6-v2). + top_k: Maximum number of results to return. Default 3. + threshold: Minimum cosine similarity (0.0–1.0). Default 0.75. + Set lower for broader recall, higher for precision. + + Returns: + List of content strings (original message text), most relevant first. + Returns empty list if no results meet the threshold. + """ + # Convert embedding list to pgvector string format: '[0.1, 0.2, ...]' + vec_str = "[" + ",".join(str(float(v)) for v in query_embedding) + "]" + + # CRITICAL: pre-filter by all three isolation columns BEFORE ANN search. + # The ORDER BY uses <=> (cosine distance) — lower is more similar. + # We convert to similarity (1 - distance) to apply the threshold filter. + stmt = text(""" + SELECT content, 1 - (embedding <=> CAST(:query AS vector)) AS similarity + FROM conversation_embeddings + WHERE tenant_id = :tenant_id + AND agent_id = :agent_id + AND user_id = :user_id + AND 1 - (embedding <=> CAST(:query AS vector)) >= :threshold + ORDER BY embedding <=> CAST(:query AS vector) + LIMIT :top_k + """) + + try: + result = await session.execute( + stmt, + { + "query": vec_str, + "tenant_id": str(tenant_id), + "agent_id": str(agent_id), + "user_id": user_id, + "threshold": threshold, + "top_k": top_k, + }, + ) + rows = result.fetchall() + except Exception: + logger.exception( + "pgvector retrieve_relevant failed for tenant=%s agent=%s user=%s", + tenant_id, + agent_id, + user_id, + ) + return [] + + return [row.content for row in rows] + + +async def store_embedding( + session: AsyncSession, + tenant_id: uuid.UUID, + agent_id: uuid.UUID, + user_id: str, + content: str, + role: str, + embedding: list[float], +) -> None: + """ + Store a conversation turn embedding in the database. + + Inserts a new row into conversation_embeddings. Embeddings are immutable + once stored — there is no UPDATE path. This matches the audit-log-like + nature of conversation history. + + Args: + session: Async SQLAlchemy session (must have RLS configured). + tenant_id: Tenant UUID for isolation. + agent_id: Agent UUID for isolation. + user_id: End-user identifier for isolation. + content: Original message text. + role: "user" or "assistant". + embedding: 384-dimensional float list (all-MiniLM-L6-v2). + """ + vec_str = "[" + ",".join(str(float(v)) for v in embedding) + "]" + + stmt = text(""" + INSERT INTO conversation_embeddings + (id, tenant_id, agent_id, user_id, content, role, embedding) + VALUES + (gen_random_uuid(), :tenant_id, :agent_id, :user_id, :content, :role, CAST(:embedding AS vector)) + """) + + try: + await session.execute( + stmt, + { + "tenant_id": str(tenant_id), + "agent_id": str(agent_id), + "user_id": user_id, + "content": content, + "role": role, + "embedding": vec_str, + }, + ) + except Exception: + logger.exception( + "pgvector store_embedding failed for tenant=%s agent=%s user=%s", + tenant_id, + agent_id, + user_id, + ) + raise diff --git a/packages/orchestrator/orchestrator/memory/short_term.py b/packages/orchestrator/orchestrator/memory/short_term.py new file mode 100644 index 0000000..efd4621 --- /dev/null +++ b/packages/orchestrator/orchestrator/memory/short_term.py @@ -0,0 +1,111 @@ +""" +Redis sliding window for short-term conversational memory. + +Implements a RPUSH + LTRIM pattern: +- RPUSH appends new messages to the right (tail) of the list +- LTRIM trims the list to the last `window` entries +- LRANGE retrieves all current entries + +This gives O(1) append + O(1) trim + O(N) read where N <= window size. + +Key format: {tenant_id}:memory:short:{agent_id}:{user_id} + +Messages are stored as JSON objects with "role" and "content" keys, +matching the OpenAI chat messages format for direct injection into +the LLM messages array. + +Design decisions: +- No TTL: message retention is indefinite per user preference. If TTL-based + expiry is needed in the future, add it via a separate expiry policy. +- No compression: messages are stored as plain JSON. At 20 messages * ~200 + bytes average, storage per user/agent is ~4KB — negligible. +- Parameterized window: callers control the window size, defaulting to 20. + This allows future policy changes without code modification. +""" + +from __future__ import annotations + +import json +import logging + +from shared.redis_keys import memory_short_key + +logger = logging.getLogger(__name__) + + +async def get_recent_messages( + redis: object, + tenant_id: str, + agent_id: str, + user_id: str, + n: int = 20, +) -> list[dict[str, str]]: + """ + Retrieve the most recent N messages from the sliding window. + + Returns messages in insertion order (oldest first) — this matches the + expected LLM message array format where conversation flows chronologically. + + Args: + redis: Redis async client (redis.asyncio.Redis or compatible). + tenant_id: Konstruct tenant identifier. + agent_id: Agent UUID string. + user_id: End-user identifier (channel-native). + n: Maximum number of messages to retrieve. Default 20. + Pass a larger value than the window size to get all messages. + + Returns: + List of message dicts with "role" and "content" keys, oldest first. + Returns empty list if no messages exist for this key. + """ + key = memory_short_key(tenant_id, agent_id, user_id) + + # LRANGE -n -1 returns the last n items in insertion order + raw_messages = await redis.lrange(key, -n, -1) # type: ignore[union-attr] + + messages: list[dict[str, str]] = [] + for raw in raw_messages: + try: + msg = json.loads(raw) + messages.append({"role": str(msg["role"]), "content": str(msg["content"])}) + except (json.JSONDecodeError, KeyError): + logger.warning("Malformed message in Redis key %s — skipping", key) + + return messages + + +async def append_message( + redis: object, + tenant_id: str, + agent_id: str, + user_id: str, + role: str, + content: str, + window: int = 20, +) -> None: + """ + Append a message to the sliding window and trim to window size. + + Uses a pipeline to make RPUSH + LTRIM atomic — no race condition + between append and trim even under concurrent writes. + + Args: + redis: Redis async client. + tenant_id: Konstruct tenant identifier. + agent_id: Agent UUID string. + user_id: End-user identifier (channel-native). + role: Message role: "user" or "assistant". + content: Message text content. + window: Maximum number of messages to retain. Default 20. + After this operation the list will contain at most `window` + entries (the most recent ones). + """ + key = memory_short_key(tenant_id, agent_id, user_id) + serialized = json.dumps({"role": role, "content": content}) + + # Pipeline ensures RPUSH + LTRIM are sent atomically + pipe = redis.pipeline() # type: ignore[union-attr] + pipe.rpush(key, serialized) + # LTRIM to last `window` entries: keep index -(window) through -1 + pipe.ltrim(key, -window, -1) + await pipe.execute() diff --git a/packages/orchestrator/pyproject.toml b/packages/orchestrator/pyproject.toml index 7223bea..69a05a8 100644 --- a/packages/orchestrator/pyproject.toml +++ b/packages/orchestrator/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "fastapi[standard]>=0.115.0", "celery[redis]>=5.4.0", "httpx>=0.28.0", + "sentence-transformers>=3.0.0", ] [tool.uv.sources] diff --git a/packages/shared/pyproject.toml b/packages/shared/pyproject.toml index 2c7592c..fb51dd8 100644 --- a/packages/shared/pyproject.toml +++ b/packages/shared/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "httpx>=0.28.0", "slowapi>=0.1.9", "bcrypt>=4.0.0", + "pgvector>=0.3.0", ] [tool.hatch.build.targets.wheel] diff --git a/packages/shared/shared/models/memory.py b/packages/shared/shared/models/memory.py new file mode 100644 index 0000000..c437407 --- /dev/null +++ b/packages/shared/shared/models/memory.py @@ -0,0 +1,96 @@ +""" +SQLAlchemy 2.0 ORM models for conversational memory. + +ConversationEmbedding stores pgvector embeddings of past conversation turns +for long-term semantic retrieval across sessions. This is the persistence layer +for the long-term memory module in the Agent Orchestrator. + +IMPORTANT: +- Embeddings are immutable (no UPDATE) — like audit records. We store and read + but never modify. This simplifies the data model and prevents mutation bugs. +- RLS is ENABLED with FORCE — tenant_id isolation is enforced at the DB level. +- The vector dimension (384) corresponds to all-MiniLM-L6-v2 output size. +""" + +from __future__ import annotations + +import uuid +from datetime import datetime + +from pgvector.sqlalchemy import Vector +from sqlalchemy import DateTime, ForeignKey, Text, func +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from shared.models.tenant import Base + + +class ConversationEmbedding(Base): + """ + A single embedded conversation turn stored for long-term recall. + + Each row represents one message (user or assistant) converted to a + 384-dimensional embedding via all-MiniLM-L6-v2. The Agent Orchestrator + queries this table at prompt assembly time to inject relevant past context. + + Scoped by: + - tenant_id: RLS enforced isolation between tenants + - agent_id: isolation between agents within a tenant + - user_id: isolation between end-users of the same agent + + RLS policy enforces: + tenant_id = current_setting('app.current_tenant', TRUE)::uuid + + FORCE ROW LEVEL SECURITY ensures even the table owner cannot bypass this. + """ + + __tablename__ = "conversation_embeddings" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + tenant_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("tenants.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + agent_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + nullable=False, + index=True, + ) + user_id: Mapped[str] = mapped_column( + Text, + nullable=False, + comment="Channel-native user identifier (e.g. Slack user ID U12345)", + ) + content: Mapped[str] = mapped_column( + Text, + nullable=False, + comment="Original message text that was embedded", + ) + role: Mapped[str] = mapped_column( + Text, + nullable=False, + comment="Message role: 'user' or 'assistant'", + ) + embedding: Mapped[list[float]] = mapped_column( + Vector(384), + nullable=False, + comment="all-MiniLM-L6-v2 embedding (384 dimensions)", + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ) + + def __repr__(self) -> str: + return ( + f"" + ) diff --git a/packages/shared/shared/redis_keys.py b/packages/shared/shared/redis_keys.py index 7a8db8a..fbce779 100644 --- a/packages/shared/shared/redis_keys.py +++ b/packages/shared/shared/redis_keys.py @@ -86,3 +86,61 @@ def engaged_thread_key(tenant_id: str, thread_id: str) -> str: Namespaced Redis key: "{tenant_id}:engaged:{thread_id}" """ return f"{tenant_id}:engaged:{thread_id}" + + +def memory_short_key(tenant_id: str, agent_id: str, user_id: str) -> str: + """ + Redis key for the short-term conversational memory sliding window. + + Stores the last N messages (serialized as JSON) for a specific + tenant + agent + user combination. Used by the Agent Orchestrator to + inject recent conversation history into every LLM prompt. + + Key includes all three discriminators to ensure: + - Two users talking to the same agent have separate histories + - The same user talking to two different agents has separate histories + - Two tenants with the same agent/user IDs are fully isolated + + Args: + tenant_id: Konstruct tenant identifier. + agent_id: Agent identifier (UUID string). + user_id: End-user identifier (channel-native, e.g. Slack user ID). + + Returns: + Namespaced Redis key: "{tenant_id}:memory:short:{agent_id}:{user_id}" + """ + return f"{tenant_id}:memory:short:{agent_id}:{user_id}" + + +def escalation_status_key(tenant_id: str, thread_id: str) -> str: + """ + Redis key for tracking escalation status of a thread. + + Stores the current escalation state for a conversation thread — + whether it has been escalated to a human or another agent. + + Args: + tenant_id: Konstruct tenant identifier. + thread_id: Thread identifier. + + Returns: + Namespaced Redis key: "{tenant_id}:escalation:{thread_id}" + """ + return f"{tenant_id}:escalation:{thread_id}" + + +def pending_tool_confirm_key(tenant_id: str, thread_id: str) -> str: + """ + Redis key for tracking pending tool confirmation requests. + + Stores the pending tool invocation that requires explicit user + confirmation before execution (e.g. destructive operations). + + Args: + tenant_id: Konstruct tenant identifier. + thread_id: Thread identifier. + + Returns: + Namespaced Redis key: "{tenant_id}:tool_confirm:{thread_id}" + """ + return f"{tenant_id}:tool_confirm:{thread_id}" diff --git a/tests/integration/test_memory_long_term.py b/tests/integration/test_memory_long_term.py new file mode 100644 index 0000000..a6a4145 --- /dev/null +++ b/tests/integration/test_memory_long_term.py @@ -0,0 +1,259 @@ +""" +Integration tests for pgvector long-term memory. + +Requires a live PostgreSQL instance with pgvector extension installed. +Tests are automatically skipped if the database is not available +(fixture from conftest.py handles that via pytest.skip). + +Key scenarios tested: +- store_embedding inserts with correct scoping +- retrieve_relevant returns matching content above threshold +- Cross-tenant isolation: tenant A's embeddings never returned for tenant B +- High threshold returns empty list for dissimilar queries +""" + +from __future__ import annotations + +import uuid + +import pytest +import pytest_asyncio +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession + +from orchestrator.memory.long_term import retrieve_relevant, store_embedding + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def agent_a_id() -> uuid.UUID: + """Return a stable agent UUID for tenant A tests.""" + return uuid.UUID("aaaaaaaa-0000-0000-0000-000000000001") + + +@pytest_asyncio.fixture +async def agent_b_id() -> uuid.UUID: + """Return a stable agent UUID for tenant B tests.""" + return uuid.UUID("bbbbbbbb-0000-0000-0000-000000000002") + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +async def test_store_embedding_inserts_row( + db_session: AsyncSession, + tenant_a: dict, + agent_a_id: uuid.UUID, +): + """store_embedding inserts a row into conversation_embeddings.""" + from shared.rls import current_tenant_id + + tenant_id = tenant_a["id"] + user_id = "user-store-test" + embedding = [0.1] * 384 + content = "I prefer concise answers." + + token = current_tenant_id.set(tenant_id) + try: + await store_embedding(db_session, tenant_id, agent_a_id, user_id, content, "user", embedding) + await db_session.commit() + + result = await db_session.execute( + text("SELECT content, role FROM conversation_embeddings WHERE tenant_id = :tid AND user_id = :uid"), + {"tid": str(tenant_id), "uid": user_id}, + ) + rows = result.fetchall() + finally: + current_tenant_id.reset(token) + + assert len(rows) == 1 + assert rows[0].content == content + assert rows[0].role == "user" + + +async def test_retrieve_relevant_returns_similar_content( + db_session: AsyncSession, + tenant_a: dict, + agent_a_id: uuid.UUID, +): + """retrieve_relevant returns content above cosine similarity threshold.""" + from shared.rls import current_tenant_id + + tenant_id = tenant_a["id"] + user_id = "user-retrieve-test" + + # Store two embeddings: one very similar to the query, one dissimilar + # We simulate similarity by using identical embeddings + similar_embedding = [1.0] + [0.0] * 383 + dissimilar_embedding = [0.0] * 383 + [1.0] + query_embedding = [1.0] + [0.0] * 383 # identical to similar_embedding + + token = current_tenant_id.set(tenant_id) + try: + await store_embedding( + db_session, tenant_id, agent_a_id, user_id, + "The user likes Python programming.", "user", similar_embedding + ) + await store_embedding( + db_session, tenant_id, agent_a_id, user_id, + "This is completely unrelated content.", "user", dissimilar_embedding + ) + await db_session.commit() + + results = await retrieve_relevant( + db_session, tenant_id, agent_a_id, user_id, query_embedding, top_k=3, threshold=0.5 + ) + finally: + current_tenant_id.reset(token) + + # Should return the similar content + assert len(results) >= 1 + assert any("Python" in r for r in results) + + +async def test_retrieve_relevant_high_threshold_returns_empty( + db_session: AsyncSession, + tenant_a: dict, + agent_a_id: uuid.UUID, +): + """retrieve_relevant with threshold=0.99 and dissimilar query returns empty list.""" + from shared.rls import current_tenant_id + + tenant_id = tenant_a["id"] + user_id = "user-threshold-test" + + # Store an embedding pointing in one direction + stored_embedding = [1.0] + [0.0] * 383 + # Query pointing in orthogonal direction — cosine distance ~= 1.0, similarity ~= 0.0 + query_embedding = [0.0] + [1.0] + [0.0] * 382 + + token = current_tenant_id.set(tenant_id) + try: + await store_embedding( + db_session, tenant_id, agent_a_id, user_id, + "Some stored content.", "user", stored_embedding + ) + await db_session.commit() + + results = await retrieve_relevant( + db_session, tenant_id, agent_a_id, user_id, query_embedding, top_k=3, threshold=0.99 + ) + finally: + current_tenant_id.reset(token) + + assert results == [] + + +async def test_cross_tenant_isolation( + db_session: AsyncSession, + tenant_a: dict, + tenant_b: dict, + agent_a_id: uuid.UUID, + agent_b_id: uuid.UUID, +): + """ + retrieve_relevant with tenant_id=A NEVER returns tenant_id=B embeddings. + + This is the critical security test — cross-tenant contamination would be + a catastrophic data leak in a multi-tenant system. + """ + from shared.rls import current_tenant_id + + user_id = "shared-user-id" + tenant_a_id = tenant_a["id"] + tenant_b_id = tenant_b["id"] + + # Same query embedding for both tenants + embedding = [1.0] + [0.0] * 383 + + # Store embedding for tenant B + token = current_tenant_id.set(tenant_b_id) + try: + await store_embedding( + db_session, tenant_b_id, agent_b_id, user_id, + "Tenant B secret information.", "user", embedding + ) + await db_session.commit() + finally: + current_tenant_id.reset(token) + + # Query as tenant A — should NOT see tenant B's data + token = current_tenant_id.set(tenant_a_id) + try: + results = await retrieve_relevant( + db_session, tenant_a_id, agent_a_id, user_id, embedding, top_k=10, threshold=0.0 + ) + finally: + current_tenant_id.reset(token) + + # Tenant A should get nothing — it has no embeddings of its own + # and it MUST NOT see tenant B's embeddings + for result in results: + assert "Tenant B" not in result, "Cross-tenant data leakage detected!" + + +async def test_retrieve_relevant_user_isolation( + db_session: AsyncSession, + tenant_a: dict, + agent_a_id: uuid.UUID, +): + """retrieve_relevant for user A never returns user B embeddings.""" + from shared.rls import current_tenant_id + + tenant_id = tenant_a["id"] + embedding = [1.0] + [0.0] * 383 + + token = current_tenant_id.set(tenant_id) + try: + await store_embedding( + db_session, tenant_id, agent_a_id, "user-A", + "User A private information.", "user", embedding + ) + await db_session.commit() + + # Query as user B — should not see user A's data + results = await retrieve_relevant( + db_session, tenant_id, agent_a_id, "user-B", embedding, top_k=10, threshold=0.0 + ) + finally: + current_tenant_id.reset(token) + + for result in results: + assert "User A private" not in result + + +async def test_retrieve_relevant_top_k_limits_results( + db_session: AsyncSession, + tenant_a: dict, + agent_a_id: uuid.UUID, +): + """retrieve_relevant respects top_k limit.""" + from shared.rls import current_tenant_id + + tenant_id = tenant_a["id"] + user_id = "user-topk-test" + embedding = [1.0] + [0.0] * 383 + + token = current_tenant_id.set(tenant_id) + try: + # Store 5 very similar embeddings + for i in range(5): + await store_embedding( + db_session, tenant_id, agent_a_id, user_id, + f"Content item {i}", "user", embedding + ) + await db_session.commit() + + results = await retrieve_relevant( + db_session, tenant_id, agent_a_id, user_id, embedding, top_k=2, threshold=0.0 + ) + finally: + current_tenant_id.reset(token) + + assert len(results) <= 2 diff --git a/tests/unit/test_memory_short_term.py b/tests/unit/test_memory_short_term.py new file mode 100644 index 0000000..3b670dd --- /dev/null +++ b/tests/unit/test_memory_short_term.py @@ -0,0 +1,139 @@ +""" +Unit tests for the Redis short-term memory sliding window. + +Uses fakeredis to avoid requiring a real Redis connection. +All tests verify tenant+agent+user namespacing and RPUSH/LTRIM correctness. +""" + +from __future__ import annotations + +import json + +import fakeredis.aioredis +import pytest + +from orchestrator.memory.short_term import append_message, get_recent_messages + + +@pytest.fixture +async def redis(): + """Return a fakeredis async client for testing.""" + client = fakeredis.aioredis.FakeRedis() + yield client + await client.aclose() + + +TENANT = "tenant-abc" +AGENT = "agent-xyz" +USER = "user-123" + + +async def test_get_recent_messages_empty(redis): + """get_recent_messages on empty key returns empty list.""" + result = await get_recent_messages(redis, TENANT, AGENT, USER) + assert result == [] + + +async def test_append_and_get_single_message(redis): + """append_message stores a message; get_recent_messages retrieves it.""" + await append_message(redis, TENANT, AGENT, USER, role="user", content="Hello!") + result = await get_recent_messages(redis, TENANT, AGENT, USER) + assert len(result) == 1 + assert result[0] == {"role": "user", "content": "Hello!"} + + +async def test_append_multiple_messages_ordering(redis): + """Messages are returned in insertion order (oldest first).""" + await append_message(redis, TENANT, AGENT, USER, role="user", content="First") + await append_message(redis, TENANT, AGENT, USER, role="assistant", content="Second") + await append_message(redis, TENANT, AGENT, USER, role="user", content="Third") + + result = await get_recent_messages(redis, TENANT, AGENT, USER) + assert len(result) == 3 + assert result[0]["content"] == "First" + assert result[1]["content"] == "Second" + assert result[2]["content"] == "Third" + + +async def test_sliding_window_trims_to_window_size(redis): + """append_message with window=5 keeps only last 5 messages.""" + for i in range(10): + await append_message(redis, TENANT, AGENT, USER, role="user", content=f"msg-{i}", window=5) + + result = await get_recent_messages(redis, TENANT, AGENT, USER, n=20) + assert len(result) == 5 + # Should have the last 5 messages: msg-5 through msg-9 + contents = [m["content"] for m in result] + assert contents == ["msg-5", "msg-6", "msg-7", "msg-8", "msg-9"] + + +async def test_default_window_20(redis): + """Default window is 20 — 21st message pushes out the first.""" + for i in range(21): + await append_message(redis, TENANT, AGENT, USER, role="user", content=f"msg-{i}") + + result = await get_recent_messages(redis, TENANT, AGENT, USER) + assert len(result) == 20 + assert result[0]["content"] == "msg-1" + assert result[-1]["content"] == "msg-20" + + +async def test_get_recent_messages_n_parameter(redis): + """get_recent_messages n parameter limits results.""" + for i in range(10): + await append_message(redis, TENANT, AGENT, USER, role="user", content=f"msg-{i}") + + result = await get_recent_messages(redis, TENANT, AGENT, USER, n=3) + assert len(result) == 3 + # n=3 returns last 3: msg-7, msg-8, msg-9 + contents = [m["content"] for m in result] + assert contents == ["msg-7", "msg-8", "msg-9"] + + +async def test_key_namespacing_user_isolation(redis): + """Different users of the same agent have isolated memory.""" + await append_message(redis, TENANT, AGENT, "user-A", role="user", content="User A message") + await append_message(redis, TENANT, AGENT, "user-B", role="user", content="User B message") + + result_a = await get_recent_messages(redis, TENANT, AGENT, "user-A") + result_b = await get_recent_messages(redis, TENANT, AGENT, "user-B") + + assert len(result_a) == 1 + assert result_a[0]["content"] == "User A message" + assert len(result_b) == 1 + assert result_b[0]["content"] == "User B message" + + +async def test_key_namespacing_tenant_isolation(redis): + """Different tenants with same agent+user IDs have isolated memory.""" + await append_message(redis, "tenant-1", AGENT, USER, role="user", content="Tenant 1 message") + await append_message(redis, "tenant-2", AGENT, USER, role="user", content="Tenant 2 message") + + result_1 = await get_recent_messages(redis, "tenant-1", AGENT, USER) + result_2 = await get_recent_messages(redis, "tenant-2", AGENT, USER) + + assert result_1[0]["content"] == "Tenant 1 message" + assert result_2[0]["content"] == "Tenant 2 message" + + +async def test_key_namespacing_agent_isolation(redis): + """Different agents for the same tenant+user have isolated memory.""" + await append_message(redis, TENANT, "agent-1", USER, role="user", content="Agent 1 context") + await append_message(redis, TENANT, "agent-2", USER, role="user", content="Agent 2 context") + + result_1 = await get_recent_messages(redis, TENANT, "agent-1", USER) + result_2 = await get_recent_messages(redis, TENANT, "agent-2", USER) + + assert result_1[0]["content"] == "Agent 1 context" + assert result_2[0]["content"] == "Agent 2 context" + + +async def test_message_role_and_content_round_trip(redis): + """Messages store and retrieve role + content correctly.""" + await append_message(redis, TENANT, AGENT, USER, role="assistant", content="I can help you with that.") + result = await get_recent_messages(redis, TENANT, AGENT, USER) + msg = result[0] + assert msg["role"] == "assistant" + assert msg["content"] == "I can help you with that." + # Verify it has exactly role and content keys + assert set(msg.keys()) == {"role", "content"}