feat: Add memory system with SQLite + ChromaDB hybrid storage
- memory_store.py: User-isolated observation storage with vector embeddings - New endpoints: /memory/save, /memory/query, /memory/get, /memory/timeline - Progressive disclosure pattern for token-efficient retrieval - Updated Dockerfile to ROCm 7.2 nightly
This commit is contained in:
433
app/main.py
Normal file
433
app/main.py
Normal file
@@ -0,0 +1,433 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Moxie RAG Service — FastAPI application.
|
||||
Multi-collection support for tenant isolation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import logging
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Query, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
from rag_engine import RAGEngine
|
||||
from document_processor import extract_text_from_file, extract_audio_from_video
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging
|
||||
# ---------------------------------------------------------------------------
|
||||
LOG_DIR = Path(os.environ.get("LOG_DIR", "/app/logs"))
|
||||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(LOG_DIR / "rag_service.log"),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger("moxie-rag")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
WHISPER_URL = os.environ.get("WHISPER_URL", "http://host.docker.internal:8081/transcribe")
|
||||
UPLOAD_DIR = Path(os.environ.get("UPLOAD_DIR", "/app/data/uploads"))
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Engine
|
||||
# ---------------------------------------------------------------------------
|
||||
engine = RAGEngine(data_dir=os.environ.get("CHROMA_DIR", "/app/data/chromadb"))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastAPI
|
||||
# ---------------------------------------------------------------------------
|
||||
app = FastAPI(
|
||||
title="Moxie RAG Service",
|
||||
description="Multi-tenant RAG system for document storage and retrieval",
|
||||
version="2.0.0",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / response models
|
||||
# ---------------------------------------------------------------------------
|
||||
class IngestRequest(BaseModel):
|
||||
content: str
|
||||
title: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
date: Optional[str] = None
|
||||
doc_type: Optional[str] = "text"
|
||||
auto_chunk: bool = True
|
||||
collection: Optional[str] = None
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
question: str
|
||||
top_k: int = 5
|
||||
filter_type: Optional[str] = None
|
||||
collection: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "Moxie RAG",
|
||||
"version": "2.0.0",
|
||||
"device": engine.device,
|
||||
"model": engine.model_name,
|
||||
"collections": engine.list_collections(),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"device": engine.device,
|
||||
"collections": engine.list_collections(),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/collections")
|
||||
async def list_collections():
|
||||
"""List all collections and their chunk counts."""
|
||||
return {"collections": engine.list_collections()}
|
||||
|
||||
|
||||
@app.post("/ingest")
|
||||
async def ingest_text(req: IngestRequest):
|
||||
"""Ingest text content into the vector store."""
|
||||
if not req.content.strip():
|
||||
raise HTTPException(400, "Content cannot be empty")
|
||||
try:
|
||||
return engine.ingest(
|
||||
content=req.content,
|
||||
title=req.title or "Untitled",
|
||||
source=req.source or "unknown",
|
||||
date=req.date,
|
||||
doc_type=req.doc_type or "text",
|
||||
auto_chunk=req.auto_chunk,
|
||||
collection=req.collection,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(400, str(exc))
|
||||
|
||||
|
||||
@app.post("/ingest-file")
|
||||
async def ingest_file(
|
||||
file: UploadFile = File(...),
|
||||
title: Optional[str] = Form(None),
|
||||
source: Optional[str] = Form(None),
|
||||
date: Optional[str] = Form(None),
|
||||
doc_type: Optional[str] = Form(None),
|
||||
collection: Optional[str] = Form(None),
|
||||
):
|
||||
"""Upload and ingest a document (PDF, DOCX, TXT, MD, XLSX, XLS, CSV)."""
|
||||
suffix = Path(file.filename).suffix.lower()
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
content = await file.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
text = extract_text_from_file(tmp_path, file.filename)
|
||||
if not text.strip():
|
||||
raise HTTPException(400, "Could not extract text from file")
|
||||
|
||||
# Keep a copy
|
||||
dest = UPLOAD_DIR / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{file.filename}"
|
||||
shutil.copy2(tmp_path, dest)
|
||||
|
||||
return engine.ingest(
|
||||
content=text,
|
||||
title=title or file.filename,
|
||||
source=source or f"file:{file.filename}",
|
||||
date=date,
|
||||
doc_type=doc_type or suffix.lstrip("."),
|
||||
collection=collection,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(400, str(exc))
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
@app.post("/query")
|
||||
async def query(req: QueryRequest):
|
||||
"""Semantic search over indexed documents."""
|
||||
if not req.question.strip():
|
||||
raise HTTPException(400, "Question cannot be empty")
|
||||
return engine.query(
|
||||
question=req.question,
|
||||
top_k=req.top_k,
|
||||
filter_type=req.filter_type,
|
||||
collection=req.collection,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/documents")
|
||||
async def list_documents(collection: Optional[str] = Query(None)):
|
||||
"""List all indexed documents."""
|
||||
return engine.list_documents(collection=collection)
|
||||
|
||||
|
||||
@app.delete("/documents/{doc_id}")
|
||||
async def delete_document(doc_id: str, collection: Optional[str] = Query(None)):
|
||||
"""Delete a document and all its chunks."""
|
||||
try:
|
||||
return engine.delete_document(doc_id, collection=collection)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(404, str(exc))
|
||||
|
||||
|
||||
@app.post("/transcribe")
|
||||
async def transcribe(
|
||||
file: UploadFile = File(...),
|
||||
auto_ingest: bool = Form(False),
|
||||
title: Optional[str] = Form(None),
|
||||
source: Optional[str] = Form(None),
|
||||
language: Optional[str] = Form(None),
|
||||
collection: Optional[str] = Form(None),
|
||||
):
|
||||
"""Transcribe audio/video via Whisper, optionally auto-ingest the result."""
|
||||
suffix = Path(file.filename).suffix.lower()
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
content = await file.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
audio_path = None
|
||||
try:
|
||||
# If video, extract audio first
|
||||
video_exts = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv"}
|
||||
send_path = tmp_path
|
||||
if suffix in video_exts:
|
||||
logger.info(f"Extracting audio from video: {file.filename}")
|
||||
audio_path = extract_audio_from_video(tmp_path)
|
||||
send_path = audio_path
|
||||
|
||||
async with httpx.AsyncClient(timeout=600.0) as client:
|
||||
with open(send_path, "rb") as audio_file:
|
||||
send_name = file.filename if suffix not in video_exts else Path(file.filename).stem + ".wav"
|
||||
files = {"file": (send_name, audio_file)}
|
||||
resp = await client.post(WHISPER_URL, files=files)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise HTTPException(502, f"Whisper error: {resp.status_code} — {resp.text}")
|
||||
|
||||
result = resp.json()
|
||||
transcription = result.get("text", result.get("transcription", ""))
|
||||
|
||||
if not transcription.strip():
|
||||
raise HTTPException(400, "Transcription returned empty text")
|
||||
|
||||
response = {
|
||||
"filename": file.filename,
|
||||
"transcription": transcription,
|
||||
"word_count": len(transcription.split()),
|
||||
}
|
||||
|
||||
if auto_ingest:
|
||||
ingest_result = engine.ingest(
|
||||
content=transcription,
|
||||
title=title or f"Transcription: {file.filename}",
|
||||
source=source or f"audio:{file.filename}",
|
||||
doc_type="transcription",
|
||||
collection=collection,
|
||||
)
|
||||
response["ingested"] = True
|
||||
response["doc_id"] = ingest_result["doc_id"]
|
||||
response["chunks_created"] = ingest_result["chunks_created"]
|
||||
response["collection"] = ingest_result["collection"]
|
||||
else:
|
||||
response["ingested"] = False
|
||||
|
||||
logger.info(f"Transcribed '{file.filename}' ({response['word_count']} words)")
|
||||
return response
|
||||
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
if audio_path and os.path.exists(audio_path):
|
||||
os.unlink(audio_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entrypoint
|
||||
# ---------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8899, log_level="info")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Memory Store Integration
|
||||
# ---------------------------------------------------------------------------
|
||||
from memory_store import MemoryStore
|
||||
|
||||
MEMORY_DB = os.environ.get("MEMORY_DB", "/app/data/memory.db")
|
||||
memory = MemoryStore(db_path=MEMORY_DB, rag_engine=engine)
|
||||
logger.info(f"Memory store initialized: {MEMORY_DB}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Memory Request Models
|
||||
# ---------------------------------------------------------------------------
|
||||
class SaveObservationRequest(BaseModel):
|
||||
user_id: str
|
||||
content: str
|
||||
type: str = "general"
|
||||
title: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
tool_name: Optional[str] = None
|
||||
importance: int = 1
|
||||
tags: Optional[list] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class QueryMemoryRequest(BaseModel):
|
||||
user_id: str
|
||||
query: str
|
||||
top_k: int = 10
|
||||
type: Optional[str] = None
|
||||
since: Optional[str] = None
|
||||
include_content: bool = False
|
||||
|
||||
|
||||
class GetObservationsRequest(BaseModel):
|
||||
user_id: str
|
||||
ids: list
|
||||
|
||||
|
||||
class TimelineRequest(BaseModel):
|
||||
user_id: str
|
||||
around_id: Optional[int] = None
|
||||
around_time: Optional[str] = None
|
||||
window_minutes: int = 30
|
||||
limit: int = 20
|
||||
|
||||
|
||||
class PreferenceRequest(BaseModel):
|
||||
user_id: str
|
||||
key: str
|
||||
value: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Memory Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
@app.post("/memory/save")
|
||||
async def save_observation(req: SaveObservationRequest):
|
||||
"""Save an observation to memory (SQLite + vector embedding)."""
|
||||
if not req.content.strip():
|
||||
raise HTTPException(400, "Content cannot be empty")
|
||||
if not req.user_id.strip():
|
||||
raise HTTPException(400, "user_id is required")
|
||||
|
||||
return memory.save_observation(
|
||||
user_id=req.user_id,
|
||||
content=req.content,
|
||||
obs_type=req.type,
|
||||
title=req.title,
|
||||
session_id=req.session_id,
|
||||
tool_name=req.tool_name,
|
||||
importance=req.importance,
|
||||
tags=req.tags,
|
||||
metadata=req.metadata,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/memory/query")
|
||||
async def query_memory(req: QueryMemoryRequest):
|
||||
"""
|
||||
Search memory using hybrid vector + structured search.
|
||||
Returns index by default (progressive disclosure).
|
||||
Set include_content=true for full observation content.
|
||||
"""
|
||||
if not req.query.strip():
|
||||
raise HTTPException(400, "Query cannot be empty")
|
||||
if not req.user_id.strip():
|
||||
raise HTTPException(400, "user_id is required")
|
||||
|
||||
return memory.query_memory(
|
||||
user_id=req.user_id,
|
||||
query=req.query,
|
||||
top_k=req.top_k,
|
||||
obs_type=req.type,
|
||||
since=req.since,
|
||||
include_content=req.include_content,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/memory/get")
|
||||
async def get_observations(req: GetObservationsRequest):
|
||||
"""Fetch full observation details by IDs."""
|
||||
if not req.user_id.strip():
|
||||
raise HTTPException(400, "user_id is required")
|
||||
if not req.ids:
|
||||
raise HTTPException(400, "ids list cannot be empty")
|
||||
|
||||
return memory.get_observations(
|
||||
user_id=req.user_id,
|
||||
ids=req.ids,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/memory/timeline")
|
||||
async def get_timeline(req: TimelineRequest):
|
||||
"""Get chronological context around a specific observation or time."""
|
||||
if not req.user_id.strip():
|
||||
raise HTTPException(400, "user_id is required")
|
||||
|
||||
return memory.get_timeline(
|
||||
user_id=req.user_id,
|
||||
around_id=req.around_id,
|
||||
around_time=req.around_time,
|
||||
window_minutes=req.window_minutes,
|
||||
limit=req.limit,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/memory/preference")
|
||||
async def save_preference(req: PreferenceRequest):
|
||||
"""Save or update a user preference."""
|
||||
if not req.user_id.strip():
|
||||
raise HTTPException(400, "user_id is required")
|
||||
|
||||
return memory.save_preference(
|
||||
user_id=req.user_id,
|
||||
key=req.key,
|
||||
value=req.value,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/memory/preferences/{user_id}")
|
||||
async def get_preferences(user_id: str):
|
||||
"""Get all preferences for a user."""
|
||||
return memory.get_preferences(user_id)
|
||||
|
||||
|
||||
@app.get("/memory/stats/{user_id}")
|
||||
async def get_memory_stats(user_id: str):
|
||||
"""Get memory statistics for a user."""
|
||||
return memory.get_stats(user_id)
|
||||
Reference in New Issue
Block a user