Files
moxie-rag/app/memory_store.py
Adolfo Delorenzo 76d8f9349e feat: Add memory system with SQLite + ChromaDB hybrid storage
- memory_store.py: User-isolated observation storage with vector embeddings
- New endpoints: /memory/save, /memory/query, /memory/get, /memory/timeline
- Progressive disclosure pattern for token-efficient retrieval
- Updated Dockerfile to ROCm 7.2 nightly
2026-02-09 15:42:43 -06:00

379 lines
13 KiB
Python

#!/usr/bin/env python3
"""
Memory Store — SQLite + ChromaDB hybrid for agent observations.
Provides structured storage with vector search.
"""
import sqlite3
import hashlib
import logging
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Dict, Any
from contextlib import contextmanager
logger = logging.getLogger("moxie-rag.memory")
class MemoryStore:
"""SQLite-backed memory store with ChromaDB integration."""
def __init__(self, db_path: str, rag_engine):
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.rag_engine = rag_engine
self._init_db()
@contextmanager
def _get_conn(self):
"""Thread-safe connection context manager."""
conn = sqlite3.connect(str(self.db_path), timeout=30.0)
conn.row_factory = sqlite3.Row
try:
yield conn
finally:
conn.close()
def _init_db(self):
"""Initialize the SQLite schema."""
with self._get_conn() as conn:
conn.executescript("""
CREATE TABLE IF NOT EXISTS observations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
type TEXT NOT NULL,
title TEXT,
content TEXT NOT NULL,
content_hash TEXT UNIQUE,
embedding_id TEXT,
session_id TEXT,
tool_name TEXT,
importance INTEGER DEFAULT 1,
tags TEXT,
metadata TEXT
);
CREATE INDEX IF NOT EXISTS idx_obs_user ON observations(user_id);
CREATE INDEX IF NOT EXISTS idx_obs_type ON observations(type);
CREATE INDEX IF NOT EXISTS idx_obs_timestamp ON observations(timestamp);
CREATE INDEX IF NOT EXISTS idx_obs_session ON observations(session_id);
CREATE TABLE IF NOT EXISTS preferences (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, key)
);
CREATE INDEX IF NOT EXISTS idx_pref_user ON preferences(user_id);
CREATE TABLE IF NOT EXISTS relationships (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
observation_id INTEGER,
related_id INTEGER,
relation_type TEXT,
FOREIGN KEY (observation_id) REFERENCES observations(id),
FOREIGN KEY (related_id) REFERENCES observations(id)
);
""")
conn.commit()
logger.info(f"Memory store initialized at {self.db_path}")
def _content_hash(self, content: str) -> str:
"""Generate hash for deduplication."""
return hashlib.sha256(content.encode()).hexdigest()[:16]
def _get_collection_name(self, user_id: str) -> str:
"""Get ChromaDB collection name for user."""
return f"moxie_memory_{user_id}"
def save_observation(
self,
user_id: str,
content: str,
obs_type: str = "general",
title: Optional[str] = None,
session_id: Optional[str] = None,
tool_name: Optional[str] = None,
importance: int = 1,
tags: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
) -> Dict[str, Any]:
"""
Save an observation to SQLite and embed in ChromaDB.
Returns the observation ID and embedding status.
"""
content_hash = self._content_hash(content)
collection = self._get_collection_name(user_id)
# Check for duplicate
with self._get_conn() as conn:
existing = conn.execute(
"SELECT id FROM observations WHERE content_hash = ?",
(content_hash,)
).fetchone()
if existing:
return {
"status": "duplicate",
"observation_id": existing["id"],
"message": "Observation already exists"
}
# Embed in ChromaDB
embed_result = self.rag_engine.ingest(
content=content,
title=title or f"Observation: {obs_type}",
source=f"memory:{user_id}:{obs_type}",
doc_type="observation",
collection=collection,
)
embedding_id = embed_result.get("doc_id")
# Store in SQLite
tags_str = ",".join(tags) if tags else None
metadata_str = str(metadata) if metadata else None
with self._get_conn() as conn:
cursor = conn.execute("""
INSERT INTO observations
(user_id, type, title, content, content_hash, embedding_id,
session_id, tool_name, importance, tags, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
user_id, obs_type, title, content, content_hash, embedding_id,
session_id, tool_name, importance, tags_str, metadata_str
))
conn.commit()
obs_id = cursor.lastrowid
logger.info(f"Saved observation #{obs_id} for user {user_id} (type: {obs_type})")
return {
"status": "created",
"observation_id": obs_id,
"embedding_id": embedding_id,
"collection": collection,
}
def query_memory(
self,
user_id: str,
query: str,
top_k: int = 10,
obs_type: Optional[str] = None,
since: Optional[str] = None,
include_content: bool = False,
) -> Dict[str, Any]:
"""
Search memory using hybrid SQLite + vector search.
Progressive disclosure: returns index by default, full content if requested.
"""
collection = self._get_collection_name(user_id)
# Vector search in ChromaDB
vector_results = self.rag_engine.query(
question=query,
top_k=top_k * 2, # Get more for filtering
collection=collection,
)
# Get observation IDs from embedding IDs
embedding_ids = [r.get("metadata", {}).get("doc_id") for r in vector_results.get("results", [])]
if not embedding_ids:
return {"results": [], "total": 0, "query": query}
# Fetch from SQLite with filters
placeholders = ",".join(["?" for _ in embedding_ids])
sql = f"""
SELECT id, user_id, timestamp, type, title, importance, tags, tool_name
{"" if not include_content else ", content"}
FROM observations
WHERE user_id = ? AND embedding_id IN ({placeholders})
"""
params = [user_id] + embedding_ids
if obs_type:
sql += " AND type = ?"
params.append(obs_type)
if since:
sql += " AND timestamp >= ?"
params.append(since)
sql += " ORDER BY timestamp DESC LIMIT ?"
params.append(top_k)
with self._get_conn() as conn:
rows = conn.execute(sql, params).fetchall()
results = []
for row in rows:
item = {
"id": row["id"],
"timestamp": row["timestamp"],
"type": row["type"],
"title": row["title"],
"importance": row["importance"],
"tags": row["tags"].split(",") if row["tags"] else [],
"tool_name": row["tool_name"],
}
if include_content:
item["content"] = row["content"]
results.append(item)
return {
"results": results,
"total": len(results),
"query": query,
"collection": collection,
}
def get_observations(
self,
user_id: str,
ids: List[int],
) -> Dict[str, Any]:
"""Fetch full observation details by IDs."""
if not ids:
return {"observations": []}
placeholders = ",".join(["?" for _ in ids])
sql = f"""
SELECT * FROM observations
WHERE user_id = ? AND id IN ({placeholders})
ORDER BY timestamp DESC
"""
with self._get_conn() as conn:
rows = conn.execute(sql, [user_id] + ids).fetchall()
observations = []
for row in rows:
observations.append({
"id": row["id"],
"timestamp": row["timestamp"],
"type": row["type"],
"title": row["title"],
"content": row["content"],
"importance": row["importance"],
"tags": row["tags"].split(",") if row["tags"] else [],
"tool_name": row["tool_name"],
"session_id": row["session_id"],
"metadata": row["metadata"],
})
return {"observations": observations, "count": len(observations)}
def get_timeline(
self,
user_id: str,
around_id: Optional[int] = None,
around_time: Optional[str] = None,
window_minutes: int = 30,
limit: int = 20,
) -> Dict[str, Any]:
"""Get chronological context around a specific observation or time."""
with self._get_conn() as conn:
if around_id:
# Get timestamp of reference observation
ref = conn.execute(
"SELECT timestamp FROM observations WHERE id = ? AND user_id = ?",
(around_id, user_id)
).fetchone()
if not ref:
return {"error": "Observation not found", "timeline": []}
center_time = ref["timestamp"]
elif around_time:
center_time = around_time
else:
center_time = datetime.now().isoformat()
# Get observations in time window
rows = conn.execute("""
SELECT id, timestamp, type, title, importance, tool_name
FROM observations
WHERE user_id = ?
AND datetime(timestamp) BETWEEN
datetime(?, '-' || ? || ' minutes')
AND datetime(?, '+' || ? || ' minutes')
ORDER BY timestamp
LIMIT ?
""", (user_id, center_time, window_minutes, center_time, window_minutes, limit)).fetchall()
timeline = [{
"id": row["id"],
"timestamp": row["timestamp"],
"type": row["type"],
"title": row["title"],
"importance": row["importance"],
"tool_name": row["tool_name"],
} for row in rows]
return {
"timeline": timeline,
"center_time": center_time,
"window_minutes": window_minutes,
"count": len(timeline),
}
def save_preference(
self,
user_id: str,
key: str,
value: str,
) -> Dict[str, Any]:
"""Save or update a user preference."""
with self._get_conn() as conn:
conn.execute("""
INSERT INTO preferences (user_id, key, value)
VALUES (?, ?, ?)
ON CONFLICT(user_id, key) DO UPDATE SET
value = excluded.value,
timestamp = CURRENT_TIMESTAMP
""", (user_id, key, value))
conn.commit()
return {"status": "saved", "user_id": user_id, "key": key}
def get_preferences(self, user_id: str) -> Dict[str, str]:
"""Get all preferences for a user."""
with self._get_conn() as conn:
rows = conn.execute(
"SELECT key, value FROM preferences WHERE user_id = ?",
(user_id,)
).fetchall()
return {row["key"]: row["value"] for row in rows}
def get_stats(self, user_id: str) -> Dict[str, Any]:
"""Get memory statistics for a user."""
with self._get_conn() as conn:
total = conn.execute(
"SELECT COUNT(*) as c FROM observations WHERE user_id = ?",
(user_id,)
).fetchone()["c"]
by_type = conn.execute("""
SELECT type, COUNT(*) as c
FROM observations WHERE user_id = ?
GROUP BY type
""", (user_id,)).fetchall()
recent = conn.execute("""
SELECT COUNT(*) as c FROM observations
WHERE user_id = ? AND timestamp >= datetime('now', '-7 days')
""", (user_id,)).fetchone()["c"]
return {
"user_id": user_id,
"total_observations": total,
"by_type": {row["type"]: row["c"] for row in by_type},
"last_7_days": recent,
"collection": self._get_collection_name(user_id),
}