- memory_store.py: User-isolated observation storage with vector embeddings - New endpoints: /memory/save, /memory/query, /memory/get, /memory/timeline - Progressive disclosure pattern for token-efficient retrieval - Updated Dockerfile to ROCm 7.2 nightly
258 lines
9.4 KiB
Python
258 lines
9.4 KiB
Python
"""
|
|
RAG Engine — ChromaDB + sentence-transformers embedding logic.
|
|
Supports multiple collections for tenant isolation.
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import chromadb
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
from document_processor import chunk_text
|
|
|
|
logger = logging.getLogger("moxie-rag.engine")
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Detect best device
|
|
# ---------------------------------------------------------------------------
|
|
DEVICE = "cpu"
|
|
try:
|
|
import torch
|
|
|
|
if torch.cuda.is_available():
|
|
DEVICE = "cuda"
|
|
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "unknown"
|
|
logger.info(f"GPU detected: {gpu_name}")
|
|
except ImportError:
|
|
pass
|
|
|
|
logger.info(f"Embedding device: {DEVICE}")
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Constants
|
|
# ---------------------------------------------------------------------------
|
|
EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
|
DEFAULT_COLLECTION = "adolfo_docs"
|
|
|
|
|
|
class RAGEngine:
|
|
"""Manages embeddings and vector storage with multi-collection support."""
|
|
|
|
def __init__(self, data_dir: str = "/app/data/chromadb"):
|
|
Path(data_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
logger.info(f"Loading embedding model '{EMBEDDING_MODEL}' on {DEVICE} ...")
|
|
self.embedder = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)
|
|
logger.info("Embedding model loaded.")
|
|
|
|
self.chroma = chromadb.PersistentClient(path=data_dir)
|
|
|
|
# Pre-load default collection
|
|
self._collections: Dict[str, Any] = {}
|
|
self._get_collection(DEFAULT_COLLECTION)
|
|
logger.info(
|
|
f"ChromaDB collection '{DEFAULT_COLLECTION}' ready — "
|
|
f"{self._collections[DEFAULT_COLLECTION].count()} existing chunks."
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Collection management
|
|
# ------------------------------------------------------------------
|
|
def _get_collection(self, name: str):
|
|
"""Get or create a ChromaDB collection by name."""
|
|
if name not in self._collections:
|
|
self._collections[name] = self.chroma.get_or_create_collection(
|
|
name=name,
|
|
metadata={"hnsw:space": "cosine"},
|
|
)
|
|
logger.info(f"Collection '{name}' loaded ({self._collections[name].count()} chunks)")
|
|
return self._collections[name]
|
|
|
|
def list_collections(self) -> List[Dict[str, Any]]:
|
|
"""List all collections with their document counts."""
|
|
collections = self.chroma.list_collections()
|
|
result = []
|
|
for coll in collections:
|
|
# ChromaDB >= 1.x returns Collection objects, older versions return strings
|
|
name = coll if isinstance(coll, str) else coll.name
|
|
c = self._get_collection(name)
|
|
result.append({"name": name, "chunks": c.count()})
|
|
return result
|
|
|
|
# ------------------------------------------------------------------
|
|
# Helpers
|
|
# ------------------------------------------------------------------
|
|
@property
|
|
def device(self) -> str:
|
|
return DEVICE
|
|
|
|
@property
|
|
def model_name(self) -> str:
|
|
return EMBEDDING_MODEL
|
|
|
|
@property
|
|
def doc_count(self) -> int:
|
|
return self._get_collection(DEFAULT_COLLECTION).count()
|
|
|
|
def collection_count(self, collection: str = None) -> int:
|
|
return self._get_collection(collection or DEFAULT_COLLECTION).count()
|
|
|
|
@staticmethod
|
|
def _make_doc_id(title: str, source: str) -> str:
|
|
raw = f"{title}:{source}:{datetime.now().isoformat()}"
|
|
return hashlib.md5(raw.encode()).hexdigest()[:12]
|
|
|
|
def _embed(self, texts: List[str]) -> List[List[float]]:
|
|
return self.embedder.encode(texts, show_progress_bar=False).tolist()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Ingest
|
|
# ------------------------------------------------------------------
|
|
def ingest(
|
|
self,
|
|
content: str,
|
|
title: str = "Untitled",
|
|
source: str = "unknown",
|
|
date: Optional[str] = None,
|
|
doc_type: str = "text",
|
|
auto_chunk: bool = True,
|
|
collection: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Chunk, embed, and store content in specified collection."""
|
|
coll = self._get_collection(collection or DEFAULT_COLLECTION)
|
|
doc_id = self._make_doc_id(title, source)
|
|
date = date or datetime.now().isoformat()
|
|
|
|
chunks = chunk_text(content) if auto_chunk else [content.strip()]
|
|
chunks = [c for c in chunks if c]
|
|
if not chunks:
|
|
raise ValueError("No content to ingest after processing")
|
|
|
|
embeddings = self._embed(chunks)
|
|
|
|
ids = [f"{doc_id}_chunk_{i}" for i in range(len(chunks))]
|
|
metadatas = [
|
|
{
|
|
"doc_id": doc_id,
|
|
"title": title,
|
|
"source": source,
|
|
"date": date,
|
|
"doc_type": doc_type,
|
|
"chunk_index": i,
|
|
"total_chunks": len(chunks),
|
|
}
|
|
for i in range(len(chunks))
|
|
]
|
|
|
|
coll.add(
|
|
ids=ids,
|
|
embeddings=embeddings,
|
|
documents=chunks,
|
|
metadatas=metadatas,
|
|
)
|
|
|
|
coll_name = collection or DEFAULT_COLLECTION
|
|
logger.info(f"Ingested '{title}' ({len(chunks)} chunks) [{doc_id}] → {coll_name}")
|
|
return {
|
|
"doc_id": doc_id,
|
|
"title": title,
|
|
"chunks_created": len(chunks),
|
|
"total_documents": coll.count(),
|
|
"collection": coll_name,
|
|
}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Query
|
|
# ------------------------------------------------------------------
|
|
def query(
|
|
self,
|
|
question: str,
|
|
top_k: int = 5,
|
|
filter_type: Optional[str] = None,
|
|
collection: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Semantic search over indexed chunks in specified collection."""
|
|
coll = self._get_collection(collection or DEFAULT_COLLECTION)
|
|
|
|
if coll.count() == 0:
|
|
return {"question": question, "results": [], "total_results": 0, "collection": collection or DEFAULT_COLLECTION}
|
|
|
|
query_emb = self._embed([question])
|
|
where = {"doc_type": filter_type} if filter_type else None
|
|
|
|
results = coll.query(
|
|
query_embeddings=query_emb,
|
|
n_results=min(top_k, coll.count()),
|
|
where=where,
|
|
include=["documents", "metadatas", "distances"],
|
|
)
|
|
|
|
formatted = []
|
|
if results and results["ids"] and results["ids"][0]:
|
|
for i, cid in enumerate(results["ids"][0]):
|
|
formatted.append(
|
|
{
|
|
"chunk_id": cid,
|
|
"content": results["documents"][0][i],
|
|
"metadata": results["metadatas"][0][i],
|
|
"distance": results["distances"][0][i],
|
|
}
|
|
)
|
|
|
|
coll_name = collection or DEFAULT_COLLECTION
|
|
logger.info(f"Query [{coll_name}]: '{question}' → {len(formatted)} results")
|
|
return {
|
|
"question": question,
|
|
"results": formatted,
|
|
"total_results": len(formatted),
|
|
"collection": coll_name,
|
|
}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Document management
|
|
# ------------------------------------------------------------------
|
|
def list_documents(self, collection: Optional[str] = None) -> Dict[str, Any]:
|
|
"""List all indexed documents grouped by doc_id in specified collection."""
|
|
coll = self._get_collection(collection or DEFAULT_COLLECTION)
|
|
|
|
if coll.count() == 0:
|
|
return {"documents": [], "total": 0, "collection": collection or DEFAULT_COLLECTION}
|
|
|
|
all_data = coll.get(include=["metadatas"])
|
|
docs: Dict[str, Dict] = {}
|
|
for meta in all_data["metadatas"]:
|
|
did = meta.get("doc_id", "unknown")
|
|
if did not in docs:
|
|
docs[did] = {
|
|
"doc_id": did,
|
|
"title": meta.get("title", "Unknown"),
|
|
"source": meta.get("source", "unknown"),
|
|
"doc_type": meta.get("doc_type", "text"),
|
|
"date": meta.get("date", "unknown"),
|
|
"chunk_count": 0,
|
|
}
|
|
docs[did]["chunk_count"] += 1
|
|
|
|
return {"documents": list(docs.values()), "total": len(docs), "collection": collection or DEFAULT_COLLECTION}
|
|
|
|
def delete_document(self, doc_id: str, collection: Optional[str] = None) -> Dict[str, Any]:
|
|
"""Delete all chunks belonging to a document in specified collection."""
|
|
coll = self._get_collection(collection or DEFAULT_COLLECTION)
|
|
all_data = coll.get(include=["metadatas"])
|
|
ids_to_delete = [
|
|
all_data["ids"][i]
|
|
for i, m in enumerate(all_data["metadatas"])
|
|
if m.get("doc_id") == doc_id
|
|
]
|
|
|
|
if not ids_to_delete:
|
|
raise KeyError(f"Document '{doc_id}' not found")
|
|
|
|
coll.delete(ids=ids_to_delete)
|
|
logger.info(f"Deleted {doc_id} ({len(ids_to_delete)} chunks)")
|
|
return {"deleted": doc_id, "chunks_removed": len(ids_to_delete)}
|