Files
moxie-rag/app/rag_engine.py
Adolfo Delorenzo 76d8f9349e feat: Add memory system with SQLite + ChromaDB hybrid storage
- memory_store.py: User-isolated observation storage with vector embeddings
- New endpoints: /memory/save, /memory/query, /memory/get, /memory/timeline
- Progressive disclosure pattern for token-efficient retrieval
- Updated Dockerfile to ROCm 7.2 nightly
2026-02-09 15:42:43 -06:00

258 lines
9.4 KiB
Python

"""
RAG Engine — ChromaDB + sentence-transformers embedding logic.
Supports multiple collections for tenant isolation.
"""
import hashlib
import logging
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import chromadb
from sentence_transformers import SentenceTransformer
from document_processor import chunk_text
logger = logging.getLogger("moxie-rag.engine")
# ---------------------------------------------------------------------------
# Detect best device
# ---------------------------------------------------------------------------
DEVICE = "cpu"
try:
import torch
if torch.cuda.is_available():
DEVICE = "cuda"
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "unknown"
logger.info(f"GPU detected: {gpu_name}")
except ImportError:
pass
logger.info(f"Embedding device: {DEVICE}")
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
DEFAULT_COLLECTION = "adolfo_docs"
class RAGEngine:
"""Manages embeddings and vector storage with multi-collection support."""
def __init__(self, data_dir: str = "/app/data/chromadb"):
Path(data_dir).mkdir(parents=True, exist_ok=True)
logger.info(f"Loading embedding model '{EMBEDDING_MODEL}' on {DEVICE} ...")
self.embedder = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)
logger.info("Embedding model loaded.")
self.chroma = chromadb.PersistentClient(path=data_dir)
# Pre-load default collection
self._collections: Dict[str, Any] = {}
self._get_collection(DEFAULT_COLLECTION)
logger.info(
f"ChromaDB collection '{DEFAULT_COLLECTION}' ready — "
f"{self._collections[DEFAULT_COLLECTION].count()} existing chunks."
)
# ------------------------------------------------------------------
# Collection management
# ------------------------------------------------------------------
def _get_collection(self, name: str):
"""Get or create a ChromaDB collection by name."""
if name not in self._collections:
self._collections[name] = self.chroma.get_or_create_collection(
name=name,
metadata={"hnsw:space": "cosine"},
)
logger.info(f"Collection '{name}' loaded ({self._collections[name].count()} chunks)")
return self._collections[name]
def list_collections(self) -> List[Dict[str, Any]]:
"""List all collections with their document counts."""
collections = self.chroma.list_collections()
result = []
for coll in collections:
# ChromaDB >= 1.x returns Collection objects, older versions return strings
name = coll if isinstance(coll, str) else coll.name
c = self._get_collection(name)
result.append({"name": name, "chunks": c.count()})
return result
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@property
def device(self) -> str:
return DEVICE
@property
def model_name(self) -> str:
return EMBEDDING_MODEL
@property
def doc_count(self) -> int:
return self._get_collection(DEFAULT_COLLECTION).count()
def collection_count(self, collection: str = None) -> int:
return self._get_collection(collection or DEFAULT_COLLECTION).count()
@staticmethod
def _make_doc_id(title: str, source: str) -> str:
raw = f"{title}:{source}:{datetime.now().isoformat()}"
return hashlib.md5(raw.encode()).hexdigest()[:12]
def _embed(self, texts: List[str]) -> List[List[float]]:
return self.embedder.encode(texts, show_progress_bar=False).tolist()
# ------------------------------------------------------------------
# Ingest
# ------------------------------------------------------------------
def ingest(
self,
content: str,
title: str = "Untitled",
source: str = "unknown",
date: Optional[str] = None,
doc_type: str = "text",
auto_chunk: bool = True,
collection: Optional[str] = None,
) -> Dict[str, Any]:
"""Chunk, embed, and store content in specified collection."""
coll = self._get_collection(collection or DEFAULT_COLLECTION)
doc_id = self._make_doc_id(title, source)
date = date or datetime.now().isoformat()
chunks = chunk_text(content) if auto_chunk else [content.strip()]
chunks = [c for c in chunks if c]
if not chunks:
raise ValueError("No content to ingest after processing")
embeddings = self._embed(chunks)
ids = [f"{doc_id}_chunk_{i}" for i in range(len(chunks))]
metadatas = [
{
"doc_id": doc_id,
"title": title,
"source": source,
"date": date,
"doc_type": doc_type,
"chunk_index": i,
"total_chunks": len(chunks),
}
for i in range(len(chunks))
]
coll.add(
ids=ids,
embeddings=embeddings,
documents=chunks,
metadatas=metadatas,
)
coll_name = collection or DEFAULT_COLLECTION
logger.info(f"Ingested '{title}' ({len(chunks)} chunks) [{doc_id}] → {coll_name}")
return {
"doc_id": doc_id,
"title": title,
"chunks_created": len(chunks),
"total_documents": coll.count(),
"collection": coll_name,
}
# ------------------------------------------------------------------
# Query
# ------------------------------------------------------------------
def query(
self,
question: str,
top_k: int = 5,
filter_type: Optional[str] = None,
collection: Optional[str] = None,
) -> Dict[str, Any]:
"""Semantic search over indexed chunks in specified collection."""
coll = self._get_collection(collection or DEFAULT_COLLECTION)
if coll.count() == 0:
return {"question": question, "results": [], "total_results": 0, "collection": collection or DEFAULT_COLLECTION}
query_emb = self._embed([question])
where = {"doc_type": filter_type} if filter_type else None
results = coll.query(
query_embeddings=query_emb,
n_results=min(top_k, coll.count()),
where=where,
include=["documents", "metadatas", "distances"],
)
formatted = []
if results and results["ids"] and results["ids"][0]:
for i, cid in enumerate(results["ids"][0]):
formatted.append(
{
"chunk_id": cid,
"content": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i],
}
)
coll_name = collection or DEFAULT_COLLECTION
logger.info(f"Query [{coll_name}]: '{question}'{len(formatted)} results")
return {
"question": question,
"results": formatted,
"total_results": len(formatted),
"collection": coll_name,
}
# ------------------------------------------------------------------
# Document management
# ------------------------------------------------------------------
def list_documents(self, collection: Optional[str] = None) -> Dict[str, Any]:
"""List all indexed documents grouped by doc_id in specified collection."""
coll = self._get_collection(collection or DEFAULT_COLLECTION)
if coll.count() == 0:
return {"documents": [], "total": 0, "collection": collection or DEFAULT_COLLECTION}
all_data = coll.get(include=["metadatas"])
docs: Dict[str, Dict] = {}
for meta in all_data["metadatas"]:
did = meta.get("doc_id", "unknown")
if did not in docs:
docs[did] = {
"doc_id": did,
"title": meta.get("title", "Unknown"),
"source": meta.get("source", "unknown"),
"doc_type": meta.get("doc_type", "text"),
"date": meta.get("date", "unknown"),
"chunk_count": 0,
}
docs[did]["chunk_count"] += 1
return {"documents": list(docs.values()), "total": len(docs), "collection": collection or DEFAULT_COLLECTION}
def delete_document(self, doc_id: str, collection: Optional[str] = None) -> Dict[str, Any]:
"""Delete all chunks belonging to a document in specified collection."""
coll = self._get_collection(collection or DEFAULT_COLLECTION)
all_data = coll.get(include=["metadatas"])
ids_to_delete = [
all_data["ids"][i]
for i, m in enumerate(all_data["metadatas"])
if m.get("doc_id") == doc_id
]
if not ids_to_delete:
raise KeyError(f"Document '{doc_id}' not found")
coll.delete(ids=ids_to_delete)
logger.info(f"Deleted {doc_id} ({len(ids_to_delete)} chunks)")
return {"deleted": doc_id, "chunks_removed": len(ids_to_delete)}