feat: Add memory system with SQLite + ChromaDB hybrid storage
- memory_store.py: User-isolated observation storage with vector embeddings - New endpoints: /memory/save, /memory/query, /memory/get, /memory/timeline - Progressive disclosure pattern for token-efficient retrieval - Updated Dockerfile to ROCm 7.2 nightly
This commit is contained in:
378
app/memory_store.py
Normal file
378
app/memory_store.py
Normal file
@@ -0,0 +1,378 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Memory Store — SQLite + ChromaDB hybrid for agent observations.
|
||||
Provides structured storage with vector search.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any
|
||||
from contextlib import contextmanager
|
||||
|
||||
logger = logging.getLogger("moxie-rag.memory")
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
"""SQLite-backed memory store with ChromaDB integration."""
|
||||
|
||||
def __init__(self, db_path: str, rag_engine):
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.rag_engine = rag_engine
|
||||
self._init_db()
|
||||
|
||||
@contextmanager
|
||||
def _get_conn(self):
|
||||
"""Thread-safe connection context manager."""
|
||||
conn = sqlite3.connect(str(self.db_path), timeout=30.0)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize the SQLite schema."""
|
||||
with self._get_conn() as conn:
|
||||
conn.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS observations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT NOT NULL,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
type TEXT NOT NULL,
|
||||
title TEXT,
|
||||
content TEXT NOT NULL,
|
||||
content_hash TEXT UNIQUE,
|
||||
embedding_id TEXT,
|
||||
session_id TEXT,
|
||||
tool_name TEXT,
|
||||
importance INTEGER DEFAULT 1,
|
||||
tags TEXT,
|
||||
metadata TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_obs_user ON observations(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_obs_type ON observations(type);
|
||||
CREATE INDEX IF NOT EXISTS idx_obs_timestamp ON observations(timestamp);
|
||||
CREATE INDEX IF NOT EXISTS idx_obs_session ON observations(session_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS preferences (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(user_id, key)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_pref_user ON preferences(user_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS relationships (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT NOT NULL,
|
||||
observation_id INTEGER,
|
||||
related_id INTEGER,
|
||||
relation_type TEXT,
|
||||
FOREIGN KEY (observation_id) REFERENCES observations(id),
|
||||
FOREIGN KEY (related_id) REFERENCES observations(id)
|
||||
);
|
||||
""")
|
||||
conn.commit()
|
||||
logger.info(f"Memory store initialized at {self.db_path}")
|
||||
|
||||
def _content_hash(self, content: str) -> str:
|
||||
"""Generate hash for deduplication."""
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
||||
|
||||
def _get_collection_name(self, user_id: str) -> str:
|
||||
"""Get ChromaDB collection name for user."""
|
||||
return f"moxie_memory_{user_id}"
|
||||
|
||||
def save_observation(
|
||||
self,
|
||||
user_id: str,
|
||||
content: str,
|
||||
obs_type: str = "general",
|
||||
title: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
tool_name: Optional[str] = None,
|
||||
importance: int = 1,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Save an observation to SQLite and embed in ChromaDB.
|
||||
Returns the observation ID and embedding status.
|
||||
"""
|
||||
content_hash = self._content_hash(content)
|
||||
collection = self._get_collection_name(user_id)
|
||||
|
||||
# Check for duplicate
|
||||
with self._get_conn() as conn:
|
||||
existing = conn.execute(
|
||||
"SELECT id FROM observations WHERE content_hash = ?",
|
||||
(content_hash,)
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
return {
|
||||
"status": "duplicate",
|
||||
"observation_id": existing["id"],
|
||||
"message": "Observation already exists"
|
||||
}
|
||||
|
||||
# Embed in ChromaDB
|
||||
embed_result = self.rag_engine.ingest(
|
||||
content=content,
|
||||
title=title or f"Observation: {obs_type}",
|
||||
source=f"memory:{user_id}:{obs_type}",
|
||||
doc_type="observation",
|
||||
collection=collection,
|
||||
)
|
||||
embedding_id = embed_result.get("doc_id")
|
||||
|
||||
# Store in SQLite
|
||||
tags_str = ",".join(tags) if tags else None
|
||||
metadata_str = str(metadata) if metadata else None
|
||||
|
||||
with self._get_conn() as conn:
|
||||
cursor = conn.execute("""
|
||||
INSERT INTO observations
|
||||
(user_id, type, title, content, content_hash, embedding_id,
|
||||
session_id, tool_name, importance, tags, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
user_id, obs_type, title, content, content_hash, embedding_id,
|
||||
session_id, tool_name, importance, tags_str, metadata_str
|
||||
))
|
||||
conn.commit()
|
||||
obs_id = cursor.lastrowid
|
||||
|
||||
logger.info(f"Saved observation #{obs_id} for user {user_id} (type: {obs_type})")
|
||||
return {
|
||||
"status": "created",
|
||||
"observation_id": obs_id,
|
||||
"embedding_id": embedding_id,
|
||||
"collection": collection,
|
||||
}
|
||||
|
||||
def query_memory(
|
||||
self,
|
||||
user_id: str,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
obs_type: Optional[str] = None,
|
||||
since: Optional[str] = None,
|
||||
include_content: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Search memory using hybrid SQLite + vector search.
|
||||
Progressive disclosure: returns index by default, full content if requested.
|
||||
"""
|
||||
collection = self._get_collection_name(user_id)
|
||||
|
||||
# Vector search in ChromaDB
|
||||
vector_results = self.rag_engine.query(
|
||||
question=query,
|
||||
top_k=top_k * 2, # Get more for filtering
|
||||
collection=collection,
|
||||
)
|
||||
|
||||
# Get observation IDs from embedding IDs
|
||||
embedding_ids = [r.get("metadata", {}).get("doc_id") for r in vector_results.get("results", [])]
|
||||
|
||||
if not embedding_ids:
|
||||
return {"results": [], "total": 0, "query": query}
|
||||
|
||||
# Fetch from SQLite with filters
|
||||
placeholders = ",".join(["?" for _ in embedding_ids])
|
||||
sql = f"""
|
||||
SELECT id, user_id, timestamp, type, title, importance, tags, tool_name
|
||||
{"" if not include_content else ", content"}
|
||||
FROM observations
|
||||
WHERE user_id = ? AND embedding_id IN ({placeholders})
|
||||
"""
|
||||
params = [user_id] + embedding_ids
|
||||
|
||||
if obs_type:
|
||||
sql += " AND type = ?"
|
||||
params.append(obs_type)
|
||||
|
||||
if since:
|
||||
sql += " AND timestamp >= ?"
|
||||
params.append(since)
|
||||
|
||||
sql += " ORDER BY timestamp DESC LIMIT ?"
|
||||
params.append(top_k)
|
||||
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
item = {
|
||||
"id": row["id"],
|
||||
"timestamp": row["timestamp"],
|
||||
"type": row["type"],
|
||||
"title": row["title"],
|
||||
"importance": row["importance"],
|
||||
"tags": row["tags"].split(",") if row["tags"] else [],
|
||||
"tool_name": row["tool_name"],
|
||||
}
|
||||
if include_content:
|
||||
item["content"] = row["content"]
|
||||
results.append(item)
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"total": len(results),
|
||||
"query": query,
|
||||
"collection": collection,
|
||||
}
|
||||
|
||||
def get_observations(
|
||||
self,
|
||||
user_id: str,
|
||||
ids: List[int],
|
||||
) -> Dict[str, Any]:
|
||||
"""Fetch full observation details by IDs."""
|
||||
if not ids:
|
||||
return {"observations": []}
|
||||
|
||||
placeholders = ",".join(["?" for _ in ids])
|
||||
sql = f"""
|
||||
SELECT * FROM observations
|
||||
WHERE user_id = ? AND id IN ({placeholders})
|
||||
ORDER BY timestamp DESC
|
||||
"""
|
||||
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(sql, [user_id] + ids).fetchall()
|
||||
|
||||
observations = []
|
||||
for row in rows:
|
||||
observations.append({
|
||||
"id": row["id"],
|
||||
"timestamp": row["timestamp"],
|
||||
"type": row["type"],
|
||||
"title": row["title"],
|
||||
"content": row["content"],
|
||||
"importance": row["importance"],
|
||||
"tags": row["tags"].split(",") if row["tags"] else [],
|
||||
"tool_name": row["tool_name"],
|
||||
"session_id": row["session_id"],
|
||||
"metadata": row["metadata"],
|
||||
})
|
||||
|
||||
return {"observations": observations, "count": len(observations)}
|
||||
|
||||
def get_timeline(
|
||||
self,
|
||||
user_id: str,
|
||||
around_id: Optional[int] = None,
|
||||
around_time: Optional[str] = None,
|
||||
window_minutes: int = 30,
|
||||
limit: int = 20,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get chronological context around a specific observation or time."""
|
||||
with self._get_conn() as conn:
|
||||
if around_id:
|
||||
# Get timestamp of reference observation
|
||||
ref = conn.execute(
|
||||
"SELECT timestamp FROM observations WHERE id = ? AND user_id = ?",
|
||||
(around_id, user_id)
|
||||
).fetchone()
|
||||
if not ref:
|
||||
return {"error": "Observation not found", "timeline": []}
|
||||
center_time = ref["timestamp"]
|
||||
elif around_time:
|
||||
center_time = around_time
|
||||
else:
|
||||
center_time = datetime.now().isoformat()
|
||||
|
||||
# Get observations in time window
|
||||
rows = conn.execute("""
|
||||
SELECT id, timestamp, type, title, importance, tool_name
|
||||
FROM observations
|
||||
WHERE user_id = ?
|
||||
AND datetime(timestamp) BETWEEN
|
||||
datetime(?, '-' || ? || ' minutes')
|
||||
AND datetime(?, '+' || ? || ' minutes')
|
||||
ORDER BY timestamp
|
||||
LIMIT ?
|
||||
""", (user_id, center_time, window_minutes, center_time, window_minutes, limit)).fetchall()
|
||||
|
||||
timeline = [{
|
||||
"id": row["id"],
|
||||
"timestamp": row["timestamp"],
|
||||
"type": row["type"],
|
||||
"title": row["title"],
|
||||
"importance": row["importance"],
|
||||
"tool_name": row["tool_name"],
|
||||
} for row in rows]
|
||||
|
||||
return {
|
||||
"timeline": timeline,
|
||||
"center_time": center_time,
|
||||
"window_minutes": window_minutes,
|
||||
"count": len(timeline),
|
||||
}
|
||||
|
||||
def save_preference(
|
||||
self,
|
||||
user_id: str,
|
||||
key: str,
|
||||
value: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Save or update a user preference."""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute("""
|
||||
INSERT INTO preferences (user_id, key, value)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(user_id, key) DO UPDATE SET
|
||||
value = excluded.value,
|
||||
timestamp = CURRENT_TIMESTAMP
|
||||
""", (user_id, key, value))
|
||||
conn.commit()
|
||||
|
||||
return {"status": "saved", "user_id": user_id, "key": key}
|
||||
|
||||
def get_preferences(self, user_id: str) -> Dict[str, str]:
|
||||
"""Get all preferences for a user."""
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT key, value FROM preferences WHERE user_id = ?",
|
||||
(user_id,)
|
||||
).fetchall()
|
||||
|
||||
return {row["key"]: row["value"] for row in rows}
|
||||
|
||||
def get_stats(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get memory statistics for a user."""
|
||||
with self._get_conn() as conn:
|
||||
total = conn.execute(
|
||||
"SELECT COUNT(*) as c FROM observations WHERE user_id = ?",
|
||||
(user_id,)
|
||||
).fetchone()["c"]
|
||||
|
||||
by_type = conn.execute("""
|
||||
SELECT type, COUNT(*) as c
|
||||
FROM observations WHERE user_id = ?
|
||||
GROUP BY type
|
||||
""", (user_id,)).fetchall()
|
||||
|
||||
recent = conn.execute("""
|
||||
SELECT COUNT(*) as c FROM observations
|
||||
WHERE user_id = ? AND timestamp >= datetime('now', '-7 days')
|
||||
""", (user_id,)).fetchone()["c"]
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"total_observations": total,
|
||||
"by_type": {row["type"]: row["c"] for row in by_type},
|
||||
"last_7_days": recent,
|
||||
"collection": self._get_collection_name(user_id),
|
||||
}
|
||||
Reference in New Issue
Block a user