- memory_store.py: User-isolated observation storage with vector embeddings - New endpoints: /memory/save, /memory/query, /memory/get, /memory/timeline - Progressive disclosure pattern for token-efficient retrieval - Updated Dockerfile to ROCm 7.2 nightly
379 lines
13 KiB
Python
379 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Memory Store — SQLite + ChromaDB hybrid for agent observations.
|
|
Provides structured storage with vector search.
|
|
"""
|
|
|
|
import sqlite3
|
|
import hashlib
|
|
import logging
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Optional, List, Dict, Any
|
|
from contextlib import contextmanager
|
|
|
|
logger = logging.getLogger("moxie-rag.memory")
|
|
|
|
|
|
class MemoryStore:
|
|
"""SQLite-backed memory store with ChromaDB integration."""
|
|
|
|
def __init__(self, db_path: str, rag_engine):
|
|
self.db_path = Path(db_path)
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self.rag_engine = rag_engine
|
|
self._init_db()
|
|
|
|
@contextmanager
|
|
def _get_conn(self):
|
|
"""Thread-safe connection context manager."""
|
|
conn = sqlite3.connect(str(self.db_path), timeout=30.0)
|
|
conn.row_factory = sqlite3.Row
|
|
try:
|
|
yield conn
|
|
finally:
|
|
conn.close()
|
|
|
|
def _init_db(self):
|
|
"""Initialize the SQLite schema."""
|
|
with self._get_conn() as conn:
|
|
conn.executescript("""
|
|
CREATE TABLE IF NOT EXISTS observations (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id TEXT NOT NULL,
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
type TEXT NOT NULL,
|
|
title TEXT,
|
|
content TEXT NOT NULL,
|
|
content_hash TEXT UNIQUE,
|
|
embedding_id TEXT,
|
|
session_id TEXT,
|
|
tool_name TEXT,
|
|
importance INTEGER DEFAULT 1,
|
|
tags TEXT,
|
|
metadata TEXT
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_obs_user ON observations(user_id);
|
|
CREATE INDEX IF NOT EXISTS idx_obs_type ON observations(type);
|
|
CREATE INDEX IF NOT EXISTS idx_obs_timestamp ON observations(timestamp);
|
|
CREATE INDEX IF NOT EXISTS idx_obs_session ON observations(session_id);
|
|
|
|
CREATE TABLE IF NOT EXISTS preferences (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id TEXT NOT NULL,
|
|
key TEXT NOT NULL,
|
|
value TEXT NOT NULL,
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
UNIQUE(user_id, key)
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_pref_user ON preferences(user_id);
|
|
|
|
CREATE TABLE IF NOT EXISTS relationships (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id TEXT NOT NULL,
|
|
observation_id INTEGER,
|
|
related_id INTEGER,
|
|
relation_type TEXT,
|
|
FOREIGN KEY (observation_id) REFERENCES observations(id),
|
|
FOREIGN KEY (related_id) REFERENCES observations(id)
|
|
);
|
|
""")
|
|
conn.commit()
|
|
logger.info(f"Memory store initialized at {self.db_path}")
|
|
|
|
def _content_hash(self, content: str) -> str:
|
|
"""Generate hash for deduplication."""
|
|
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
|
|
|
def _get_collection_name(self, user_id: str) -> str:
|
|
"""Get ChromaDB collection name for user."""
|
|
return f"moxie_memory_{user_id}"
|
|
|
|
def save_observation(
|
|
self,
|
|
user_id: str,
|
|
content: str,
|
|
obs_type: str = "general",
|
|
title: Optional[str] = None,
|
|
session_id: Optional[str] = None,
|
|
tool_name: Optional[str] = None,
|
|
importance: int = 1,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Save an observation to SQLite and embed in ChromaDB.
|
|
Returns the observation ID and embedding status.
|
|
"""
|
|
content_hash = self._content_hash(content)
|
|
collection = self._get_collection_name(user_id)
|
|
|
|
# Check for duplicate
|
|
with self._get_conn() as conn:
|
|
existing = conn.execute(
|
|
"SELECT id FROM observations WHERE content_hash = ?",
|
|
(content_hash,)
|
|
).fetchone()
|
|
|
|
if existing:
|
|
return {
|
|
"status": "duplicate",
|
|
"observation_id": existing["id"],
|
|
"message": "Observation already exists"
|
|
}
|
|
|
|
# Embed in ChromaDB
|
|
embed_result = self.rag_engine.ingest(
|
|
content=content,
|
|
title=title or f"Observation: {obs_type}",
|
|
source=f"memory:{user_id}:{obs_type}",
|
|
doc_type="observation",
|
|
collection=collection,
|
|
)
|
|
embedding_id = embed_result.get("doc_id")
|
|
|
|
# Store in SQLite
|
|
tags_str = ",".join(tags) if tags else None
|
|
metadata_str = str(metadata) if metadata else None
|
|
|
|
with self._get_conn() as conn:
|
|
cursor = conn.execute("""
|
|
INSERT INTO observations
|
|
(user_id, type, title, content, content_hash, embedding_id,
|
|
session_id, tool_name, importance, tags, metadata)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
user_id, obs_type, title, content, content_hash, embedding_id,
|
|
session_id, tool_name, importance, tags_str, metadata_str
|
|
))
|
|
conn.commit()
|
|
obs_id = cursor.lastrowid
|
|
|
|
logger.info(f"Saved observation #{obs_id} for user {user_id} (type: {obs_type})")
|
|
return {
|
|
"status": "created",
|
|
"observation_id": obs_id,
|
|
"embedding_id": embedding_id,
|
|
"collection": collection,
|
|
}
|
|
|
|
def query_memory(
|
|
self,
|
|
user_id: str,
|
|
query: str,
|
|
top_k: int = 10,
|
|
obs_type: Optional[str] = None,
|
|
since: Optional[str] = None,
|
|
include_content: bool = False,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Search memory using hybrid SQLite + vector search.
|
|
Progressive disclosure: returns index by default, full content if requested.
|
|
"""
|
|
collection = self._get_collection_name(user_id)
|
|
|
|
# Vector search in ChromaDB
|
|
vector_results = self.rag_engine.query(
|
|
question=query,
|
|
top_k=top_k * 2, # Get more for filtering
|
|
collection=collection,
|
|
)
|
|
|
|
# Get observation IDs from embedding IDs
|
|
embedding_ids = [r.get("metadata", {}).get("doc_id") for r in vector_results.get("results", [])]
|
|
|
|
if not embedding_ids:
|
|
return {"results": [], "total": 0, "query": query}
|
|
|
|
# Fetch from SQLite with filters
|
|
placeholders = ",".join(["?" for _ in embedding_ids])
|
|
sql = f"""
|
|
SELECT id, user_id, timestamp, type, title, importance, tags, tool_name
|
|
{"" if not include_content else ", content"}
|
|
FROM observations
|
|
WHERE user_id = ? AND embedding_id IN ({placeholders})
|
|
"""
|
|
params = [user_id] + embedding_ids
|
|
|
|
if obs_type:
|
|
sql += " AND type = ?"
|
|
params.append(obs_type)
|
|
|
|
if since:
|
|
sql += " AND timestamp >= ?"
|
|
params.append(since)
|
|
|
|
sql += " ORDER BY timestamp DESC LIMIT ?"
|
|
params.append(top_k)
|
|
|
|
with self._get_conn() as conn:
|
|
rows = conn.execute(sql, params).fetchall()
|
|
|
|
results = []
|
|
for row in rows:
|
|
item = {
|
|
"id": row["id"],
|
|
"timestamp": row["timestamp"],
|
|
"type": row["type"],
|
|
"title": row["title"],
|
|
"importance": row["importance"],
|
|
"tags": row["tags"].split(",") if row["tags"] else [],
|
|
"tool_name": row["tool_name"],
|
|
}
|
|
if include_content:
|
|
item["content"] = row["content"]
|
|
results.append(item)
|
|
|
|
return {
|
|
"results": results,
|
|
"total": len(results),
|
|
"query": query,
|
|
"collection": collection,
|
|
}
|
|
|
|
def get_observations(
|
|
self,
|
|
user_id: str,
|
|
ids: List[int],
|
|
) -> Dict[str, Any]:
|
|
"""Fetch full observation details by IDs."""
|
|
if not ids:
|
|
return {"observations": []}
|
|
|
|
placeholders = ",".join(["?" for _ in ids])
|
|
sql = f"""
|
|
SELECT * FROM observations
|
|
WHERE user_id = ? AND id IN ({placeholders})
|
|
ORDER BY timestamp DESC
|
|
"""
|
|
|
|
with self._get_conn() as conn:
|
|
rows = conn.execute(sql, [user_id] + ids).fetchall()
|
|
|
|
observations = []
|
|
for row in rows:
|
|
observations.append({
|
|
"id": row["id"],
|
|
"timestamp": row["timestamp"],
|
|
"type": row["type"],
|
|
"title": row["title"],
|
|
"content": row["content"],
|
|
"importance": row["importance"],
|
|
"tags": row["tags"].split(",") if row["tags"] else [],
|
|
"tool_name": row["tool_name"],
|
|
"session_id": row["session_id"],
|
|
"metadata": row["metadata"],
|
|
})
|
|
|
|
return {"observations": observations, "count": len(observations)}
|
|
|
|
def get_timeline(
|
|
self,
|
|
user_id: str,
|
|
around_id: Optional[int] = None,
|
|
around_time: Optional[str] = None,
|
|
window_minutes: int = 30,
|
|
limit: int = 20,
|
|
) -> Dict[str, Any]:
|
|
"""Get chronological context around a specific observation or time."""
|
|
with self._get_conn() as conn:
|
|
if around_id:
|
|
# Get timestamp of reference observation
|
|
ref = conn.execute(
|
|
"SELECT timestamp FROM observations WHERE id = ? AND user_id = ?",
|
|
(around_id, user_id)
|
|
).fetchone()
|
|
if not ref:
|
|
return {"error": "Observation not found", "timeline": []}
|
|
center_time = ref["timestamp"]
|
|
elif around_time:
|
|
center_time = around_time
|
|
else:
|
|
center_time = datetime.now().isoformat()
|
|
|
|
# Get observations in time window
|
|
rows = conn.execute("""
|
|
SELECT id, timestamp, type, title, importance, tool_name
|
|
FROM observations
|
|
WHERE user_id = ?
|
|
AND datetime(timestamp) BETWEEN
|
|
datetime(?, '-' || ? || ' minutes')
|
|
AND datetime(?, '+' || ? || ' minutes')
|
|
ORDER BY timestamp
|
|
LIMIT ?
|
|
""", (user_id, center_time, window_minutes, center_time, window_minutes, limit)).fetchall()
|
|
|
|
timeline = [{
|
|
"id": row["id"],
|
|
"timestamp": row["timestamp"],
|
|
"type": row["type"],
|
|
"title": row["title"],
|
|
"importance": row["importance"],
|
|
"tool_name": row["tool_name"],
|
|
} for row in rows]
|
|
|
|
return {
|
|
"timeline": timeline,
|
|
"center_time": center_time,
|
|
"window_minutes": window_minutes,
|
|
"count": len(timeline),
|
|
}
|
|
|
|
def save_preference(
|
|
self,
|
|
user_id: str,
|
|
key: str,
|
|
value: str,
|
|
) -> Dict[str, Any]:
|
|
"""Save or update a user preference."""
|
|
with self._get_conn() as conn:
|
|
conn.execute("""
|
|
INSERT INTO preferences (user_id, key, value)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id, key) DO UPDATE SET
|
|
value = excluded.value,
|
|
timestamp = CURRENT_TIMESTAMP
|
|
""", (user_id, key, value))
|
|
conn.commit()
|
|
|
|
return {"status": "saved", "user_id": user_id, "key": key}
|
|
|
|
def get_preferences(self, user_id: str) -> Dict[str, str]:
|
|
"""Get all preferences for a user."""
|
|
with self._get_conn() as conn:
|
|
rows = conn.execute(
|
|
"SELECT key, value FROM preferences WHERE user_id = ?",
|
|
(user_id,)
|
|
).fetchall()
|
|
|
|
return {row["key"]: row["value"] for row in rows}
|
|
|
|
def get_stats(self, user_id: str) -> Dict[str, Any]:
|
|
"""Get memory statistics for a user."""
|
|
with self._get_conn() as conn:
|
|
total = conn.execute(
|
|
"SELECT COUNT(*) as c FROM observations WHERE user_id = ?",
|
|
(user_id,)
|
|
).fetchone()["c"]
|
|
|
|
by_type = conn.execute("""
|
|
SELECT type, COUNT(*) as c
|
|
FROM observations WHERE user_id = ?
|
|
GROUP BY type
|
|
""", (user_id,)).fetchall()
|
|
|
|
recent = conn.execute("""
|
|
SELECT COUNT(*) as c FROM observations
|
|
WHERE user_id = ? AND timestamp >= datetime('now', '-7 days')
|
|
""", (user_id,)).fetchone()["c"]
|
|
|
|
return {
|
|
"user_id": user_id,
|
|
"total_observations": total,
|
|
"by_type": {row["type"]: row["c"] for row in by_type},
|
|
"last_7_days": recent,
|
|
"collection": self._get_collection_name(user_id),
|
|
}
|