feat: Add memory system with SQLite + ChromaDB hybrid storage
- memory_store.py: User-isolated observation storage with vector embeddings - New endpoints: /memory/save, /memory/query, /memory/get, /memory/timeline - Progressive disclosure pattern for token-efficient retrieval - Updated Dockerfile to ROCm 7.2 nightly
This commit is contained in:
257
app/rag_engine.py
Normal file
257
app/rag_engine.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""
|
||||
RAG Engine — ChromaDB + sentence-transformers embedding logic.
|
||||
Supports multiple collections for tenant isolation.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import chromadb
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from document_processor import chunk_text
|
||||
|
||||
logger = logging.getLogger("moxie-rag.engine")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detect best device
|
||||
# ---------------------------------------------------------------------------
|
||||
DEVICE = "cpu"
|
||||
try:
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
DEVICE = "cuda"
|
||||
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "unknown"
|
||||
logger.info(f"GPU detected: {gpu_name}")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
logger.info(f"Embedding device: {DEVICE}")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
DEFAULT_COLLECTION = "adolfo_docs"
|
||||
|
||||
|
||||
class RAGEngine:
|
||||
"""Manages embeddings and vector storage with multi-collection support."""
|
||||
|
||||
def __init__(self, data_dir: str = "/app/data/chromadb"):
|
||||
Path(data_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"Loading embedding model '{EMBEDDING_MODEL}' on {DEVICE} ...")
|
||||
self.embedder = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)
|
||||
logger.info("Embedding model loaded.")
|
||||
|
||||
self.chroma = chromadb.PersistentClient(path=data_dir)
|
||||
|
||||
# Pre-load default collection
|
||||
self._collections: Dict[str, Any] = {}
|
||||
self._get_collection(DEFAULT_COLLECTION)
|
||||
logger.info(
|
||||
f"ChromaDB collection '{DEFAULT_COLLECTION}' ready — "
|
||||
f"{self._collections[DEFAULT_COLLECTION].count()} existing chunks."
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Collection management
|
||||
# ------------------------------------------------------------------
|
||||
def _get_collection(self, name: str):
|
||||
"""Get or create a ChromaDB collection by name."""
|
||||
if name not in self._collections:
|
||||
self._collections[name] = self.chroma.get_or_create_collection(
|
||||
name=name,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
)
|
||||
logger.info(f"Collection '{name}' loaded ({self._collections[name].count()} chunks)")
|
||||
return self._collections[name]
|
||||
|
||||
def list_collections(self) -> List[Dict[str, Any]]:
|
||||
"""List all collections with their document counts."""
|
||||
collections = self.chroma.list_collections()
|
||||
result = []
|
||||
for coll in collections:
|
||||
# ChromaDB >= 1.x returns Collection objects, older versions return strings
|
||||
name = coll if isinstance(coll, str) else coll.name
|
||||
c = self._get_collection(name)
|
||||
result.append({"name": name, "chunks": c.count()})
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
@property
|
||||
def device(self) -> str:
|
||||
return DEVICE
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return EMBEDDING_MODEL
|
||||
|
||||
@property
|
||||
def doc_count(self) -> int:
|
||||
return self._get_collection(DEFAULT_COLLECTION).count()
|
||||
|
||||
def collection_count(self, collection: str = None) -> int:
|
||||
return self._get_collection(collection or DEFAULT_COLLECTION).count()
|
||||
|
||||
@staticmethod
|
||||
def _make_doc_id(title: str, source: str) -> str:
|
||||
raw = f"{title}:{source}:{datetime.now().isoformat()}"
|
||||
return hashlib.md5(raw.encode()).hexdigest()[:12]
|
||||
|
||||
def _embed(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embedder.encode(texts, show_progress_bar=False).tolist()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Ingest
|
||||
# ------------------------------------------------------------------
|
||||
def ingest(
|
||||
self,
|
||||
content: str,
|
||||
title: str = "Untitled",
|
||||
source: str = "unknown",
|
||||
date: Optional[str] = None,
|
||||
doc_type: str = "text",
|
||||
auto_chunk: bool = True,
|
||||
collection: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Chunk, embed, and store content in specified collection."""
|
||||
coll = self._get_collection(collection or DEFAULT_COLLECTION)
|
||||
doc_id = self._make_doc_id(title, source)
|
||||
date = date or datetime.now().isoformat()
|
||||
|
||||
chunks = chunk_text(content) if auto_chunk else [content.strip()]
|
||||
chunks = [c for c in chunks if c]
|
||||
if not chunks:
|
||||
raise ValueError("No content to ingest after processing")
|
||||
|
||||
embeddings = self._embed(chunks)
|
||||
|
||||
ids = [f"{doc_id}_chunk_{i}" for i in range(len(chunks))]
|
||||
metadatas = [
|
||||
{
|
||||
"doc_id": doc_id,
|
||||
"title": title,
|
||||
"source": source,
|
||||
"date": date,
|
||||
"doc_type": doc_type,
|
||||
"chunk_index": i,
|
||||
"total_chunks": len(chunks),
|
||||
}
|
||||
for i in range(len(chunks))
|
||||
]
|
||||
|
||||
coll.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=chunks,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
coll_name = collection or DEFAULT_COLLECTION
|
||||
logger.info(f"Ingested '{title}' ({len(chunks)} chunks) [{doc_id}] → {coll_name}")
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"title": title,
|
||||
"chunks_created": len(chunks),
|
||||
"total_documents": coll.count(),
|
||||
"collection": coll_name,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query
|
||||
# ------------------------------------------------------------------
|
||||
def query(
|
||||
self,
|
||||
question: str,
|
||||
top_k: int = 5,
|
||||
filter_type: Optional[str] = None,
|
||||
collection: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Semantic search over indexed chunks in specified collection."""
|
||||
coll = self._get_collection(collection or DEFAULT_COLLECTION)
|
||||
|
||||
if coll.count() == 0:
|
||||
return {"question": question, "results": [], "total_results": 0, "collection": collection or DEFAULT_COLLECTION}
|
||||
|
||||
query_emb = self._embed([question])
|
||||
where = {"doc_type": filter_type} if filter_type else None
|
||||
|
||||
results = coll.query(
|
||||
query_embeddings=query_emb,
|
||||
n_results=min(top_k, coll.count()),
|
||||
where=where,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
|
||||
formatted = []
|
||||
if results and results["ids"] and results["ids"][0]:
|
||||
for i, cid in enumerate(results["ids"][0]):
|
||||
formatted.append(
|
||||
{
|
||||
"chunk_id": cid,
|
||||
"content": results["documents"][0][i],
|
||||
"metadata": results["metadatas"][0][i],
|
||||
"distance": results["distances"][0][i],
|
||||
}
|
||||
)
|
||||
|
||||
coll_name = collection or DEFAULT_COLLECTION
|
||||
logger.info(f"Query [{coll_name}]: '{question}' → {len(formatted)} results")
|
||||
return {
|
||||
"question": question,
|
||||
"results": formatted,
|
||||
"total_results": len(formatted),
|
||||
"collection": coll_name,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Document management
|
||||
# ------------------------------------------------------------------
|
||||
def list_documents(self, collection: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""List all indexed documents grouped by doc_id in specified collection."""
|
||||
coll = self._get_collection(collection or DEFAULT_COLLECTION)
|
||||
|
||||
if coll.count() == 0:
|
||||
return {"documents": [], "total": 0, "collection": collection or DEFAULT_COLLECTION}
|
||||
|
||||
all_data = coll.get(include=["metadatas"])
|
||||
docs: Dict[str, Dict] = {}
|
||||
for meta in all_data["metadatas"]:
|
||||
did = meta.get("doc_id", "unknown")
|
||||
if did not in docs:
|
||||
docs[did] = {
|
||||
"doc_id": did,
|
||||
"title": meta.get("title", "Unknown"),
|
||||
"source": meta.get("source", "unknown"),
|
||||
"doc_type": meta.get("doc_type", "text"),
|
||||
"date": meta.get("date", "unknown"),
|
||||
"chunk_count": 0,
|
||||
}
|
||||
docs[did]["chunk_count"] += 1
|
||||
|
||||
return {"documents": list(docs.values()), "total": len(docs), "collection": collection or DEFAULT_COLLECTION}
|
||||
|
||||
def delete_document(self, doc_id: str, collection: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Delete all chunks belonging to a document in specified collection."""
|
||||
coll = self._get_collection(collection or DEFAULT_COLLECTION)
|
||||
all_data = coll.get(include=["metadatas"])
|
||||
ids_to_delete = [
|
||||
all_data["ids"][i]
|
||||
for i, m in enumerate(all_data["metadatas"])
|
||||
if m.get("doc_id") == doc_id
|
||||
]
|
||||
|
||||
if not ids_to_delete:
|
||||
raise KeyError(f"Document '{doc_id}' not found")
|
||||
|
||||
coll.delete(ids=ids_to_delete)
|
||||
logger.info(f"Deleted {doc_id} ({len(ids_to_delete)} chunks)")
|
||||
return {"deleted": doc_id, "chunks_removed": len(ids_to_delete)}
|
||||
Reference in New Issue
Block a user