- memory_store.py: User-isolated observation storage with vector embeddings - New endpoints: /memory/save, /memory/query, /memory/get, /memory/timeline - Progressive disclosure pattern for token-efficient retrieval - Updated Dockerfile to ROCm 7.2 nightly
434 lines
13 KiB
Python
434 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Moxie RAG Service — FastAPI application.
|
|
Multi-collection support for tenant isolation.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import shutil
|
|
import logging
|
|
import tempfile
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
import uvicorn
|
|
from fastapi import FastAPI, File, Form, HTTPException, Query, UploadFile
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
|
|
from rag_engine import RAGEngine
|
|
from document_processor import extract_text_from_file, extract_audio_from_video
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Logging
|
|
# ---------------------------------------------------------------------------
|
|
LOG_DIR = Path(os.environ.get("LOG_DIR", "/app/logs"))
|
|
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
handlers=[
|
|
logging.FileHandler(LOG_DIR / "rag_service.log"),
|
|
logging.StreamHandler(sys.stdout),
|
|
],
|
|
)
|
|
logger = logging.getLogger("moxie-rag")
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config
|
|
# ---------------------------------------------------------------------------
|
|
WHISPER_URL = os.environ.get("WHISPER_URL", "http://host.docker.internal:8081/transcribe")
|
|
UPLOAD_DIR = Path(os.environ.get("UPLOAD_DIR", "/app/data/uploads"))
|
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Engine
|
|
# ---------------------------------------------------------------------------
|
|
engine = RAGEngine(data_dir=os.environ.get("CHROMA_DIR", "/app/data/chromadb"))
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# FastAPI
|
|
# ---------------------------------------------------------------------------
|
|
app = FastAPI(
|
|
title="Moxie RAG Service",
|
|
description="Multi-tenant RAG system for document storage and retrieval",
|
|
version="2.0.0",
|
|
)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Request / response models
|
|
# ---------------------------------------------------------------------------
|
|
class IngestRequest(BaseModel):
|
|
content: str
|
|
title: Optional[str] = None
|
|
source: Optional[str] = None
|
|
date: Optional[str] = None
|
|
doc_type: Optional[str] = "text"
|
|
auto_chunk: bool = True
|
|
collection: Optional[str] = None
|
|
|
|
|
|
class QueryRequest(BaseModel):
|
|
question: str
|
|
top_k: int = 5
|
|
filter_type: Optional[str] = None
|
|
collection: Optional[str] = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Endpoints
|
|
# ---------------------------------------------------------------------------
|
|
@app.get("/")
|
|
async def root():
|
|
return {
|
|
"service": "Moxie RAG",
|
|
"version": "2.0.0",
|
|
"device": engine.device,
|
|
"model": engine.model_name,
|
|
"collections": engine.list_collections(),
|
|
}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {
|
|
"status": "ok",
|
|
"device": engine.device,
|
|
"collections": engine.list_collections(),
|
|
}
|
|
|
|
|
|
@app.get("/collections")
|
|
async def list_collections():
|
|
"""List all collections and their chunk counts."""
|
|
return {"collections": engine.list_collections()}
|
|
|
|
|
|
@app.post("/ingest")
|
|
async def ingest_text(req: IngestRequest):
|
|
"""Ingest text content into the vector store."""
|
|
if not req.content.strip():
|
|
raise HTTPException(400, "Content cannot be empty")
|
|
try:
|
|
return engine.ingest(
|
|
content=req.content,
|
|
title=req.title or "Untitled",
|
|
source=req.source or "unknown",
|
|
date=req.date,
|
|
doc_type=req.doc_type or "text",
|
|
auto_chunk=req.auto_chunk,
|
|
collection=req.collection,
|
|
)
|
|
except ValueError as exc:
|
|
raise HTTPException(400, str(exc))
|
|
|
|
|
|
@app.post("/ingest-file")
|
|
async def ingest_file(
|
|
file: UploadFile = File(...),
|
|
title: Optional[str] = Form(None),
|
|
source: Optional[str] = Form(None),
|
|
date: Optional[str] = Form(None),
|
|
doc_type: Optional[str] = Form(None),
|
|
collection: Optional[str] = Form(None),
|
|
):
|
|
"""Upload and ingest a document (PDF, DOCX, TXT, MD, XLSX, XLS, CSV)."""
|
|
suffix = Path(file.filename).suffix.lower()
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
|
content = await file.read()
|
|
tmp.write(content)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
text = extract_text_from_file(tmp_path, file.filename)
|
|
if not text.strip():
|
|
raise HTTPException(400, "Could not extract text from file")
|
|
|
|
# Keep a copy
|
|
dest = UPLOAD_DIR / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{file.filename}"
|
|
shutil.copy2(tmp_path, dest)
|
|
|
|
return engine.ingest(
|
|
content=text,
|
|
title=title or file.filename,
|
|
source=source or f"file:{file.filename}",
|
|
date=date,
|
|
doc_type=doc_type or suffix.lstrip("."),
|
|
collection=collection,
|
|
)
|
|
except ValueError as exc:
|
|
raise HTTPException(400, str(exc))
|
|
finally:
|
|
os.unlink(tmp_path)
|
|
|
|
|
|
@app.post("/query")
|
|
async def query(req: QueryRequest):
|
|
"""Semantic search over indexed documents."""
|
|
if not req.question.strip():
|
|
raise HTTPException(400, "Question cannot be empty")
|
|
return engine.query(
|
|
question=req.question,
|
|
top_k=req.top_k,
|
|
filter_type=req.filter_type,
|
|
collection=req.collection,
|
|
)
|
|
|
|
|
|
@app.get("/documents")
|
|
async def list_documents(collection: Optional[str] = Query(None)):
|
|
"""List all indexed documents."""
|
|
return engine.list_documents(collection=collection)
|
|
|
|
|
|
@app.delete("/documents/{doc_id}")
|
|
async def delete_document(doc_id: str, collection: Optional[str] = Query(None)):
|
|
"""Delete a document and all its chunks."""
|
|
try:
|
|
return engine.delete_document(doc_id, collection=collection)
|
|
except KeyError as exc:
|
|
raise HTTPException(404, str(exc))
|
|
|
|
|
|
@app.post("/transcribe")
|
|
async def transcribe(
|
|
file: UploadFile = File(...),
|
|
auto_ingest: bool = Form(False),
|
|
title: Optional[str] = Form(None),
|
|
source: Optional[str] = Form(None),
|
|
language: Optional[str] = Form(None),
|
|
collection: Optional[str] = Form(None),
|
|
):
|
|
"""Transcribe audio/video via Whisper, optionally auto-ingest the result."""
|
|
suffix = Path(file.filename).suffix.lower()
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
|
content = await file.read()
|
|
tmp.write(content)
|
|
tmp_path = tmp.name
|
|
|
|
audio_path = None
|
|
try:
|
|
# If video, extract audio first
|
|
video_exts = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv"}
|
|
send_path = tmp_path
|
|
if suffix in video_exts:
|
|
logger.info(f"Extracting audio from video: {file.filename}")
|
|
audio_path = extract_audio_from_video(tmp_path)
|
|
send_path = audio_path
|
|
|
|
async with httpx.AsyncClient(timeout=600.0) as client:
|
|
with open(send_path, "rb") as audio_file:
|
|
send_name = file.filename if suffix not in video_exts else Path(file.filename).stem + ".wav"
|
|
files = {"file": (send_name, audio_file)}
|
|
resp = await client.post(WHISPER_URL, files=files)
|
|
|
|
if resp.status_code != 200:
|
|
raise HTTPException(502, f"Whisper error: {resp.status_code} — {resp.text}")
|
|
|
|
result = resp.json()
|
|
transcription = result.get("text", result.get("transcription", ""))
|
|
|
|
if not transcription.strip():
|
|
raise HTTPException(400, "Transcription returned empty text")
|
|
|
|
response = {
|
|
"filename": file.filename,
|
|
"transcription": transcription,
|
|
"word_count": len(transcription.split()),
|
|
}
|
|
|
|
if auto_ingest:
|
|
ingest_result = engine.ingest(
|
|
content=transcription,
|
|
title=title or f"Transcription: {file.filename}",
|
|
source=source or f"audio:{file.filename}",
|
|
doc_type="transcription",
|
|
collection=collection,
|
|
)
|
|
response["ingested"] = True
|
|
response["doc_id"] = ingest_result["doc_id"]
|
|
response["chunks_created"] = ingest_result["chunks_created"]
|
|
response["collection"] = ingest_result["collection"]
|
|
else:
|
|
response["ingested"] = False
|
|
|
|
logger.info(f"Transcribed '{file.filename}' ({response['word_count']} words)")
|
|
return response
|
|
|
|
finally:
|
|
os.unlink(tmp_path)
|
|
if audio_path and os.path.exists(audio_path):
|
|
os.unlink(audio_path)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Entrypoint
|
|
# ---------------------------------------------------------------------------
|
|
if __name__ == "__main__":
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8899, log_level="info")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Memory Store Integration
|
|
# ---------------------------------------------------------------------------
|
|
from memory_store import MemoryStore
|
|
|
|
MEMORY_DB = os.environ.get("MEMORY_DB", "/app/data/memory.db")
|
|
memory = MemoryStore(db_path=MEMORY_DB, rag_engine=engine)
|
|
logger.info(f"Memory store initialized: {MEMORY_DB}")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Memory Request Models
|
|
# ---------------------------------------------------------------------------
|
|
class SaveObservationRequest(BaseModel):
|
|
user_id: str
|
|
content: str
|
|
type: str = "general"
|
|
title: Optional[str] = None
|
|
session_id: Optional[str] = None
|
|
tool_name: Optional[str] = None
|
|
importance: int = 1
|
|
tags: Optional[list] = None
|
|
metadata: Optional[dict] = None
|
|
|
|
|
|
class QueryMemoryRequest(BaseModel):
|
|
user_id: str
|
|
query: str
|
|
top_k: int = 10
|
|
type: Optional[str] = None
|
|
since: Optional[str] = None
|
|
include_content: bool = False
|
|
|
|
|
|
class GetObservationsRequest(BaseModel):
|
|
user_id: str
|
|
ids: list
|
|
|
|
|
|
class TimelineRequest(BaseModel):
|
|
user_id: str
|
|
around_id: Optional[int] = None
|
|
around_time: Optional[str] = None
|
|
window_minutes: int = 30
|
|
limit: int = 20
|
|
|
|
|
|
class PreferenceRequest(BaseModel):
|
|
user_id: str
|
|
key: str
|
|
value: str
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Memory Endpoints
|
|
# ---------------------------------------------------------------------------
|
|
@app.post("/memory/save")
|
|
async def save_observation(req: SaveObservationRequest):
|
|
"""Save an observation to memory (SQLite + vector embedding)."""
|
|
if not req.content.strip():
|
|
raise HTTPException(400, "Content cannot be empty")
|
|
if not req.user_id.strip():
|
|
raise HTTPException(400, "user_id is required")
|
|
|
|
return memory.save_observation(
|
|
user_id=req.user_id,
|
|
content=req.content,
|
|
obs_type=req.type,
|
|
title=req.title,
|
|
session_id=req.session_id,
|
|
tool_name=req.tool_name,
|
|
importance=req.importance,
|
|
tags=req.tags,
|
|
metadata=req.metadata,
|
|
)
|
|
|
|
|
|
@app.post("/memory/query")
|
|
async def query_memory(req: QueryMemoryRequest):
|
|
"""
|
|
Search memory using hybrid vector + structured search.
|
|
Returns index by default (progressive disclosure).
|
|
Set include_content=true for full observation content.
|
|
"""
|
|
if not req.query.strip():
|
|
raise HTTPException(400, "Query cannot be empty")
|
|
if not req.user_id.strip():
|
|
raise HTTPException(400, "user_id is required")
|
|
|
|
return memory.query_memory(
|
|
user_id=req.user_id,
|
|
query=req.query,
|
|
top_k=req.top_k,
|
|
obs_type=req.type,
|
|
since=req.since,
|
|
include_content=req.include_content,
|
|
)
|
|
|
|
|
|
@app.post("/memory/get")
|
|
async def get_observations(req: GetObservationsRequest):
|
|
"""Fetch full observation details by IDs."""
|
|
if not req.user_id.strip():
|
|
raise HTTPException(400, "user_id is required")
|
|
if not req.ids:
|
|
raise HTTPException(400, "ids list cannot be empty")
|
|
|
|
return memory.get_observations(
|
|
user_id=req.user_id,
|
|
ids=req.ids,
|
|
)
|
|
|
|
|
|
@app.post("/memory/timeline")
|
|
async def get_timeline(req: TimelineRequest):
|
|
"""Get chronological context around a specific observation or time."""
|
|
if not req.user_id.strip():
|
|
raise HTTPException(400, "user_id is required")
|
|
|
|
return memory.get_timeline(
|
|
user_id=req.user_id,
|
|
around_id=req.around_id,
|
|
around_time=req.around_time,
|
|
window_minutes=req.window_minutes,
|
|
limit=req.limit,
|
|
)
|
|
|
|
|
|
@app.post("/memory/preference")
|
|
async def save_preference(req: PreferenceRequest):
|
|
"""Save or update a user preference."""
|
|
if not req.user_id.strip():
|
|
raise HTTPException(400, "user_id is required")
|
|
|
|
return memory.save_preference(
|
|
user_id=req.user_id,
|
|
key=req.key,
|
|
value=req.value,
|
|
)
|
|
|
|
|
|
@app.get("/memory/preferences/{user_id}")
|
|
async def get_preferences(user_id: str):
|
|
"""Get all preferences for a user."""
|
|
return memory.get_preferences(user_id)
|
|
|
|
|
|
@app.get("/memory/stats/{user_id}")
|
|
async def get_memory_stats(user_id: str):
|
|
"""Get memory statistics for a user."""
|
|
return memory.get_stats(user_id)
|