feat: Add memory system with SQLite + ChromaDB hybrid storage
- memory_store.py: User-isolated observation storage with vector embeddings - New endpoints: /memory/save, /memory/query, /memory/get, /memory/timeline - Progressive disclosure pattern for token-efficient retrieval - Updated Dockerfile to ROCm 7.2 nightly
This commit is contained in:
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
data/
|
||||
logs/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.env
|
||||
35
Dockerfile
Normal file
35
Dockerfile
Normal file
@@ -0,0 +1,35 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
# System deps: pdfplumber, ffmpeg for video audio extraction, build tools
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
curl \
|
||||
ffmpeg \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install PyTorch with ROCm support first (big layer, cache it)
|
||||
RUN pip install --no-cache-dir \
|
||||
torch torchvision torchaudio \
|
||||
--index-url https://download.pytorch.org/whl/nightly/rocm7.2/
|
||||
|
||||
# Install remaining Python dependencies
|
||||
COPY app/requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY app/ .
|
||||
|
||||
# Pre-download the embedding model at build time so startup is fast
|
||||
RUN python -c "\
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
m = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'); \
|
||||
print('Model cached:', m.encode(['test']).shape)"
|
||||
|
||||
EXPOSE 8899
|
||||
|
||||
VOLUME ["/app/data", "/app/logs"]
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8899", "--log-level", "info"]
|
||||
168
app/document_processor.py
Normal file
168
app/document_processor.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Document processing utilities for the RAG service.
|
||||
Handles text chunking and extraction from various file formats.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger("moxie-rag.processor")
|
||||
|
||||
# Approximate chars per token for multilingual text
|
||||
CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
def chunk_text(
|
||||
text: str,
|
||||
chunk_size: int = 500,
|
||||
overlap: int = 50,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Split text into chunks of approximately chunk_size tokens with overlap.
|
||||
"""
|
||||
char_size = chunk_size * CHARS_PER_TOKEN
|
||||
char_overlap = overlap * CHARS_PER_TOKEN
|
||||
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
if len(text) <= char_size:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
|
||||
while start < len(text):
|
||||
end = start + char_size
|
||||
|
||||
if end < len(text):
|
||||
window = text[start:end]
|
||||
best_break = -1
|
||||
for separator in ["\n\n", ".\n", ". ", "?\n", "? ", "!\n", "! ", "\n", ", ", " "]:
|
||||
pos = window.rfind(separator)
|
||||
if pos > char_size // 2:
|
||||
best_break = pos + len(separator)
|
||||
break
|
||||
if best_break > 0:
|
||||
end = start + best_break
|
||||
|
||||
chunk = text[start:end].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
next_start = end - char_overlap
|
||||
if next_start <= start:
|
||||
next_start = end
|
||||
start = next_start
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def extract_text_from_pdf(file_path: str) -> str:
|
||||
"""Extract text from a PDF file using pdfplumber."""
|
||||
import pdfplumber
|
||||
|
||||
text_parts = []
|
||||
with pdfplumber.open(file_path) as pdf:
|
||||
for i, page in enumerate(pdf.pages):
|
||||
page_text = page.extract_text()
|
||||
if page_text:
|
||||
text_parts.append(page_text)
|
||||
else:
|
||||
logger.debug(f"Page {i + 1}: no text extracted")
|
||||
|
||||
result = "\n\n".join(text_parts)
|
||||
logger.info(f"Extracted {len(result)} chars from PDF ({len(text_parts)} pages)")
|
||||
return result
|
||||
|
||||
|
||||
def extract_text_from_docx(file_path: str) -> str:
|
||||
"""Extract text from a DOCX file using python-docx."""
|
||||
from docx import Document
|
||||
|
||||
doc = Document(file_path)
|
||||
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||
result = "\n\n".join(paragraphs)
|
||||
logger.info(f"Extracted {len(result)} chars from DOCX ({len(paragraphs)} paragraphs)")
|
||||
return result
|
||||
|
||||
|
||||
def extract_text_from_excel(file_path: str) -> str:
|
||||
"""Extract text from Excel files (.xlsx, .xls) using openpyxl/pandas."""
|
||||
import pandas as pd
|
||||
|
||||
text_parts = []
|
||||
xls = pd.ExcelFile(file_path)
|
||||
|
||||
for sheet_name in xls.sheet_names:
|
||||
df = pd.read_excel(file_path, sheet_name=sheet_name)
|
||||
if df.empty:
|
||||
continue
|
||||
|
||||
text_parts.append(f"--- Sheet: {sheet_name} ---")
|
||||
|
||||
# Include column headers
|
||||
headers = " | ".join(str(c) for c in df.columns)
|
||||
text_parts.append(f"Columns: {headers}")
|
||||
|
||||
# Convert rows to readable text
|
||||
for idx, row in df.iterrows():
|
||||
row_text = " | ".join(
|
||||
f"{col}: {val}" for col, val in row.items()
|
||||
if pd.notna(val) and str(val).strip()
|
||||
)
|
||||
if row_text:
|
||||
text_parts.append(row_text)
|
||||
|
||||
result = "\n".join(text_parts)
|
||||
logger.info(f"Extracted {len(result)} chars from Excel ({len(xls.sheet_names)} sheets)")
|
||||
return result
|
||||
|
||||
|
||||
def extract_audio_from_video(video_path: str) -> str:
|
||||
"""Extract audio track from video file using ffmpeg. Returns path to wav file."""
|
||||
audio_path = tempfile.mktemp(suffix=".wav")
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg", "-i", video_path,
|
||||
"-vn", "-acodec", "pcm_s16le",
|
||||
"-ar", "16000", "-ac", "1",
|
||||
"-y", audio_path,
|
||||
],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
timeout=600,
|
||||
)
|
||||
logger.info(f"Extracted audio from video to {audio_path}")
|
||||
return audio_path
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"ffmpeg failed: {e.stderr.decode()}")
|
||||
raise ValueError(f"Could not extract audio from video: {e.stderr.decode()[:200]}")
|
||||
|
||||
|
||||
def extract_text_from_file(file_path: str, filename: str) -> str:
|
||||
"""
|
||||
Extract text from a file based on its extension.
|
||||
|
||||
Supported: .pdf, .docx, .doc, .txt, .md, .csv, .json, .html, .xlsx, .xls
|
||||
"""
|
||||
ext = Path(filename).suffix.lower()
|
||||
|
||||
if ext == ".pdf":
|
||||
return extract_text_from_pdf(file_path)
|
||||
elif ext in (".docx", ".doc"):
|
||||
return extract_text_from_docx(file_path)
|
||||
elif ext in (".xlsx", ".xls"):
|
||||
return extract_text_from_excel(file_path)
|
||||
elif ext in (".txt", ".md", ".csv", ".json", ".html", ".xml", ".rst"):
|
||||
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
logger.info(f"Read {len(content)} chars from {ext} file")
|
||||
return content
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {ext}")
|
||||
395
app/email_poller.py
Normal file
395
app/email_poller.py
Normal file
@@ -0,0 +1,395 @@
|
||||
import re
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Email poller for Zeus RAG — checks zeus@zz11.net via IMAP,
|
||||
downloads attachments, and ingests them into the RAG service.
|
||||
Also ingests email body text.
|
||||
"""
|
||||
|
||||
import email
|
||||
import email.header
|
||||
import imaplib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime
|
||||
from email.message import Message
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
IMAP_HOST = os.environ.get("IMAP_HOST", "mail.oe74.net")
|
||||
IMAP_PORT = int(os.environ.get("IMAP_PORT", "993"))
|
||||
IMAP_USER = os.environ.get("IMAP_USER", "zeus@zz11.net")
|
||||
IMAP_PASS = os.environ.get("IMAP_PASS", "")
|
||||
RAG_URL = os.environ.get("RAG_URL", "http://moxie-rag:8899")
|
||||
RAG_COLLECTION = os.environ.get("RAG_COLLECTION", "") # empty = default collection
|
||||
POLL_INTERVAL = int(os.environ.get("POLL_INTERVAL", "60")) # seconds
|
||||
STATE_FILE = os.environ.get("STATE_FILE", "/app/data/email_state.json")
|
||||
|
||||
# Whitelist of allowed senders (comma-separated email addresses)
|
||||
ALLOWED_SENDERS = os.environ.get("ALLOWED_SENDERS", "")
|
||||
ALLOWED_SENDERS_LIST = [s.strip().lower() for s in ALLOWED_SENDERS.split(",") if s.strip()]
|
||||
|
||||
SUPPORTED_EXTENSIONS = {
|
||||
".pdf", ".docx", ".doc", ".txt", ".md", ".csv", ".json",
|
||||
".xlsx", ".xls", ".html", ".xml",
|
||||
}
|
||||
MEDIA_EXTENSIONS = {
|
||||
".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv",
|
||||
".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac",
|
||||
}
|
||||
|
||||
LOG_DIR = Path(os.environ.get("LOG_DIR", "/app/logs"))
|
||||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(LOG_DIR / "email_poller.log"),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger("zeus-email-poller")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# State management (track processed emails)
|
||||
# ---------------------------------------------------------------------------
|
||||
def load_state() -> dict:
|
||||
if os.path.exists(STATE_FILE):
|
||||
with open(STATE_FILE) as f:
|
||||
return json.load(f)
|
||||
return {"processed_uids": [], "last_check": None}
|
||||
|
||||
|
||||
def save_state(state: dict):
|
||||
Path(STATE_FILE).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(STATE_FILE, "w") as f:
|
||||
json.dump(state, f, indent=2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Email processing
|
||||
# ---------------------------------------------------------------------------
|
||||
def decode_header_value(value: str) -> str:
|
||||
"""Decode MIME encoded header value."""
|
||||
if not value:
|
||||
return ""
|
||||
parts = email.header.decode_header(value)
|
||||
decoded = []
|
||||
for part, charset in parts:
|
||||
if isinstance(part, bytes):
|
||||
decoded.append(part.decode(charset or "utf-8", errors="replace"))
|
||||
else:
|
||||
decoded.append(part)
|
||||
return " ".join(decoded)
|
||||
|
||||
|
||||
def get_email_body(msg: Message) -> str:
|
||||
"""Extract plain text body from email message."""
|
||||
body_parts = []
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
ctype = part.get_content_type()
|
||||
if ctype == "text/plain":
|
||||
payload = part.get_payload(decode=True)
|
||||
if payload:
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
body_parts.append(payload.decode(charset, errors="replace"))
|
||||
elif ctype == "text/html" and not body_parts:
|
||||
# Fallback to HTML if no plain text
|
||||
payload = part.get_payload(decode=True)
|
||||
if payload:
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
body_parts.append(payload.decode(charset, errors="replace"))
|
||||
else:
|
||||
payload = msg.get_payload(decode=True)
|
||||
if payload:
|
||||
charset = msg.get_content_charset() or "utf-8"
|
||||
body_parts.append(payload.decode(charset, errors="replace"))
|
||||
return "\n".join(body_parts).strip()
|
||||
|
||||
|
||||
def get_attachments(msg: Message) -> list:
|
||||
"""Extract attachments from email message."""
|
||||
attachments = []
|
||||
for part in msg.walk():
|
||||
if part.get_content_maintype() == "multipart":
|
||||
continue
|
||||
filename = part.get_filename()
|
||||
if filename:
|
||||
filename = decode_header_value(filename)
|
||||
payload = part.get_payload(decode=True)
|
||||
if payload:
|
||||
attachments.append({"filename": filename, "data": payload})
|
||||
return attachments
|
||||
|
||||
|
||||
def ingest_text(content: str, title: str, source: str, doc_type: str = "email"):
|
||||
"""Send text to the RAG ingest endpoint."""
|
||||
try:
|
||||
payload = {
|
||||
"content": content,
|
||||
"title": title,
|
||||
"source": source,
|
||||
"doc_type": doc_type,
|
||||
"date": datetime.now().isoformat(),
|
||||
}
|
||||
if RAG_COLLECTION:
|
||||
payload["collection"] = RAG_COLLECTION
|
||||
resp = httpx.post(
|
||||
f"{RAG_URL}/ingest",
|
||||
json=payload,
|
||||
timeout=120.0,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
result = resp.json()
|
||||
logger.info(f"Ingested text '{title}': {result.get('chunks_created', 0)} chunks")
|
||||
return result
|
||||
else:
|
||||
logger.error(f"Ingest failed ({resp.status_code}): {resp.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error ingesting text: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def ingest_file(filepath: str, filename: str, source: str, doc_type: str = None):
|
||||
"""Send a file to the RAG ingest-file endpoint."""
|
||||
ext = Path(filename).suffix.lower()
|
||||
try:
|
||||
form_data = {
|
||||
"title": filename,
|
||||
"source": source,
|
||||
"doc_type": doc_type or ext.lstrip("."),
|
||||
}
|
||||
if RAG_COLLECTION:
|
||||
form_data["collection"] = RAG_COLLECTION
|
||||
with open(filepath, "rb") as f:
|
||||
resp = httpx.post(
|
||||
f"{RAG_URL}/ingest-file",
|
||||
files={"file": (filename, f)},
|
||||
data=form_data,
|
||||
timeout=300.0,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
result = resp.json()
|
||||
logger.info(f"Ingested file '{filename}': {result.get('chunks_created', 0)} chunks")
|
||||
return result
|
||||
else:
|
||||
logger.error(f"File ingest failed ({resp.status_code}): {resp.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error ingesting file '{filename}': {e}")
|
||||
return None
|
||||
|
||||
|
||||
def transcribe_and_ingest(filepath: str, filename: str, source: str):
|
||||
"""Send audio/video to transcribe endpoint with auto_ingest=true."""
|
||||
try:
|
||||
form_data = {
|
||||
"auto_ingest": "true",
|
||||
"title": f"Transcription: {filename}",
|
||||
"source": source,
|
||||
}
|
||||
if RAG_COLLECTION:
|
||||
form_data["collection"] = RAG_COLLECTION
|
||||
with open(filepath, "rb") as f:
|
||||
resp = httpx.post(
|
||||
f"{RAG_URL}/transcribe",
|
||||
files={"file": (filename, f)},
|
||||
data=form_data,
|
||||
timeout=600.0,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
result = resp.json()
|
||||
logger.info(
|
||||
f"Transcribed+ingested '{filename}': "
|
||||
f"{result.get('word_count', 0)} words, "
|
||||
f"{result.get('chunks_created', 0)} chunks"
|
||||
)
|
||||
return result
|
||||
else:
|
||||
logger.error(f"Transcribe failed ({resp.status_code}): {resp.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error transcribing '{filename}': {e}")
|
||||
return None
|
||||
|
||||
|
||||
def process_email(uid: str, msg: Message) -> dict:
|
||||
"""Process a single email: extract body and attachments, ingest everything."""
|
||||
subject = decode_header_value(msg.get("Subject", "No Subject"))
|
||||
sender = decode_header_value(msg.get("From", "Unknown"))
|
||||
date_str = msg.get("Date", datetime.now().isoformat())
|
||||
source = f"email:{sender}"
|
||||
|
||||
logger.info(f"Processing email UID={uid}: '{subject}' from {sender}")
|
||||
|
||||
# Check sender whitelist
|
||||
if ALLOWED_SENDERS_LIST:
|
||||
sender_email = sender.lower()
|
||||
# Extract email from "Name <email@domain.com>" format
|
||||
email_match = re.search(r'<([^>]+)>', sender_email)
|
||||
if email_match:
|
||||
sender_email = email_match.group(1)
|
||||
|
||||
if sender_email not in ALLOWED_SENDERS_LIST:
|
||||
logger.warning(f"Rejecting email from {sender}: not in whitelist")
|
||||
return {"uid": uid, "subject": subject, "sender": sender, "rejected": True, "reason": "sender_not_allowed"}
|
||||
|
||||
results = {"uid": uid, "subject": subject, "sender": sender, "ingested": []}
|
||||
|
||||
# 1. Ingest email body
|
||||
body = get_email_body(msg)
|
||||
if body and len(body.strip()) > 20:
|
||||
title = f"Email: {subject}"
|
||||
content = f"From: {sender}\nDate: {date_str}\nSubject: {subject}\n\n{body}"
|
||||
r = ingest_text(content, title, source, doc_type="email")
|
||||
if r:
|
||||
results["ingested"].append({"type": "body", "title": title, **r})
|
||||
|
||||
# 2. Process attachments
|
||||
attachments = get_attachments(msg)
|
||||
for att in attachments:
|
||||
filename = att["filename"]
|
||||
ext = Path(filename).suffix.lower()
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
|
||||
tmp.write(att["data"])
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
att_source = f"email-attachment:{sender}:{filename}"
|
||||
|
||||
if ext in SUPPORTED_EXTENSIONS:
|
||||
r = ingest_file(tmp_path, filename, att_source)
|
||||
if r:
|
||||
results["ingested"].append({"type": "file", "filename": filename, **r})
|
||||
|
||||
elif ext in MEDIA_EXTENSIONS:
|
||||
r = transcribe_and_ingest(tmp_path, filename, att_source)
|
||||
if r:
|
||||
results["ingested"].append({"type": "media", "filename": filename, **r})
|
||||
|
||||
else:
|
||||
logger.warning(f"Skipping unsupported attachment: {filename} ({ext})")
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def check_emails():
|
||||
"""Connect to IMAP, fetch unread emails, process them."""
|
||||
state = load_state()
|
||||
processed = set(state.get("processed_uids", []))
|
||||
|
||||
logger.info(f"Connecting to {IMAP_HOST}:{IMAP_PORT} as {IMAP_USER}...")
|
||||
|
||||
try:
|
||||
imap = imaplib.IMAP4_SSL(IMAP_HOST, IMAP_PORT)
|
||||
imap.login(IMAP_USER, IMAP_PASS)
|
||||
imap.select("INBOX")
|
||||
|
||||
# Search for UNSEEN messages
|
||||
status, data = imap.search(None, "UNSEEN")
|
||||
if status != "OK":
|
||||
logger.error(f"IMAP search failed: {status}")
|
||||
return
|
||||
|
||||
message_nums = data[0].split()
|
||||
if not message_nums:
|
||||
logger.info("No new emails.")
|
||||
imap.logout()
|
||||
return
|
||||
|
||||
logger.info(f"Found {len(message_nums)} unread email(s)")
|
||||
|
||||
for num in message_nums:
|
||||
# Get UID
|
||||
status, uid_data = imap.fetch(num, "(UID)")
|
||||
if status != "OK":
|
||||
continue
|
||||
uid = uid_data[0].decode().split("UID ")[1].split(")")[0].strip()
|
||||
|
||||
if uid in processed:
|
||||
logger.info(f"Skipping already-processed UID={uid}")
|
||||
continue
|
||||
|
||||
# Fetch full message
|
||||
status, msg_data = imap.fetch(num, "(RFC822)")
|
||||
if status != "OK":
|
||||
continue
|
||||
|
||||
raw_email = msg_data[0][1]
|
||||
msg = email.message_from_bytes(raw_email)
|
||||
|
||||
try:
|
||||
result = process_email(uid, msg)
|
||||
processed.add(uid)
|
||||
total_ingested = len(result.get("ingested", []))
|
||||
logger.info(
|
||||
f"Email UID={uid} processed: "
|
||||
f"{total_ingested} item(s) ingested"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing UID={uid}: {e}", exc_info=True)
|
||||
|
||||
imap.logout()
|
||||
|
||||
except imaplib.IMAP4.error as e:
|
||||
logger.error(f"IMAP error: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}", exc_info=True)
|
||||
|
||||
# Save state
|
||||
state["processed_uids"] = list(processed)[-500:] # Keep last 500
|
||||
state["last_check"] = datetime.now().isoformat()
|
||||
save_state(state)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main loop
|
||||
# ---------------------------------------------------------------------------
|
||||
def main():
|
||||
if not IMAP_PASS:
|
||||
logger.error("IMAP_PASS not set! Cannot connect to email.")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(f"Email Poller starting — checking {IMAP_USER} every {POLL_INTERVAL}s")
|
||||
logger.info(f"RAG endpoint: {RAG_URL}")
|
||||
if RAG_COLLECTION:
|
||||
logger.info(f"Target collection: {RAG_COLLECTION}")
|
||||
else:
|
||||
logger.info("Target collection: default")
|
||||
|
||||
# Wait for RAG service to be ready
|
||||
for attempt in range(30):
|
||||
try:
|
||||
resp = httpx.get(f"{RAG_URL}/health", timeout=5.0)
|
||||
if resp.status_code == 200:
|
||||
logger.info("RAG service is ready!")
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"Waiting for RAG service... (attempt {attempt + 1}/30)")
|
||||
time.sleep(5)
|
||||
else:
|
||||
logger.error("RAG service not available after 150s, starting anyway")
|
||||
|
||||
while True:
|
||||
try:
|
||||
check_emails()
|
||||
except Exception as e:
|
||||
logger.error(f"Poll cycle error: {e}", exc_info=True)
|
||||
time.sleep(POLL_INTERVAL)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
433
app/main.py
Normal file
433
app/main.py
Normal file
@@ -0,0 +1,433 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Moxie RAG Service — FastAPI application.
|
||||
Multi-collection support for tenant isolation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import logging
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Query, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
from rag_engine import RAGEngine
|
||||
from document_processor import extract_text_from_file, extract_audio_from_video
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging
|
||||
# ---------------------------------------------------------------------------
|
||||
LOG_DIR = Path(os.environ.get("LOG_DIR", "/app/logs"))
|
||||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(LOG_DIR / "rag_service.log"),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger("moxie-rag")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
WHISPER_URL = os.environ.get("WHISPER_URL", "http://host.docker.internal:8081/transcribe")
|
||||
UPLOAD_DIR = Path(os.environ.get("UPLOAD_DIR", "/app/data/uploads"))
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Engine
|
||||
# ---------------------------------------------------------------------------
|
||||
engine = RAGEngine(data_dir=os.environ.get("CHROMA_DIR", "/app/data/chromadb"))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastAPI
|
||||
# ---------------------------------------------------------------------------
|
||||
app = FastAPI(
|
||||
title="Moxie RAG Service",
|
||||
description="Multi-tenant RAG system for document storage and retrieval",
|
||||
version="2.0.0",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / response models
|
||||
# ---------------------------------------------------------------------------
|
||||
class IngestRequest(BaseModel):
|
||||
content: str
|
||||
title: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
date: Optional[str] = None
|
||||
doc_type: Optional[str] = "text"
|
||||
auto_chunk: bool = True
|
||||
collection: Optional[str] = None
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
question: str
|
||||
top_k: int = 5
|
||||
filter_type: Optional[str] = None
|
||||
collection: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "Moxie RAG",
|
||||
"version": "2.0.0",
|
||||
"device": engine.device,
|
||||
"model": engine.model_name,
|
||||
"collections": engine.list_collections(),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"device": engine.device,
|
||||
"collections": engine.list_collections(),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/collections")
|
||||
async def list_collections():
|
||||
"""List all collections and their chunk counts."""
|
||||
return {"collections": engine.list_collections()}
|
||||
|
||||
|
||||
@app.post("/ingest")
|
||||
async def ingest_text(req: IngestRequest):
|
||||
"""Ingest text content into the vector store."""
|
||||
if not req.content.strip():
|
||||
raise HTTPException(400, "Content cannot be empty")
|
||||
try:
|
||||
return engine.ingest(
|
||||
content=req.content,
|
||||
title=req.title or "Untitled",
|
||||
source=req.source or "unknown",
|
||||
date=req.date,
|
||||
doc_type=req.doc_type or "text",
|
||||
auto_chunk=req.auto_chunk,
|
||||
collection=req.collection,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(400, str(exc))
|
||||
|
||||
|
||||
@app.post("/ingest-file")
|
||||
async def ingest_file(
|
||||
file: UploadFile = File(...),
|
||||
title: Optional[str] = Form(None),
|
||||
source: Optional[str] = Form(None),
|
||||
date: Optional[str] = Form(None),
|
||||
doc_type: Optional[str] = Form(None),
|
||||
collection: Optional[str] = Form(None),
|
||||
):
|
||||
"""Upload and ingest a document (PDF, DOCX, TXT, MD, XLSX, XLS, CSV)."""
|
||||
suffix = Path(file.filename).suffix.lower()
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
content = await file.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
text = extract_text_from_file(tmp_path, file.filename)
|
||||
if not text.strip():
|
||||
raise HTTPException(400, "Could not extract text from file")
|
||||
|
||||
# Keep a copy
|
||||
dest = UPLOAD_DIR / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{file.filename}"
|
||||
shutil.copy2(tmp_path, dest)
|
||||
|
||||
return engine.ingest(
|
||||
content=text,
|
||||
title=title or file.filename,
|
||||
source=source or f"file:{file.filename}",
|
||||
date=date,
|
||||
doc_type=doc_type or suffix.lstrip("."),
|
||||
collection=collection,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(400, str(exc))
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
@app.post("/query")
|
||||
async def query(req: QueryRequest):
|
||||
"""Semantic search over indexed documents."""
|
||||
if not req.question.strip():
|
||||
raise HTTPException(400, "Question cannot be empty")
|
||||
return engine.query(
|
||||
question=req.question,
|
||||
top_k=req.top_k,
|
||||
filter_type=req.filter_type,
|
||||
collection=req.collection,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/documents")
|
||||
async def list_documents(collection: Optional[str] = Query(None)):
|
||||
"""List all indexed documents."""
|
||||
return engine.list_documents(collection=collection)
|
||||
|
||||
|
||||
@app.delete("/documents/{doc_id}")
|
||||
async def delete_document(doc_id: str, collection: Optional[str] = Query(None)):
|
||||
"""Delete a document and all its chunks."""
|
||||
try:
|
||||
return engine.delete_document(doc_id, collection=collection)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(404, str(exc))
|
||||
|
||||
|
||||
@app.post("/transcribe")
|
||||
async def transcribe(
|
||||
file: UploadFile = File(...),
|
||||
auto_ingest: bool = Form(False),
|
||||
title: Optional[str] = Form(None),
|
||||
source: Optional[str] = Form(None),
|
||||
language: Optional[str] = Form(None),
|
||||
collection: Optional[str] = Form(None),
|
||||
):
|
||||
"""Transcribe audio/video via Whisper, optionally auto-ingest the result."""
|
||||
suffix = Path(file.filename).suffix.lower()
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
content = await file.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
audio_path = None
|
||||
try:
|
||||
# If video, extract audio first
|
||||
video_exts = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv"}
|
||||
send_path = tmp_path
|
||||
if suffix in video_exts:
|
||||
logger.info(f"Extracting audio from video: {file.filename}")
|
||||
audio_path = extract_audio_from_video(tmp_path)
|
||||
send_path = audio_path
|
||||
|
||||
async with httpx.AsyncClient(timeout=600.0) as client:
|
||||
with open(send_path, "rb") as audio_file:
|
||||
send_name = file.filename if suffix not in video_exts else Path(file.filename).stem + ".wav"
|
||||
files = {"file": (send_name, audio_file)}
|
||||
resp = await client.post(WHISPER_URL, files=files)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise HTTPException(502, f"Whisper error: {resp.status_code} — {resp.text}")
|
||||
|
||||
result = resp.json()
|
||||
transcription = result.get("text", result.get("transcription", ""))
|
||||
|
||||
if not transcription.strip():
|
||||
raise HTTPException(400, "Transcription returned empty text")
|
||||
|
||||
response = {
|
||||
"filename": file.filename,
|
||||
"transcription": transcription,
|
||||
"word_count": len(transcription.split()),
|
||||
}
|
||||
|
||||
if auto_ingest:
|
||||
ingest_result = engine.ingest(
|
||||
content=transcription,
|
||||
title=title or f"Transcription: {file.filename}",
|
||||
source=source or f"audio:{file.filename}",
|
||||
doc_type="transcription",
|
||||
collection=collection,
|
||||
)
|
||||
response["ingested"] = True
|
||||
response["doc_id"] = ingest_result["doc_id"]
|
||||
response["chunks_created"] = ingest_result["chunks_created"]
|
||||
response["collection"] = ingest_result["collection"]
|
||||
else:
|
||||
response["ingested"] = False
|
||||
|
||||
logger.info(f"Transcribed '{file.filename}' ({response['word_count']} words)")
|
||||
return response
|
||||
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
if audio_path and os.path.exists(audio_path):
|
||||
os.unlink(audio_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entrypoint
|
||||
# ---------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8899, log_level="info")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Memory Store Integration
|
||||
# ---------------------------------------------------------------------------
|
||||
from memory_store import MemoryStore
|
||||
|
||||
MEMORY_DB = os.environ.get("MEMORY_DB", "/app/data/memory.db")
|
||||
memory = MemoryStore(db_path=MEMORY_DB, rag_engine=engine)
|
||||
logger.info(f"Memory store initialized: {MEMORY_DB}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Memory Request Models
|
||||
# ---------------------------------------------------------------------------
|
||||
class SaveObservationRequest(BaseModel):
|
||||
user_id: str
|
||||
content: str
|
||||
type: str = "general"
|
||||
title: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
tool_name: Optional[str] = None
|
||||
importance: int = 1
|
||||
tags: Optional[list] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class QueryMemoryRequest(BaseModel):
|
||||
user_id: str
|
||||
query: str
|
||||
top_k: int = 10
|
||||
type: Optional[str] = None
|
||||
since: Optional[str] = None
|
||||
include_content: bool = False
|
||||
|
||||
|
||||
class GetObservationsRequest(BaseModel):
|
||||
user_id: str
|
||||
ids: list
|
||||
|
||||
|
||||
class TimelineRequest(BaseModel):
|
||||
user_id: str
|
||||
around_id: Optional[int] = None
|
||||
around_time: Optional[str] = None
|
||||
window_minutes: int = 30
|
||||
limit: int = 20
|
||||
|
||||
|
||||
class PreferenceRequest(BaseModel):
|
||||
user_id: str
|
||||
key: str
|
||||
value: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Memory Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
@app.post("/memory/save")
|
||||
async def save_observation(req: SaveObservationRequest):
|
||||
"""Save an observation to memory (SQLite + vector embedding)."""
|
||||
if not req.content.strip():
|
||||
raise HTTPException(400, "Content cannot be empty")
|
||||
if not req.user_id.strip():
|
||||
raise HTTPException(400, "user_id is required")
|
||||
|
||||
return memory.save_observation(
|
||||
user_id=req.user_id,
|
||||
content=req.content,
|
||||
obs_type=req.type,
|
||||
title=req.title,
|
||||
session_id=req.session_id,
|
||||
tool_name=req.tool_name,
|
||||
importance=req.importance,
|
||||
tags=req.tags,
|
||||
metadata=req.metadata,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/memory/query")
|
||||
async def query_memory(req: QueryMemoryRequest):
|
||||
"""
|
||||
Search memory using hybrid vector + structured search.
|
||||
Returns index by default (progressive disclosure).
|
||||
Set include_content=true for full observation content.
|
||||
"""
|
||||
if not req.query.strip():
|
||||
raise HTTPException(400, "Query cannot be empty")
|
||||
if not req.user_id.strip():
|
||||
raise HTTPException(400, "user_id is required")
|
||||
|
||||
return memory.query_memory(
|
||||
user_id=req.user_id,
|
||||
query=req.query,
|
||||
top_k=req.top_k,
|
||||
obs_type=req.type,
|
||||
since=req.since,
|
||||
include_content=req.include_content,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/memory/get")
|
||||
async def get_observations(req: GetObservationsRequest):
|
||||
"""Fetch full observation details by IDs."""
|
||||
if not req.user_id.strip():
|
||||
raise HTTPException(400, "user_id is required")
|
||||
if not req.ids:
|
||||
raise HTTPException(400, "ids list cannot be empty")
|
||||
|
||||
return memory.get_observations(
|
||||
user_id=req.user_id,
|
||||
ids=req.ids,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/memory/timeline")
|
||||
async def get_timeline(req: TimelineRequest):
|
||||
"""Get chronological context around a specific observation or time."""
|
||||
if not req.user_id.strip():
|
||||
raise HTTPException(400, "user_id is required")
|
||||
|
||||
return memory.get_timeline(
|
||||
user_id=req.user_id,
|
||||
around_id=req.around_id,
|
||||
around_time=req.around_time,
|
||||
window_minutes=req.window_minutes,
|
||||
limit=req.limit,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/memory/preference")
|
||||
async def save_preference(req: PreferenceRequest):
|
||||
"""Save or update a user preference."""
|
||||
if not req.user_id.strip():
|
||||
raise HTTPException(400, "user_id is required")
|
||||
|
||||
return memory.save_preference(
|
||||
user_id=req.user_id,
|
||||
key=req.key,
|
||||
value=req.value,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/memory/preferences/{user_id}")
|
||||
async def get_preferences(user_id: str):
|
||||
"""Get all preferences for a user."""
|
||||
return memory.get_preferences(user_id)
|
||||
|
||||
|
||||
@app.get("/memory/stats/{user_id}")
|
||||
async def get_memory_stats(user_id: str):
|
||||
"""Get memory statistics for a user."""
|
||||
return memory.get_stats(user_id)
|
||||
378
app/memory_store.py
Normal file
378
app/memory_store.py
Normal file
@@ -0,0 +1,378 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Memory Store — SQLite + ChromaDB hybrid for agent observations.
|
||||
Provides structured storage with vector search.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any
|
||||
from contextlib import contextmanager
|
||||
|
||||
logger = logging.getLogger("moxie-rag.memory")
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
"""SQLite-backed memory store with ChromaDB integration."""
|
||||
|
||||
def __init__(self, db_path: str, rag_engine):
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.rag_engine = rag_engine
|
||||
self._init_db()
|
||||
|
||||
@contextmanager
|
||||
def _get_conn(self):
|
||||
"""Thread-safe connection context manager."""
|
||||
conn = sqlite3.connect(str(self.db_path), timeout=30.0)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize the SQLite schema."""
|
||||
with self._get_conn() as conn:
|
||||
conn.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS observations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT NOT NULL,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
type TEXT NOT NULL,
|
||||
title TEXT,
|
||||
content TEXT NOT NULL,
|
||||
content_hash TEXT UNIQUE,
|
||||
embedding_id TEXT,
|
||||
session_id TEXT,
|
||||
tool_name TEXT,
|
||||
importance INTEGER DEFAULT 1,
|
||||
tags TEXT,
|
||||
metadata TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_obs_user ON observations(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_obs_type ON observations(type);
|
||||
CREATE INDEX IF NOT EXISTS idx_obs_timestamp ON observations(timestamp);
|
||||
CREATE INDEX IF NOT EXISTS idx_obs_session ON observations(session_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS preferences (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(user_id, key)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_pref_user ON preferences(user_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS relationships (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT NOT NULL,
|
||||
observation_id INTEGER,
|
||||
related_id INTEGER,
|
||||
relation_type TEXT,
|
||||
FOREIGN KEY (observation_id) REFERENCES observations(id),
|
||||
FOREIGN KEY (related_id) REFERENCES observations(id)
|
||||
);
|
||||
""")
|
||||
conn.commit()
|
||||
logger.info(f"Memory store initialized at {self.db_path}")
|
||||
|
||||
def _content_hash(self, content: str) -> str:
|
||||
"""Generate hash for deduplication."""
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
||||
|
||||
def _get_collection_name(self, user_id: str) -> str:
|
||||
"""Get ChromaDB collection name for user."""
|
||||
return f"moxie_memory_{user_id}"
|
||||
|
||||
def save_observation(
|
||||
self,
|
||||
user_id: str,
|
||||
content: str,
|
||||
obs_type: str = "general",
|
||||
title: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
tool_name: Optional[str] = None,
|
||||
importance: int = 1,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Save an observation to SQLite and embed in ChromaDB.
|
||||
Returns the observation ID and embedding status.
|
||||
"""
|
||||
content_hash = self._content_hash(content)
|
||||
collection = self._get_collection_name(user_id)
|
||||
|
||||
# Check for duplicate
|
||||
with self._get_conn() as conn:
|
||||
existing = conn.execute(
|
||||
"SELECT id FROM observations WHERE content_hash = ?",
|
||||
(content_hash,)
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
return {
|
||||
"status": "duplicate",
|
||||
"observation_id": existing["id"],
|
||||
"message": "Observation already exists"
|
||||
}
|
||||
|
||||
# Embed in ChromaDB
|
||||
embed_result = self.rag_engine.ingest(
|
||||
content=content,
|
||||
title=title or f"Observation: {obs_type}",
|
||||
source=f"memory:{user_id}:{obs_type}",
|
||||
doc_type="observation",
|
||||
collection=collection,
|
||||
)
|
||||
embedding_id = embed_result.get("doc_id")
|
||||
|
||||
# Store in SQLite
|
||||
tags_str = ",".join(tags) if tags else None
|
||||
metadata_str = str(metadata) if metadata else None
|
||||
|
||||
with self._get_conn() as conn:
|
||||
cursor = conn.execute("""
|
||||
INSERT INTO observations
|
||||
(user_id, type, title, content, content_hash, embedding_id,
|
||||
session_id, tool_name, importance, tags, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
user_id, obs_type, title, content, content_hash, embedding_id,
|
||||
session_id, tool_name, importance, tags_str, metadata_str
|
||||
))
|
||||
conn.commit()
|
||||
obs_id = cursor.lastrowid
|
||||
|
||||
logger.info(f"Saved observation #{obs_id} for user {user_id} (type: {obs_type})")
|
||||
return {
|
||||
"status": "created",
|
||||
"observation_id": obs_id,
|
||||
"embedding_id": embedding_id,
|
||||
"collection": collection,
|
||||
}
|
||||
|
||||
def query_memory(
|
||||
self,
|
||||
user_id: str,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
obs_type: Optional[str] = None,
|
||||
since: Optional[str] = None,
|
||||
include_content: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Search memory using hybrid SQLite + vector search.
|
||||
Progressive disclosure: returns index by default, full content if requested.
|
||||
"""
|
||||
collection = self._get_collection_name(user_id)
|
||||
|
||||
# Vector search in ChromaDB
|
||||
vector_results = self.rag_engine.query(
|
||||
question=query,
|
||||
top_k=top_k * 2, # Get more for filtering
|
||||
collection=collection,
|
||||
)
|
||||
|
||||
# Get observation IDs from embedding IDs
|
||||
embedding_ids = [r.get("metadata", {}).get("doc_id") for r in vector_results.get("results", [])]
|
||||
|
||||
if not embedding_ids:
|
||||
return {"results": [], "total": 0, "query": query}
|
||||
|
||||
# Fetch from SQLite with filters
|
||||
placeholders = ",".join(["?" for _ in embedding_ids])
|
||||
sql = f"""
|
||||
SELECT id, user_id, timestamp, type, title, importance, tags, tool_name
|
||||
{"" if not include_content else ", content"}
|
||||
FROM observations
|
||||
WHERE user_id = ? AND embedding_id IN ({placeholders})
|
||||
"""
|
||||
params = [user_id] + embedding_ids
|
||||
|
||||
if obs_type:
|
||||
sql += " AND type = ?"
|
||||
params.append(obs_type)
|
||||
|
||||
if since:
|
||||
sql += " AND timestamp >= ?"
|
||||
params.append(since)
|
||||
|
||||
sql += " ORDER BY timestamp DESC LIMIT ?"
|
||||
params.append(top_k)
|
||||
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
item = {
|
||||
"id": row["id"],
|
||||
"timestamp": row["timestamp"],
|
||||
"type": row["type"],
|
||||
"title": row["title"],
|
||||
"importance": row["importance"],
|
||||
"tags": row["tags"].split(",") if row["tags"] else [],
|
||||
"tool_name": row["tool_name"],
|
||||
}
|
||||
if include_content:
|
||||
item["content"] = row["content"]
|
||||
results.append(item)
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"total": len(results),
|
||||
"query": query,
|
||||
"collection": collection,
|
||||
}
|
||||
|
||||
def get_observations(
|
||||
self,
|
||||
user_id: str,
|
||||
ids: List[int],
|
||||
) -> Dict[str, Any]:
|
||||
"""Fetch full observation details by IDs."""
|
||||
if not ids:
|
||||
return {"observations": []}
|
||||
|
||||
placeholders = ",".join(["?" for _ in ids])
|
||||
sql = f"""
|
||||
SELECT * FROM observations
|
||||
WHERE user_id = ? AND id IN ({placeholders})
|
||||
ORDER BY timestamp DESC
|
||||
"""
|
||||
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(sql, [user_id] + ids).fetchall()
|
||||
|
||||
observations = []
|
||||
for row in rows:
|
||||
observations.append({
|
||||
"id": row["id"],
|
||||
"timestamp": row["timestamp"],
|
||||
"type": row["type"],
|
||||
"title": row["title"],
|
||||
"content": row["content"],
|
||||
"importance": row["importance"],
|
||||
"tags": row["tags"].split(",") if row["tags"] else [],
|
||||
"tool_name": row["tool_name"],
|
||||
"session_id": row["session_id"],
|
||||
"metadata": row["metadata"],
|
||||
})
|
||||
|
||||
return {"observations": observations, "count": len(observations)}
|
||||
|
||||
def get_timeline(
|
||||
self,
|
||||
user_id: str,
|
||||
around_id: Optional[int] = None,
|
||||
around_time: Optional[str] = None,
|
||||
window_minutes: int = 30,
|
||||
limit: int = 20,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get chronological context around a specific observation or time."""
|
||||
with self._get_conn() as conn:
|
||||
if around_id:
|
||||
# Get timestamp of reference observation
|
||||
ref = conn.execute(
|
||||
"SELECT timestamp FROM observations WHERE id = ? AND user_id = ?",
|
||||
(around_id, user_id)
|
||||
).fetchone()
|
||||
if not ref:
|
||||
return {"error": "Observation not found", "timeline": []}
|
||||
center_time = ref["timestamp"]
|
||||
elif around_time:
|
||||
center_time = around_time
|
||||
else:
|
||||
center_time = datetime.now().isoformat()
|
||||
|
||||
# Get observations in time window
|
||||
rows = conn.execute("""
|
||||
SELECT id, timestamp, type, title, importance, tool_name
|
||||
FROM observations
|
||||
WHERE user_id = ?
|
||||
AND datetime(timestamp) BETWEEN
|
||||
datetime(?, '-' || ? || ' minutes')
|
||||
AND datetime(?, '+' || ? || ' minutes')
|
||||
ORDER BY timestamp
|
||||
LIMIT ?
|
||||
""", (user_id, center_time, window_minutes, center_time, window_minutes, limit)).fetchall()
|
||||
|
||||
timeline = [{
|
||||
"id": row["id"],
|
||||
"timestamp": row["timestamp"],
|
||||
"type": row["type"],
|
||||
"title": row["title"],
|
||||
"importance": row["importance"],
|
||||
"tool_name": row["tool_name"],
|
||||
} for row in rows]
|
||||
|
||||
return {
|
||||
"timeline": timeline,
|
||||
"center_time": center_time,
|
||||
"window_minutes": window_minutes,
|
||||
"count": len(timeline),
|
||||
}
|
||||
|
||||
def save_preference(
|
||||
self,
|
||||
user_id: str,
|
||||
key: str,
|
||||
value: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Save or update a user preference."""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute("""
|
||||
INSERT INTO preferences (user_id, key, value)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(user_id, key) DO UPDATE SET
|
||||
value = excluded.value,
|
||||
timestamp = CURRENT_TIMESTAMP
|
||||
""", (user_id, key, value))
|
||||
conn.commit()
|
||||
|
||||
return {"status": "saved", "user_id": user_id, "key": key}
|
||||
|
||||
def get_preferences(self, user_id: str) -> Dict[str, str]:
|
||||
"""Get all preferences for a user."""
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT key, value FROM preferences WHERE user_id = ?",
|
||||
(user_id,)
|
||||
).fetchall()
|
||||
|
||||
return {row["key"]: row["value"] for row in rows}
|
||||
|
||||
def get_stats(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get memory statistics for a user."""
|
||||
with self._get_conn() as conn:
|
||||
total = conn.execute(
|
||||
"SELECT COUNT(*) as c FROM observations WHERE user_id = ?",
|
||||
(user_id,)
|
||||
).fetchone()["c"]
|
||||
|
||||
by_type = conn.execute("""
|
||||
SELECT type, COUNT(*) as c
|
||||
FROM observations WHERE user_id = ?
|
||||
GROUP BY type
|
||||
""", (user_id,)).fetchall()
|
||||
|
||||
recent = conn.execute("""
|
||||
SELECT COUNT(*) as c FROM observations
|
||||
WHERE user_id = ? AND timestamp >= datetime('now', '-7 days')
|
||||
""", (user_id,)).fetchone()["c"]
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"total_observations": total,
|
||||
"by_type": {row["type"]: row["c"] for row in by_type},
|
||||
"last_7_days": recent,
|
||||
"collection": self._get_collection_name(user_id),
|
||||
}
|
||||
257
app/rag_engine.py
Normal file
257
app/rag_engine.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""
|
||||
RAG Engine — ChromaDB + sentence-transformers embedding logic.
|
||||
Supports multiple collections for tenant isolation.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import chromadb
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from document_processor import chunk_text
|
||||
|
||||
logger = logging.getLogger("moxie-rag.engine")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detect best device
|
||||
# ---------------------------------------------------------------------------
|
||||
DEVICE = "cpu"
|
||||
try:
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
DEVICE = "cuda"
|
||||
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "unknown"
|
||||
logger.info(f"GPU detected: {gpu_name}")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
logger.info(f"Embedding device: {DEVICE}")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
DEFAULT_COLLECTION = "adolfo_docs"
|
||||
|
||||
|
||||
class RAGEngine:
|
||||
"""Manages embeddings and vector storage with multi-collection support."""
|
||||
|
||||
def __init__(self, data_dir: str = "/app/data/chromadb"):
|
||||
Path(data_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"Loading embedding model '{EMBEDDING_MODEL}' on {DEVICE} ...")
|
||||
self.embedder = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)
|
||||
logger.info("Embedding model loaded.")
|
||||
|
||||
self.chroma = chromadb.PersistentClient(path=data_dir)
|
||||
|
||||
# Pre-load default collection
|
||||
self._collections: Dict[str, Any] = {}
|
||||
self._get_collection(DEFAULT_COLLECTION)
|
||||
logger.info(
|
||||
f"ChromaDB collection '{DEFAULT_COLLECTION}' ready — "
|
||||
f"{self._collections[DEFAULT_COLLECTION].count()} existing chunks."
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Collection management
|
||||
# ------------------------------------------------------------------
|
||||
def _get_collection(self, name: str):
|
||||
"""Get or create a ChromaDB collection by name."""
|
||||
if name not in self._collections:
|
||||
self._collections[name] = self.chroma.get_or_create_collection(
|
||||
name=name,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
)
|
||||
logger.info(f"Collection '{name}' loaded ({self._collections[name].count()} chunks)")
|
||||
return self._collections[name]
|
||||
|
||||
def list_collections(self) -> List[Dict[str, Any]]:
|
||||
"""List all collections with their document counts."""
|
||||
collections = self.chroma.list_collections()
|
||||
result = []
|
||||
for coll in collections:
|
||||
# ChromaDB >= 1.x returns Collection objects, older versions return strings
|
||||
name = coll if isinstance(coll, str) else coll.name
|
||||
c = self._get_collection(name)
|
||||
result.append({"name": name, "chunks": c.count()})
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
@property
|
||||
def device(self) -> str:
|
||||
return DEVICE
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return EMBEDDING_MODEL
|
||||
|
||||
@property
|
||||
def doc_count(self) -> int:
|
||||
return self._get_collection(DEFAULT_COLLECTION).count()
|
||||
|
||||
def collection_count(self, collection: str = None) -> int:
|
||||
return self._get_collection(collection or DEFAULT_COLLECTION).count()
|
||||
|
||||
@staticmethod
|
||||
def _make_doc_id(title: str, source: str) -> str:
|
||||
raw = f"{title}:{source}:{datetime.now().isoformat()}"
|
||||
return hashlib.md5(raw.encode()).hexdigest()[:12]
|
||||
|
||||
def _embed(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embedder.encode(texts, show_progress_bar=False).tolist()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Ingest
|
||||
# ------------------------------------------------------------------
|
||||
def ingest(
|
||||
self,
|
||||
content: str,
|
||||
title: str = "Untitled",
|
||||
source: str = "unknown",
|
||||
date: Optional[str] = None,
|
||||
doc_type: str = "text",
|
||||
auto_chunk: bool = True,
|
||||
collection: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Chunk, embed, and store content in specified collection."""
|
||||
coll = self._get_collection(collection or DEFAULT_COLLECTION)
|
||||
doc_id = self._make_doc_id(title, source)
|
||||
date = date or datetime.now().isoformat()
|
||||
|
||||
chunks = chunk_text(content) if auto_chunk else [content.strip()]
|
||||
chunks = [c for c in chunks if c]
|
||||
if not chunks:
|
||||
raise ValueError("No content to ingest after processing")
|
||||
|
||||
embeddings = self._embed(chunks)
|
||||
|
||||
ids = [f"{doc_id}_chunk_{i}" for i in range(len(chunks))]
|
||||
metadatas = [
|
||||
{
|
||||
"doc_id": doc_id,
|
||||
"title": title,
|
||||
"source": source,
|
||||
"date": date,
|
||||
"doc_type": doc_type,
|
||||
"chunk_index": i,
|
||||
"total_chunks": len(chunks),
|
||||
}
|
||||
for i in range(len(chunks))
|
||||
]
|
||||
|
||||
coll.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=chunks,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
coll_name = collection or DEFAULT_COLLECTION
|
||||
logger.info(f"Ingested '{title}' ({len(chunks)} chunks) [{doc_id}] → {coll_name}")
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"title": title,
|
||||
"chunks_created": len(chunks),
|
||||
"total_documents": coll.count(),
|
||||
"collection": coll_name,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query
|
||||
# ------------------------------------------------------------------
|
||||
def query(
|
||||
self,
|
||||
question: str,
|
||||
top_k: int = 5,
|
||||
filter_type: Optional[str] = None,
|
||||
collection: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Semantic search over indexed chunks in specified collection."""
|
||||
coll = self._get_collection(collection or DEFAULT_COLLECTION)
|
||||
|
||||
if coll.count() == 0:
|
||||
return {"question": question, "results": [], "total_results": 0, "collection": collection or DEFAULT_COLLECTION}
|
||||
|
||||
query_emb = self._embed([question])
|
||||
where = {"doc_type": filter_type} if filter_type else None
|
||||
|
||||
results = coll.query(
|
||||
query_embeddings=query_emb,
|
||||
n_results=min(top_k, coll.count()),
|
||||
where=where,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
|
||||
formatted = []
|
||||
if results and results["ids"] and results["ids"][0]:
|
||||
for i, cid in enumerate(results["ids"][0]):
|
||||
formatted.append(
|
||||
{
|
||||
"chunk_id": cid,
|
||||
"content": results["documents"][0][i],
|
||||
"metadata": results["metadatas"][0][i],
|
||||
"distance": results["distances"][0][i],
|
||||
}
|
||||
)
|
||||
|
||||
coll_name = collection or DEFAULT_COLLECTION
|
||||
logger.info(f"Query [{coll_name}]: '{question}' → {len(formatted)} results")
|
||||
return {
|
||||
"question": question,
|
||||
"results": formatted,
|
||||
"total_results": len(formatted),
|
||||
"collection": coll_name,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Document management
|
||||
# ------------------------------------------------------------------
|
||||
def list_documents(self, collection: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""List all indexed documents grouped by doc_id in specified collection."""
|
||||
coll = self._get_collection(collection or DEFAULT_COLLECTION)
|
||||
|
||||
if coll.count() == 0:
|
||||
return {"documents": [], "total": 0, "collection": collection or DEFAULT_COLLECTION}
|
||||
|
||||
all_data = coll.get(include=["metadatas"])
|
||||
docs: Dict[str, Dict] = {}
|
||||
for meta in all_data["metadatas"]:
|
||||
did = meta.get("doc_id", "unknown")
|
||||
if did not in docs:
|
||||
docs[did] = {
|
||||
"doc_id": did,
|
||||
"title": meta.get("title", "Unknown"),
|
||||
"source": meta.get("source", "unknown"),
|
||||
"doc_type": meta.get("doc_type", "text"),
|
||||
"date": meta.get("date", "unknown"),
|
||||
"chunk_count": 0,
|
||||
}
|
||||
docs[did]["chunk_count"] += 1
|
||||
|
||||
return {"documents": list(docs.values()), "total": len(docs), "collection": collection or DEFAULT_COLLECTION}
|
||||
|
||||
def delete_document(self, doc_id: str, collection: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Delete all chunks belonging to a document in specified collection."""
|
||||
coll = self._get_collection(collection or DEFAULT_COLLECTION)
|
||||
all_data = coll.get(include=["metadatas"])
|
||||
ids_to_delete = [
|
||||
all_data["ids"][i]
|
||||
for i, m in enumerate(all_data["metadatas"])
|
||||
if m.get("doc_id") == doc_id
|
||||
]
|
||||
|
||||
if not ids_to_delete:
|
||||
raise KeyError(f"Document '{doc_id}' not found")
|
||||
|
||||
coll.delete(ids=ids_to_delete)
|
||||
logger.info(f"Deleted {doc_id} ({len(ids_to_delete)} chunks)")
|
||||
return {"deleted": doc_id, "chunks_removed": len(ids_to_delete)}
|
||||
11
app/requirements.txt
Normal file
11
app/requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
# PyTorch ROCm installed separately in Dockerfile
|
||||
fastapi>=0.104.0
|
||||
uvicorn[standard]>=0.24.0
|
||||
chromadb>=0.4.0
|
||||
sentence-transformers>=2.2.0
|
||||
pdfplumber>=0.10.0
|
||||
python-docx>=1.0.0
|
||||
python-multipart>=0.0.6
|
||||
httpx>=0.25.0
|
||||
pandas>=2.1.0
|
||||
openpyxl>=3.1.0
|
||||
69
docker-compose.yml
Normal file
69
docker-compose.yml
Normal file
@@ -0,0 +1,69 @@
|
||||
services:
|
||||
rag:
|
||||
build: .
|
||||
container_name: moxie-rag
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "8899:8899"
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
- ./logs:/app/logs
|
||||
environment:
|
||||
- WHISPER_URL=http://host.docker.internal:8081/transcribe
|
||||
- CHROMA_DIR=/app/data/chromadb
|
||||
- UPLOAD_DIR=/app/data/uploads
|
||||
- LOG_DIR=/app/logs
|
||||
devices:
|
||||
- /dev/kfd:/dev/kfd
|
||||
- /dev/dri:/dev/dri
|
||||
group_add:
|
||||
- "44"
|
||||
- "992"
|
||||
security_opt:
|
||||
- seccomp=unconfined
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
ipc: host
|
||||
|
||||
poller-zeus:
|
||||
build: .
|
||||
container_name: zeus-email-poller
|
||||
restart: unless-stopped
|
||||
command: python email_poller.py
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
- ./logs:/app/logs
|
||||
environment:
|
||||
- IMAP_HOST=mail.oe74.net
|
||||
- IMAP_PORT=993
|
||||
- IMAP_USER=zeus@zz11.net
|
||||
- IMAP_PASS=#!nvo@uHR6493
|
||||
- RAG_URL=http://moxie-rag:8899
|
||||
- RAG_COLLECTION=zeus_docs
|
||||
- ALLOWED_SENDERS=isabella.isg@gmail.com
|
||||
- POLL_INTERVAL=60
|
||||
- STATE_FILE=/app/data/zeus_email_state.json
|
||||
- LOG_DIR=/app/logs
|
||||
depends_on:
|
||||
- rag
|
||||
|
||||
poller-moxie:
|
||||
build: .
|
||||
container_name: moxie-email-poller
|
||||
restart: unless-stopped
|
||||
command: python email_poller.py
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
- ./logs:/app/logs
|
||||
environment:
|
||||
- IMAP_HOST=mail.oe74.net
|
||||
- IMAP_PORT=993
|
||||
- IMAP_USER=moxie@zz11.net
|
||||
- IMAP_PASS=Xn1R#JThrcn0k
|
||||
- RAG_URL=http://moxie-rag:8899
|
||||
- RAG_COLLECTION=adolfo_docs
|
||||
- POLL_INTERVAL=60
|
||||
- STATE_FILE=/app/data/moxie_email_state.json
|
||||
- LOG_DIR=/app/logs
|
||||
depends_on:
|
||||
- rag
|
||||
Reference in New Issue
Block a user