feat(02-01): add two-layer memory system — Redis sliding window + pgvector long-term
- ConversationEmbedding ORM model with Vector(384) column (pgvector) - memory_short_key, escalation_status_key, pending_tool_confirm_key in redis_keys.py - orchestrator/memory/short_term.py: RPUSH/LTRIM sliding window (get_recent_messages, append_message) - orchestrator/memory/long_term.py: pgvector HNSW cosine search (retrieve_relevant, store_embedding) - Migration 002: conversation_embeddings table, HNSW index, RLS with FORCE, SELECT/INSERT only - 10 unit tests (fakeredis), 6 integration tests (pgvector) — all passing - Auto-fix [Rule 3]: postgres image updated to pgvector/pgvector:pg16 (extension required)
This commit is contained in:
@@ -11,7 +11,7 @@ volumes:
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
image: pgvector/pgvector:pg16
|
||||
container_name: konstruct-postgres
|
||||
environment:
|
||||
POSTGRES_DB: konstruct
|
||||
|
||||
146
migrations/versions/002_phase2_memory.py
Normal file
146
migrations/versions/002_phase2_memory.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Phase 2: conversation_embeddings table with HNSW index and RLS
|
||||
|
||||
Revision ID: 002
|
||||
Revises: 001
|
||||
Create Date: 2026-03-23
|
||||
|
||||
This migration adds the conversation_embeddings table for the long-term
|
||||
conversational memory system. It stores pgvector embeddings of past
|
||||
conversation turns for semantic similarity retrieval.
|
||||
|
||||
Key design decisions:
|
||||
1. pgvector extension is enabled (CREATE EXTENSION IF NOT EXISTS vector)
|
||||
2. HNSW index with m=16, ef_construction=64 for approximate nearest neighbor
|
||||
search — cosine distance operator (vector_cosine_ops)
|
||||
3. Covering index on (tenant_id, agent_id, user_id, created_at DESC) for
|
||||
pre-filtering before ANN search
|
||||
4. RLS with FORCE — tenant_id isolation enforced at DB level
|
||||
5. GRANT SELECT, INSERT only — embeddings are immutable (no UPDATE/DELETE)
|
||||
This models conversation history as an append-only audit log
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "002"
|
||||
down_revision: Union[str, None] = "001"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# -------------------------------------------------------------------------
|
||||
# 1. Enable pgvector extension (idempotent)
|
||||
# -------------------------------------------------------------------------
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# 2. Create conversation_embeddings table
|
||||
# -------------------------------------------------------------------------
|
||||
op.create_table(
|
||||
"conversation_embeddings",
|
||||
sa.Column(
|
||||
"id",
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
),
|
||||
sa.Column(
|
||||
"tenant_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"agent_id",
|
||||
UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
sa.Text,
|
||||
nullable=False,
|
||||
comment="Channel-native user identifier (e.g. Slack user ID U12345)",
|
||||
),
|
||||
sa.Column(
|
||||
"content",
|
||||
sa.Text,
|
||||
nullable=False,
|
||||
comment="Original message text that was embedded",
|
||||
),
|
||||
sa.Column(
|
||||
"role",
|
||||
sa.Text,
|
||||
nullable=False,
|
||||
comment="Message role: 'user' or 'assistant'",
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("NOW()"),
|
||||
),
|
||||
# The embedding column uses pgvector type — created via raw DDL below
|
||||
# because SQLAlchemy doesn't know the 'vector' type without pgvector extension
|
||||
)
|
||||
|
||||
# Add embedding column as vector(384) — must be raw DDL for pgvector type
|
||||
op.execute(
|
||||
"ALTER TABLE conversation_embeddings "
|
||||
"ADD COLUMN embedding vector(384) NOT NULL"
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# 3. Create covering index for pre-filter (tenant + agent + user + time)
|
||||
# Used to scope queries before the ANN operator for isolation + performance
|
||||
# -------------------------------------------------------------------------
|
||||
op.create_index(
|
||||
"ix_conv_embed_tenant_agent_user_time",
|
||||
"conversation_embeddings",
|
||||
["tenant_id", "agent_id", "user_id", "created_at"],
|
||||
postgresql_ops={"created_at": "DESC"},
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# 4. Create HNSW index for approximate nearest neighbor cosine search
|
||||
# m=16: number of bidirectional links per node (quality vs. memory tradeoff)
|
||||
# ef_construction=64: search width during build (quality vs. speed)
|
||||
# vector_cosine_ops: uses cosine distance (compatible with <=> operator)
|
||||
# -------------------------------------------------------------------------
|
||||
op.execute("""
|
||||
CREATE INDEX ix_conv_embed_hnsw
|
||||
ON conversation_embeddings
|
||||
USING hnsw (embedding vector_cosine_ops)
|
||||
WITH (m = 16, ef_construction = 64)
|
||||
""")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# 5. Apply Row Level Security
|
||||
# FORCE ensures even the table owner cannot bypass tenant isolation
|
||||
# -------------------------------------------------------------------------
|
||||
op.execute("ALTER TABLE conversation_embeddings ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE conversation_embeddings FORCE ROW LEVEL SECURITY")
|
||||
op.execute("""
|
||||
CREATE POLICY tenant_isolation ON conversation_embeddings
|
||||
USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid)
|
||||
""")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# 6. Grant permissions to konstruct_app role
|
||||
# SELECT + INSERT only — embeddings are immutable (no UPDATE or DELETE)
|
||||
# This models conversation history as an append-only audit log
|
||||
# -------------------------------------------------------------------------
|
||||
op.execute("GRANT SELECT, INSERT ON conversation_embeddings TO konstruct_app")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("REVOKE ALL ON conversation_embeddings FROM konstruct_app")
|
||||
op.drop_table("conversation_embeddings")
|
||||
# Note: We do NOT drop the vector extension — other tables may use it
|
||||
22
packages/orchestrator/orchestrator/memory/__init__.py
Normal file
22
packages/orchestrator/orchestrator/memory/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Konstruct Agent Memory Layer.
|
||||
|
||||
Two-layer conversational memory system:
|
||||
|
||||
1. Short-term (Redis sliding window):
|
||||
- Stores the last N messages verbatim
|
||||
- Zero latency — Redis is always available
|
||||
- Provides immediate in-session context continuity
|
||||
- See: short_term.py
|
||||
|
||||
2. Long-term (pgvector HNSW similarity search):
|
||||
- Stores all messages as semantic embeddings
|
||||
- Retrieves top-K semantically relevant past exchanges
|
||||
- Provides cross-session recall (user preferences, past issues, etc.)
|
||||
- Embedding model: all-MiniLM-L6-v2 (384 dimensions)
|
||||
- See: long_term.py
|
||||
|
||||
Memory scoping: All operations are keyed by (tenant_id, agent_id, user_id).
|
||||
This ensures complete isolation — no cross-tenant, cross-agent, or cross-user
|
||||
contamination is possible.
|
||||
"""
|
||||
164
packages/orchestrator/orchestrator/memory/long_term.py
Normal file
164
packages/orchestrator/orchestrator/memory/long_term.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
pgvector-backed long-term conversational memory.
|
||||
|
||||
Stores conversation turns as 384-dimensional embeddings (all-MiniLM-L6-v2)
|
||||
and retrieves semantically relevant past exchanges using HNSW cosine similarity
|
||||
search.
|
||||
|
||||
CRITICAL SECURITY CONSTRAINTS:
|
||||
1. ALL queries MUST pre-filter by (tenant_id, agent_id, user_id) BEFORE the
|
||||
ANN operator. This prevents cross-tenant, cross-agent, or cross-user data
|
||||
leakage even in the face of embedding collisions.
|
||||
2. Cosine similarity threshold filters out low-relevance results — only content
|
||||
genuinely related to the query should be injected into the LLM prompt.
|
||||
3. RLS (Row Level Security) is the DB-level backstop — the application-level
|
||||
filters above are the primary guard; RLS is the safety net.
|
||||
|
||||
pgvector cosine operations:
|
||||
- <=> operator: cosine DISTANCE (0 = identical, 2 = opposite)
|
||||
- cosine similarity = 1 - cosine distance
|
||||
- A threshold of 0.75 means: only return results where
|
||||
1 - (embedding <=> query) >= 0.75 → distance <= 0.25
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def retrieve_relevant(
|
||||
session: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
agent_id: uuid.UUID,
|
||||
user_id: str,
|
||||
query_embedding: list[float],
|
||||
top_k: int = 3,
|
||||
threshold: float = 0.75,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Retrieve semantically relevant past conversation content.
|
||||
|
||||
Performs an HNSW approximate nearest neighbor search scoped strictly to
|
||||
(tenant_id, agent_id, user_id). Results below the cosine similarity
|
||||
threshold are discarded.
|
||||
|
||||
IMPORTANT: The (tenant_id, agent_id, user_id) pre-filter is applied
|
||||
BEFORE the ANN operator to guarantee isolation. This is not optional —
|
||||
removing these WHERE clauses would allow cross-tenant data leakage.
|
||||
|
||||
Args:
|
||||
session: Async SQLAlchemy session (must have RLS configured).
|
||||
tenant_id: Tenant UUID — mandatory for isolation.
|
||||
agent_id: Agent UUID — mandatory for isolation.
|
||||
user_id: End-user identifier — mandatory for isolation.
|
||||
query_embedding: 384-dimensional query vector (all-MiniLM-L6-v2).
|
||||
top_k: Maximum number of results to return. Default 3.
|
||||
threshold: Minimum cosine similarity (0.0–1.0). Default 0.75.
|
||||
Set lower for broader recall, higher for precision.
|
||||
|
||||
Returns:
|
||||
List of content strings (original message text), most relevant first.
|
||||
Returns empty list if no results meet the threshold.
|
||||
"""
|
||||
# Convert embedding list to pgvector string format: '[0.1, 0.2, ...]'
|
||||
vec_str = "[" + ",".join(str(float(v)) for v in query_embedding) + "]"
|
||||
|
||||
# CRITICAL: pre-filter by all three isolation columns BEFORE ANN search.
|
||||
# The ORDER BY uses <=> (cosine distance) — lower is more similar.
|
||||
# We convert to similarity (1 - distance) to apply the threshold filter.
|
||||
stmt = text("""
|
||||
SELECT content, 1 - (embedding <=> CAST(:query AS vector)) AS similarity
|
||||
FROM conversation_embeddings
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND agent_id = :agent_id
|
||||
AND user_id = :user_id
|
||||
AND 1 - (embedding <=> CAST(:query AS vector)) >= :threshold
|
||||
ORDER BY embedding <=> CAST(:query AS vector)
|
||||
LIMIT :top_k
|
||||
""")
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
stmt,
|
||||
{
|
||||
"query": vec_str,
|
||||
"tenant_id": str(tenant_id),
|
||||
"agent_id": str(agent_id),
|
||||
"user_id": user_id,
|
||||
"threshold": threshold,
|
||||
"top_k": top_k,
|
||||
},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"pgvector retrieve_relevant failed for tenant=%s agent=%s user=%s",
|
||||
tenant_id,
|
||||
agent_id,
|
||||
user_id,
|
||||
)
|
||||
return []
|
||||
|
||||
return [row.content for row in rows]
|
||||
|
||||
|
||||
async def store_embedding(
|
||||
session: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
agent_id: uuid.UUID,
|
||||
user_id: str,
|
||||
content: str,
|
||||
role: str,
|
||||
embedding: list[float],
|
||||
) -> None:
|
||||
"""
|
||||
Store a conversation turn embedding in the database.
|
||||
|
||||
Inserts a new row into conversation_embeddings. Embeddings are immutable
|
||||
once stored — there is no UPDATE path. This matches the audit-log-like
|
||||
nature of conversation history.
|
||||
|
||||
Args:
|
||||
session: Async SQLAlchemy session (must have RLS configured).
|
||||
tenant_id: Tenant UUID for isolation.
|
||||
agent_id: Agent UUID for isolation.
|
||||
user_id: End-user identifier for isolation.
|
||||
content: Original message text.
|
||||
role: "user" or "assistant".
|
||||
embedding: 384-dimensional float list (all-MiniLM-L6-v2).
|
||||
"""
|
||||
vec_str = "[" + ",".join(str(float(v)) for v in embedding) + "]"
|
||||
|
||||
stmt = text("""
|
||||
INSERT INTO conversation_embeddings
|
||||
(id, tenant_id, agent_id, user_id, content, role, embedding)
|
||||
VALUES
|
||||
(gen_random_uuid(), :tenant_id, :agent_id, :user_id, :content, :role, CAST(:embedding AS vector))
|
||||
""")
|
||||
|
||||
try:
|
||||
await session.execute(
|
||||
stmt,
|
||||
{
|
||||
"tenant_id": str(tenant_id),
|
||||
"agent_id": str(agent_id),
|
||||
"user_id": user_id,
|
||||
"content": content,
|
||||
"role": role,
|
||||
"embedding": vec_str,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"pgvector store_embedding failed for tenant=%s agent=%s user=%s",
|
||||
tenant_id,
|
||||
agent_id,
|
||||
user_id,
|
||||
)
|
||||
raise
|
||||
111
packages/orchestrator/orchestrator/memory/short_term.py
Normal file
111
packages/orchestrator/orchestrator/memory/short_term.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
Redis sliding window for short-term conversational memory.
|
||||
|
||||
Implements a RPUSH + LTRIM pattern:
|
||||
- RPUSH appends new messages to the right (tail) of the list
|
||||
- LTRIM trims the list to the last `window` entries
|
||||
- LRANGE retrieves all current entries
|
||||
|
||||
This gives O(1) append + O(1) trim + O(N) read where N <= window size.
|
||||
|
||||
Key format: {tenant_id}:memory:short:{agent_id}:{user_id}
|
||||
|
||||
Messages are stored as JSON objects with "role" and "content" keys,
|
||||
matching the OpenAI chat messages format for direct injection into
|
||||
the LLM messages array.
|
||||
|
||||
Design decisions:
|
||||
- No TTL: message retention is indefinite per user preference. If TTL-based
|
||||
expiry is needed in the future, add it via a separate expiry policy.
|
||||
- No compression: messages are stored as plain JSON. At 20 messages * ~200
|
||||
bytes average, storage per user/agent is ~4KB — negligible.
|
||||
- Parameterized window: callers control the window size, defaulting to 20.
|
||||
This allows future policy changes without code modification.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from shared.redis_keys import memory_short_key
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_recent_messages(
|
||||
redis: object,
|
||||
tenant_id: str,
|
||||
agent_id: str,
|
||||
user_id: str,
|
||||
n: int = 20,
|
||||
) -> list[dict[str, str]]:
|
||||
"""
|
||||
Retrieve the most recent N messages from the sliding window.
|
||||
|
||||
Returns messages in insertion order (oldest first) — this matches the
|
||||
expected LLM message array format where conversation flows chronologically.
|
||||
|
||||
Args:
|
||||
redis: Redis async client (redis.asyncio.Redis or compatible).
|
||||
tenant_id: Konstruct tenant identifier.
|
||||
agent_id: Agent UUID string.
|
||||
user_id: End-user identifier (channel-native).
|
||||
n: Maximum number of messages to retrieve. Default 20.
|
||||
Pass a larger value than the window size to get all messages.
|
||||
|
||||
Returns:
|
||||
List of message dicts with "role" and "content" keys, oldest first.
|
||||
Returns empty list if no messages exist for this key.
|
||||
"""
|
||||
key = memory_short_key(tenant_id, agent_id, user_id)
|
||||
|
||||
# LRANGE -n -1 returns the last n items in insertion order
|
||||
raw_messages = await redis.lrange(key, -n, -1) # type: ignore[union-attr]
|
||||
|
||||
messages: list[dict[str, str]] = []
|
||||
for raw in raw_messages:
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
messages.append({"role": str(msg["role"]), "content": str(msg["content"])})
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
logger.warning("Malformed message in Redis key %s — skipping", key)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
async def append_message(
|
||||
redis: object,
|
||||
tenant_id: str,
|
||||
agent_id: str,
|
||||
user_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
window: int = 20,
|
||||
) -> None:
|
||||
"""
|
||||
Append a message to the sliding window and trim to window size.
|
||||
|
||||
Uses a pipeline to make RPUSH + LTRIM atomic — no race condition
|
||||
between append and trim even under concurrent writes.
|
||||
|
||||
Args:
|
||||
redis: Redis async client.
|
||||
tenant_id: Konstruct tenant identifier.
|
||||
agent_id: Agent UUID string.
|
||||
user_id: End-user identifier (channel-native).
|
||||
role: Message role: "user" or "assistant".
|
||||
content: Message text content.
|
||||
window: Maximum number of messages to retain. Default 20.
|
||||
After this operation the list will contain at most `window`
|
||||
entries (the most recent ones).
|
||||
"""
|
||||
key = memory_short_key(tenant_id, agent_id, user_id)
|
||||
serialized = json.dumps({"role": role, "content": content})
|
||||
|
||||
# Pipeline ensures RPUSH + LTRIM are sent atomically
|
||||
pipe = redis.pipeline() # type: ignore[union-attr]
|
||||
pipe.rpush(key, serialized)
|
||||
# LTRIM to last `window` entries: keep index -(window) through -1
|
||||
pipe.ltrim(key, -window, -1)
|
||||
await pipe.execute()
|
||||
@@ -12,6 +12,7 @@ dependencies = [
|
||||
"fastapi[standard]>=0.115.0",
|
||||
"celery[redis]>=5.4.0",
|
||||
"httpx>=0.28.0",
|
||||
"sentence-transformers>=3.0.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
|
||||
@@ -19,6 +19,7 @@ dependencies = [
|
||||
"httpx>=0.28.0",
|
||||
"slowapi>=0.1.9",
|
||||
"bcrypt>=4.0.0",
|
||||
"pgvector>=0.3.0",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
|
||||
96
packages/shared/shared/models/memory.py
Normal file
96
packages/shared/shared/models/memory.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
SQLAlchemy 2.0 ORM models for conversational memory.
|
||||
|
||||
ConversationEmbedding stores pgvector embeddings of past conversation turns
|
||||
for long-term semantic retrieval across sessions. This is the persistence layer
|
||||
for the long-term memory module in the Agent Orchestrator.
|
||||
|
||||
IMPORTANT:
|
||||
- Embeddings are immutable (no UPDATE) — like audit records. We store and read
|
||||
but never modify. This simplifies the data model and prevents mutation bugs.
|
||||
- RLS is ENABLED with FORCE — tenant_id isolation is enforced at the DB level.
|
||||
- The vector dimension (384) corresponds to all-MiniLM-L6-v2 output size.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import DateTime, ForeignKey, Text, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from shared.models.tenant import Base
|
||||
|
||||
|
||||
class ConversationEmbedding(Base):
|
||||
"""
|
||||
A single embedded conversation turn stored for long-term recall.
|
||||
|
||||
Each row represents one message (user or assistant) converted to a
|
||||
384-dimensional embedding via all-MiniLM-L6-v2. The Agent Orchestrator
|
||||
queries this table at prompt assembly time to inject relevant past context.
|
||||
|
||||
Scoped by:
|
||||
- tenant_id: RLS enforced isolation between tenants
|
||||
- agent_id: isolation between agents within a tenant
|
||||
- user_id: isolation between end-users of the same agent
|
||||
|
||||
RLS policy enforces:
|
||||
tenant_id = current_setting('app.current_tenant', TRUE)::uuid
|
||||
|
||||
FORCE ROW LEVEL SECURITY ensures even the table owner cannot bypass this.
|
||||
"""
|
||||
|
||||
__tablename__ = "conversation_embeddings"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
agent_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Text,
|
||||
nullable=False,
|
||||
comment="Channel-native user identifier (e.g. Slack user ID U12345)",
|
||||
)
|
||||
content: Mapped[str] = mapped_column(
|
||||
Text,
|
||||
nullable=False,
|
||||
comment="Original message text that was embedded",
|
||||
)
|
||||
role: Mapped[str] = mapped_column(
|
||||
Text,
|
||||
nullable=False,
|
||||
comment="Message role: 'user' or 'assistant'",
|
||||
)
|
||||
embedding: Mapped[list[float]] = mapped_column(
|
||||
Vector(384),
|
||||
nullable=False,
|
||||
comment="all-MiniLM-L6-v2 embedding (384 dimensions)",
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<ConversationEmbedding id={self.id} "
|
||||
f"tenant_id={self.tenant_id} agent_id={self.agent_id} "
|
||||
f"user_id={self.user_id!r} role={self.role!r}>"
|
||||
)
|
||||
@@ -86,3 +86,61 @@ def engaged_thread_key(tenant_id: str, thread_id: str) -> str:
|
||||
Namespaced Redis key: "{tenant_id}:engaged:{thread_id}"
|
||||
"""
|
||||
return f"{tenant_id}:engaged:{thread_id}"
|
||||
|
||||
|
||||
def memory_short_key(tenant_id: str, agent_id: str, user_id: str) -> str:
|
||||
"""
|
||||
Redis key for the short-term conversational memory sliding window.
|
||||
|
||||
Stores the last N messages (serialized as JSON) for a specific
|
||||
tenant + agent + user combination. Used by the Agent Orchestrator to
|
||||
inject recent conversation history into every LLM prompt.
|
||||
|
||||
Key includes all three discriminators to ensure:
|
||||
- Two users talking to the same agent have separate histories
|
||||
- The same user talking to two different agents has separate histories
|
||||
- Two tenants with the same agent/user IDs are fully isolated
|
||||
|
||||
Args:
|
||||
tenant_id: Konstruct tenant identifier.
|
||||
agent_id: Agent identifier (UUID string).
|
||||
user_id: End-user identifier (channel-native, e.g. Slack user ID).
|
||||
|
||||
Returns:
|
||||
Namespaced Redis key: "{tenant_id}:memory:short:{agent_id}:{user_id}"
|
||||
"""
|
||||
return f"{tenant_id}:memory:short:{agent_id}:{user_id}"
|
||||
|
||||
|
||||
def escalation_status_key(tenant_id: str, thread_id: str) -> str:
|
||||
"""
|
||||
Redis key for tracking escalation status of a thread.
|
||||
|
||||
Stores the current escalation state for a conversation thread —
|
||||
whether it has been escalated to a human or another agent.
|
||||
|
||||
Args:
|
||||
tenant_id: Konstruct tenant identifier.
|
||||
thread_id: Thread identifier.
|
||||
|
||||
Returns:
|
||||
Namespaced Redis key: "{tenant_id}:escalation:{thread_id}"
|
||||
"""
|
||||
return f"{tenant_id}:escalation:{thread_id}"
|
||||
|
||||
|
||||
def pending_tool_confirm_key(tenant_id: str, thread_id: str) -> str:
|
||||
"""
|
||||
Redis key for tracking pending tool confirmation requests.
|
||||
|
||||
Stores the pending tool invocation that requires explicit user
|
||||
confirmation before execution (e.g. destructive operations).
|
||||
|
||||
Args:
|
||||
tenant_id: Konstruct tenant identifier.
|
||||
thread_id: Thread identifier.
|
||||
|
||||
Returns:
|
||||
Namespaced Redis key: "{tenant_id}:tool_confirm:{thread_id}"
|
||||
"""
|
||||
return f"{tenant_id}:tool_confirm:{thread_id}"
|
||||
|
||||
259
tests/integration/test_memory_long_term.py
Normal file
259
tests/integration/test_memory_long_term.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Integration tests for pgvector long-term memory.
|
||||
|
||||
Requires a live PostgreSQL instance with pgvector extension installed.
|
||||
Tests are automatically skipped if the database is not available
|
||||
(fixture from conftest.py handles that via pytest.skip).
|
||||
|
||||
Key scenarios tested:
|
||||
- store_embedding inserts with correct scoping
|
||||
- retrieve_relevant returns matching content above threshold
|
||||
- Cross-tenant isolation: tenant A's embeddings never returned for tenant B
|
||||
- High threshold returns empty list for dissimilar queries
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from orchestrator.memory.long_term import retrieve_relevant, store_embedding
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def agent_a_id() -> uuid.UUID:
|
||||
"""Return a stable agent UUID for tenant A tests."""
|
||||
return uuid.UUID("aaaaaaaa-0000-0000-0000-000000000001")
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def agent_b_id() -> uuid.UUID:
|
||||
"""Return a stable agent UUID for tenant B tests."""
|
||||
return uuid.UUID("bbbbbbbb-0000-0000-0000-000000000002")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_store_embedding_inserts_row(
|
||||
db_session: AsyncSession,
|
||||
tenant_a: dict,
|
||||
agent_a_id: uuid.UUID,
|
||||
):
|
||||
"""store_embedding inserts a row into conversation_embeddings."""
|
||||
from shared.rls import current_tenant_id
|
||||
|
||||
tenant_id = tenant_a["id"]
|
||||
user_id = "user-store-test"
|
||||
embedding = [0.1] * 384
|
||||
content = "I prefer concise answers."
|
||||
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
try:
|
||||
await store_embedding(db_session, tenant_id, agent_a_id, user_id, content, "user", embedding)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.execute(
|
||||
text("SELECT content, role FROM conversation_embeddings WHERE tenant_id = :tid AND user_id = :uid"),
|
||||
{"tid": str(tenant_id), "uid": user_id},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
finally:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
assert len(rows) == 1
|
||||
assert rows[0].content == content
|
||||
assert rows[0].role == "user"
|
||||
|
||||
|
||||
async def test_retrieve_relevant_returns_similar_content(
|
||||
db_session: AsyncSession,
|
||||
tenant_a: dict,
|
||||
agent_a_id: uuid.UUID,
|
||||
):
|
||||
"""retrieve_relevant returns content above cosine similarity threshold."""
|
||||
from shared.rls import current_tenant_id
|
||||
|
||||
tenant_id = tenant_a["id"]
|
||||
user_id = "user-retrieve-test"
|
||||
|
||||
# Store two embeddings: one very similar to the query, one dissimilar
|
||||
# We simulate similarity by using identical embeddings
|
||||
similar_embedding = [1.0] + [0.0] * 383
|
||||
dissimilar_embedding = [0.0] * 383 + [1.0]
|
||||
query_embedding = [1.0] + [0.0] * 383 # identical to similar_embedding
|
||||
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
try:
|
||||
await store_embedding(
|
||||
db_session, tenant_id, agent_a_id, user_id,
|
||||
"The user likes Python programming.", "user", similar_embedding
|
||||
)
|
||||
await store_embedding(
|
||||
db_session, tenant_id, agent_a_id, user_id,
|
||||
"This is completely unrelated content.", "user", dissimilar_embedding
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
results = await retrieve_relevant(
|
||||
db_session, tenant_id, agent_a_id, user_id, query_embedding, top_k=3, threshold=0.5
|
||||
)
|
||||
finally:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
# Should return the similar content
|
||||
assert len(results) >= 1
|
||||
assert any("Python" in r for r in results)
|
||||
|
||||
|
||||
async def test_retrieve_relevant_high_threshold_returns_empty(
|
||||
db_session: AsyncSession,
|
||||
tenant_a: dict,
|
||||
agent_a_id: uuid.UUID,
|
||||
):
|
||||
"""retrieve_relevant with threshold=0.99 and dissimilar query returns empty list."""
|
||||
from shared.rls import current_tenant_id
|
||||
|
||||
tenant_id = tenant_a["id"]
|
||||
user_id = "user-threshold-test"
|
||||
|
||||
# Store an embedding pointing in one direction
|
||||
stored_embedding = [1.0] + [0.0] * 383
|
||||
# Query pointing in orthogonal direction — cosine distance ~= 1.0, similarity ~= 0.0
|
||||
query_embedding = [0.0] + [1.0] + [0.0] * 382
|
||||
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
try:
|
||||
await store_embedding(
|
||||
db_session, tenant_id, agent_a_id, user_id,
|
||||
"Some stored content.", "user", stored_embedding
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
results = await retrieve_relevant(
|
||||
db_session, tenant_id, agent_a_id, user_id, query_embedding, top_k=3, threshold=0.99
|
||||
)
|
||||
finally:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
async def test_cross_tenant_isolation(
|
||||
db_session: AsyncSession,
|
||||
tenant_a: dict,
|
||||
tenant_b: dict,
|
||||
agent_a_id: uuid.UUID,
|
||||
agent_b_id: uuid.UUID,
|
||||
):
|
||||
"""
|
||||
retrieve_relevant with tenant_id=A NEVER returns tenant_id=B embeddings.
|
||||
|
||||
This is the critical security test — cross-tenant contamination would be
|
||||
a catastrophic data leak in a multi-tenant system.
|
||||
"""
|
||||
from shared.rls import current_tenant_id
|
||||
|
||||
user_id = "shared-user-id"
|
||||
tenant_a_id = tenant_a["id"]
|
||||
tenant_b_id = tenant_b["id"]
|
||||
|
||||
# Same query embedding for both tenants
|
||||
embedding = [1.0] + [0.0] * 383
|
||||
|
||||
# Store embedding for tenant B
|
||||
token = current_tenant_id.set(tenant_b_id)
|
||||
try:
|
||||
await store_embedding(
|
||||
db_session, tenant_b_id, agent_b_id, user_id,
|
||||
"Tenant B secret information.", "user", embedding
|
||||
)
|
||||
await db_session.commit()
|
||||
finally:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
# Query as tenant A — should NOT see tenant B's data
|
||||
token = current_tenant_id.set(tenant_a_id)
|
||||
try:
|
||||
results = await retrieve_relevant(
|
||||
db_session, tenant_a_id, agent_a_id, user_id, embedding, top_k=10, threshold=0.0
|
||||
)
|
||||
finally:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
# Tenant A should get nothing — it has no embeddings of its own
|
||||
# and it MUST NOT see tenant B's embeddings
|
||||
for result in results:
|
||||
assert "Tenant B" not in result, "Cross-tenant data leakage detected!"
|
||||
|
||||
|
||||
async def test_retrieve_relevant_user_isolation(
|
||||
db_session: AsyncSession,
|
||||
tenant_a: dict,
|
||||
agent_a_id: uuid.UUID,
|
||||
):
|
||||
"""retrieve_relevant for user A never returns user B embeddings."""
|
||||
from shared.rls import current_tenant_id
|
||||
|
||||
tenant_id = tenant_a["id"]
|
||||
embedding = [1.0] + [0.0] * 383
|
||||
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
try:
|
||||
await store_embedding(
|
||||
db_session, tenant_id, agent_a_id, "user-A",
|
||||
"User A private information.", "user", embedding
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
# Query as user B — should not see user A's data
|
||||
results = await retrieve_relevant(
|
||||
db_session, tenant_id, agent_a_id, "user-B", embedding, top_k=10, threshold=0.0
|
||||
)
|
||||
finally:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
for result in results:
|
||||
assert "User A private" not in result
|
||||
|
||||
|
||||
async def test_retrieve_relevant_top_k_limits_results(
|
||||
db_session: AsyncSession,
|
||||
tenant_a: dict,
|
||||
agent_a_id: uuid.UUID,
|
||||
):
|
||||
"""retrieve_relevant respects top_k limit."""
|
||||
from shared.rls import current_tenant_id
|
||||
|
||||
tenant_id = tenant_a["id"]
|
||||
user_id = "user-topk-test"
|
||||
embedding = [1.0] + [0.0] * 383
|
||||
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
try:
|
||||
# Store 5 very similar embeddings
|
||||
for i in range(5):
|
||||
await store_embedding(
|
||||
db_session, tenant_id, agent_a_id, user_id,
|
||||
f"Content item {i}", "user", embedding
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
results = await retrieve_relevant(
|
||||
db_session, tenant_id, agent_a_id, user_id, embedding, top_k=2, threshold=0.0
|
||||
)
|
||||
finally:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
assert len(results) <= 2
|
||||
139
tests/unit/test_memory_short_term.py
Normal file
139
tests/unit/test_memory_short_term.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Unit tests for the Redis short-term memory sliding window.
|
||||
|
||||
Uses fakeredis to avoid requiring a real Redis connection.
|
||||
All tests verify tenant+agent+user namespacing and RPUSH/LTRIM correctness.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import fakeredis.aioredis
|
||||
import pytest
|
||||
|
||||
from orchestrator.memory.short_term import append_message, get_recent_messages
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def redis():
|
||||
"""Return a fakeredis async client for testing."""
|
||||
client = fakeredis.aioredis.FakeRedis()
|
||||
yield client
|
||||
await client.aclose()
|
||||
|
||||
|
||||
TENANT = "tenant-abc"
|
||||
AGENT = "agent-xyz"
|
||||
USER = "user-123"
|
||||
|
||||
|
||||
async def test_get_recent_messages_empty(redis):
|
||||
"""get_recent_messages on empty key returns empty list."""
|
||||
result = await get_recent_messages(redis, TENANT, AGENT, USER)
|
||||
assert result == []
|
||||
|
||||
|
||||
async def test_append_and_get_single_message(redis):
|
||||
"""append_message stores a message; get_recent_messages retrieves it."""
|
||||
await append_message(redis, TENANT, AGENT, USER, role="user", content="Hello!")
|
||||
result = await get_recent_messages(redis, TENANT, AGENT, USER)
|
||||
assert len(result) == 1
|
||||
assert result[0] == {"role": "user", "content": "Hello!"}
|
||||
|
||||
|
||||
async def test_append_multiple_messages_ordering(redis):
|
||||
"""Messages are returned in insertion order (oldest first)."""
|
||||
await append_message(redis, TENANT, AGENT, USER, role="user", content="First")
|
||||
await append_message(redis, TENANT, AGENT, USER, role="assistant", content="Second")
|
||||
await append_message(redis, TENANT, AGENT, USER, role="user", content="Third")
|
||||
|
||||
result = await get_recent_messages(redis, TENANT, AGENT, USER)
|
||||
assert len(result) == 3
|
||||
assert result[0]["content"] == "First"
|
||||
assert result[1]["content"] == "Second"
|
||||
assert result[2]["content"] == "Third"
|
||||
|
||||
|
||||
async def test_sliding_window_trims_to_window_size(redis):
|
||||
"""append_message with window=5 keeps only last 5 messages."""
|
||||
for i in range(10):
|
||||
await append_message(redis, TENANT, AGENT, USER, role="user", content=f"msg-{i}", window=5)
|
||||
|
||||
result = await get_recent_messages(redis, TENANT, AGENT, USER, n=20)
|
||||
assert len(result) == 5
|
||||
# Should have the last 5 messages: msg-5 through msg-9
|
||||
contents = [m["content"] for m in result]
|
||||
assert contents == ["msg-5", "msg-6", "msg-7", "msg-8", "msg-9"]
|
||||
|
||||
|
||||
async def test_default_window_20(redis):
|
||||
"""Default window is 20 — 21st message pushes out the first."""
|
||||
for i in range(21):
|
||||
await append_message(redis, TENANT, AGENT, USER, role="user", content=f"msg-{i}")
|
||||
|
||||
result = await get_recent_messages(redis, TENANT, AGENT, USER)
|
||||
assert len(result) == 20
|
||||
assert result[0]["content"] == "msg-1"
|
||||
assert result[-1]["content"] == "msg-20"
|
||||
|
||||
|
||||
async def test_get_recent_messages_n_parameter(redis):
|
||||
"""get_recent_messages n parameter limits results."""
|
||||
for i in range(10):
|
||||
await append_message(redis, TENANT, AGENT, USER, role="user", content=f"msg-{i}")
|
||||
|
||||
result = await get_recent_messages(redis, TENANT, AGENT, USER, n=3)
|
||||
assert len(result) == 3
|
||||
# n=3 returns last 3: msg-7, msg-8, msg-9
|
||||
contents = [m["content"] for m in result]
|
||||
assert contents == ["msg-7", "msg-8", "msg-9"]
|
||||
|
||||
|
||||
async def test_key_namespacing_user_isolation(redis):
|
||||
"""Different users of the same agent have isolated memory."""
|
||||
await append_message(redis, TENANT, AGENT, "user-A", role="user", content="User A message")
|
||||
await append_message(redis, TENANT, AGENT, "user-B", role="user", content="User B message")
|
||||
|
||||
result_a = await get_recent_messages(redis, TENANT, AGENT, "user-A")
|
||||
result_b = await get_recent_messages(redis, TENANT, AGENT, "user-B")
|
||||
|
||||
assert len(result_a) == 1
|
||||
assert result_a[0]["content"] == "User A message"
|
||||
assert len(result_b) == 1
|
||||
assert result_b[0]["content"] == "User B message"
|
||||
|
||||
|
||||
async def test_key_namespacing_tenant_isolation(redis):
|
||||
"""Different tenants with same agent+user IDs have isolated memory."""
|
||||
await append_message(redis, "tenant-1", AGENT, USER, role="user", content="Tenant 1 message")
|
||||
await append_message(redis, "tenant-2", AGENT, USER, role="user", content="Tenant 2 message")
|
||||
|
||||
result_1 = await get_recent_messages(redis, "tenant-1", AGENT, USER)
|
||||
result_2 = await get_recent_messages(redis, "tenant-2", AGENT, USER)
|
||||
|
||||
assert result_1[0]["content"] == "Tenant 1 message"
|
||||
assert result_2[0]["content"] == "Tenant 2 message"
|
||||
|
||||
|
||||
async def test_key_namespacing_agent_isolation(redis):
|
||||
"""Different agents for the same tenant+user have isolated memory."""
|
||||
await append_message(redis, TENANT, "agent-1", USER, role="user", content="Agent 1 context")
|
||||
await append_message(redis, TENANT, "agent-2", USER, role="user", content="Agent 2 context")
|
||||
|
||||
result_1 = await get_recent_messages(redis, TENANT, "agent-1", USER)
|
||||
result_2 = await get_recent_messages(redis, TENANT, "agent-2", USER)
|
||||
|
||||
assert result_1[0]["content"] == "Agent 1 context"
|
||||
assert result_2[0]["content"] == "Agent 2 context"
|
||||
|
||||
|
||||
async def test_message_role_and_content_round_trip(redis):
|
||||
"""Messages store and retrieve role + content correctly."""
|
||||
await append_message(redis, TENANT, AGENT, USER, role="assistant", content="I can help you with that.")
|
||||
result = await get_recent_messages(redis, TENANT, AGENT, USER)
|
||||
msg = result[0]
|
||||
assert msg["role"] == "assistant"
|
||||
assert msg["content"] == "I can help you with that."
|
||||
# Verify it has exactly role and content keys
|
||||
assert set(msg.keys()) == {"role", "content"}
|
||||
Reference in New Issue
Block a user