feat(06-01): WebSocket endpoint, chat REST API, orchestrator wiring, gateway mounting
- Create gateway/channels/web.py with normalize_web_event() and /chat/ws/{conversation_id}
WebSocket endpoint (auth via first JSON message, typing indicator, Redis pub-sub response)
- Create shared/api/chat.py with GET/POST/DELETE /api/portal/chat/conversations* REST API
with require_tenant_member RBAC enforcement and RLS context var setup
- Add chat_router to shared/api/__init__.py exports
- Mount chat_router and web_chat_router in gateway/main.py (Phase 6 Web Chat routers)
- All 19 unit tests pass; full 313-test suite green
This commit is contained in:
340
packages/gateway/gateway/channels/web.py
Normal file
340
packages/gateway/gateway/channels/web.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
Web Channel Adapter — WebSocket endpoint and message normalizer.
|
||||
|
||||
The web channel lets portal users chat with AI employees directly from
|
||||
the Konstruct portal UI. Messages flow through the same agent pipeline
|
||||
as Slack and WhatsApp — the only difference is the transport layer.
|
||||
|
||||
Message flow:
|
||||
1. Browser opens WebSocket at /chat/ws/{conversation_id}
|
||||
2. Client sends {"type": "auth", "userId": ..., "role": ..., "tenantId": ...}
|
||||
NOTE: Browsers cannot set custom HTTP headers on WebSocket connections,
|
||||
so auth credentials are sent as the first JSON message (Pitfall 1).
|
||||
3. For each user message (type="message"):
|
||||
a. Server immediately sends {"type": "typing"} to client (CHAT-05)
|
||||
b. normalize_web_event() converts to KonstructMessage (channel=WEB)
|
||||
c. User message saved to web_conversation_messages
|
||||
d. handle_message.delay(msg | extras) dispatches to Celery pipeline
|
||||
e. Server subscribes to Redis pub-sub channel for the response
|
||||
f. When orchestrator publishes the response:
|
||||
- Save assistant message to web_conversation_messages
|
||||
- Send {"type": "response", "text": ..., "conversation_id": ...} to client
|
||||
4. On disconnect: unsubscribe and close all Redis connections
|
||||
|
||||
Design notes:
|
||||
- thread_id = conversation_id — scopes agent memory to one conversation (Pitfall 3)
|
||||
- Redis pub-sub connections closed in try/finally to prevent leaks (Pitfall 2)
|
||||
- DB access uses configure_rls_hook + current_tenant_id context var per project pattern
|
||||
- WebSocket is a long-lived connection; each message/response cycle is synchronous
|
||||
within the connection but non-blocking for other connections
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from sqlalchemy import select, text
|
||||
|
||||
from orchestrator.tasks import handle_message
|
||||
from shared.config import settings
|
||||
from shared.db import async_session_factory, engine
|
||||
from shared.models.chat import WebConversation, WebConversationMessage
|
||||
from shared.models.message import ChannelType, KonstructMessage, MessageContent, SenderInfo
|
||||
from shared.redis_keys import webchat_response_key
|
||||
from shared.rls import configure_rls_hook, current_tenant_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Router — mounted in gateway/main.py
|
||||
# ---------------------------------------------------------------------------
|
||||
web_chat_router = APIRouter(tags=["web-chat"])
|
||||
|
||||
# Timeout for waiting for an agent response via Redis pub-sub (seconds)
|
||||
_RESPONSE_TIMEOUT_SECONDS = 60
|
||||
|
||||
|
||||
def normalize_web_event(event: dict[str, Any]) -> KonstructMessage:
|
||||
"""
|
||||
Normalize a web channel event dict into a KonstructMessage.
|
||||
|
||||
The web channel normalizer sets thread_id = conversation_id so that
|
||||
the agent memory pipeline scopes context to this conversation (Pitfall 3).
|
||||
|
||||
Args:
|
||||
event: Dict with keys: text, tenant_id, agent_id, user_id,
|
||||
display_name, conversation_id.
|
||||
|
||||
Returns:
|
||||
KonstructMessage with channel=WEB, thread_id=conversation_id.
|
||||
"""
|
||||
tenant_id: str = event.get("tenant_id", "") or ""
|
||||
user_id: str = event.get("user_id", "") or ""
|
||||
display_name: str = event.get("display_name", "Portal User")
|
||||
conversation_id: str = event.get("conversation_id", "") or ""
|
||||
text_content: str = event.get("text", "") or ""
|
||||
|
||||
return KonstructMessage(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant_id,
|
||||
channel=ChannelType.WEB,
|
||||
channel_metadata={
|
||||
"portal_user_id": user_id,
|
||||
"tenant_id": tenant_id,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
sender=SenderInfo(
|
||||
user_id=user_id,
|
||||
display_name=display_name,
|
||||
),
|
||||
content=MessageContent(
|
||||
text=text_content,
|
||||
),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
thread_id=conversation_id,
|
||||
reply_to=None,
|
||||
context={},
|
||||
)
|
||||
|
||||
|
||||
async def _handle_websocket_connection(
|
||||
websocket: WebSocket,
|
||||
conversation_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Core WebSocket connection handler — separated for testability.
|
||||
|
||||
Lifecycle:
|
||||
1. Accept connection
|
||||
2. Wait for auth message (browser cannot send custom headers)
|
||||
3. Loop: receive messages → type indicator → Celery dispatch → Redis subscribe → response
|
||||
|
||||
Args:
|
||||
websocket: The FastAPI WebSocket connection.
|
||||
conversation_id: The conversation UUID from the URL path.
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Step 1: Auth handshake
|
||||
# Browsers cannot send custom HTTP headers on WebSocket connections.
|
||||
# Auth credentials are sent as the first JSON message.
|
||||
# -------------------------------------------------------------------------
|
||||
try:
|
||||
auth_msg = await websocket.receive_json()
|
||||
except WebSocketDisconnect:
|
||||
return
|
||||
|
||||
if auth_msg.get("type") != "auth":
|
||||
await websocket.send_json({"type": "error", "message": "First message must be auth"})
|
||||
await websocket.close(code=4001)
|
||||
return
|
||||
|
||||
user_id_str: str = auth_msg.get("userId", "") or ""
|
||||
user_role: str = auth_msg.get("role", "") or ""
|
||||
tenant_id_str: str = auth_msg.get("tenantId", "") or ""
|
||||
|
||||
if not user_id_str or not tenant_id_str:
|
||||
await websocket.send_json({"type": "error", "message": "Missing userId or tenantId in auth"})
|
||||
await websocket.close(code=4001)
|
||||
return
|
||||
|
||||
# Validate UUID format
|
||||
try:
|
||||
uuid.UUID(user_id_str)
|
||||
tenant_uuid = uuid.UUID(tenant_id_str)
|
||||
except (ValueError, AttributeError):
|
||||
await websocket.send_json({"type": "error", "message": "Invalid UUID format in auth"})
|
||||
await websocket.close(code=4001)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"WebSocket auth: user=%s role=%s tenant=%s conversation=%s",
|
||||
user_id_str, user_role, tenant_id_str, conversation_id,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Step 2: Message loop
|
||||
# -------------------------------------------------------------------------
|
||||
while True:
|
||||
try:
|
||||
msg_data = await websocket.receive_json()
|
||||
except (WebSocketDisconnect, Exception):
|
||||
break
|
||||
|
||||
if msg_data.get("type") != "message":
|
||||
continue
|
||||
|
||||
text_content: str = msg_data.get("text", "") or ""
|
||||
agent_id_str: str = msg_data.get("agentId", "") or ""
|
||||
msg_conversation_id: str = msg_data.get("conversationId", conversation_id) or conversation_id
|
||||
display_name: str = msg_data.get("displayName", "Portal User")
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# a. Send typing indicator IMMEDIATELY — before any DB or Celery work
|
||||
# -------------------------------------------------------------------
|
||||
await websocket.send_json({"type": "typing"})
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# b. Save user message to web_conversation_messages
|
||||
# -------------------------------------------------------------------
|
||||
configure_rls_hook(engine)
|
||||
rls_token = current_tenant_id.set(tenant_uuid)
|
||||
saved_conversation_id = msg_conversation_id
|
||||
|
||||
try:
|
||||
async with async_session_factory() as session:
|
||||
# Look up the conversation to get tenant-scoped context
|
||||
conv_stmt = select(WebConversation).where(
|
||||
WebConversation.id == uuid.UUID(msg_conversation_id)
|
||||
)
|
||||
conv_result = await session.execute(conv_stmt)
|
||||
conversation = conv_result.scalar_one_or_none()
|
||||
|
||||
if conversation is not None:
|
||||
# Save user message
|
||||
user_msg = WebConversationMessage(
|
||||
conversation_id=uuid.UUID(msg_conversation_id),
|
||||
tenant_id=tenant_uuid,
|
||||
role="user",
|
||||
content=text_content,
|
||||
)
|
||||
session.add(user_msg)
|
||||
|
||||
# Update conversation timestamp
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
|
||||
),
|
||||
{"conv_id": str(msg_conversation_id)},
|
||||
)
|
||||
await session.commit()
|
||||
saved_conversation_id = msg_conversation_id
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to save user message for conversation=%s", msg_conversation_id
|
||||
)
|
||||
finally:
|
||||
current_tenant_id.reset(rls_token)
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# c. Normalize and dispatch to Celery pipeline
|
||||
# -------------------------------------------------------------------
|
||||
event = {
|
||||
"text": text_content,
|
||||
"tenant_id": tenant_id_str,
|
||||
"agent_id": agent_id_str,
|
||||
"user_id": user_id_str,
|
||||
"display_name": display_name,
|
||||
"conversation_id": saved_conversation_id,
|
||||
}
|
||||
normalized_msg = normalize_web_event(event)
|
||||
|
||||
extras = {
|
||||
"conversation_id": saved_conversation_id,
|
||||
"portal_user_id": user_id_str,
|
||||
}
|
||||
task_payload = normalized_msg.model_dump(mode="json") | extras
|
||||
handle_message.delay(task_payload)
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# d. Subscribe to Redis pub-sub and wait for agent response
|
||||
# -------------------------------------------------------------------
|
||||
response_channel = webchat_response_key(tenant_id_str, saved_conversation_id)
|
||||
subscribe_redis = aioredis.from_url(settings.redis_url)
|
||||
try:
|
||||
pubsub = subscribe_redis.pubsub()
|
||||
await pubsub.subscribe(response_channel)
|
||||
|
||||
response_text: str = ""
|
||||
deadline = asyncio.get_event_loop().time() + _RESPONSE_TIMEOUT_SECONDS
|
||||
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
||||
if message and message.get("type") == "message":
|
||||
try:
|
||||
payload = json.loads(message["data"])
|
||||
response_text = payload.get("text", "")
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
await pubsub.unsubscribe(response_channel)
|
||||
finally:
|
||||
await subscribe_redis.aclose()
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# e. Save assistant message and send response to client
|
||||
# -------------------------------------------------------------------
|
||||
if response_text:
|
||||
rls_token2 = current_tenant_id.set(tenant_uuid)
|
||||
try:
|
||||
async with async_session_factory() as session:
|
||||
assistant_msg = WebConversationMessage(
|
||||
conversation_id=uuid.UUID(saved_conversation_id),
|
||||
tenant_id=tenant_uuid,
|
||||
role="assistant",
|
||||
content=response_text,
|
||||
)
|
||||
session.add(assistant_msg)
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"
|
||||
),
|
||||
{"conv_id": str(saved_conversation_id)},
|
||||
)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to save assistant message for conversation=%s", saved_conversation_id
|
||||
)
|
||||
finally:
|
||||
current_tenant_id.reset(rls_token2)
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "response",
|
||||
"text": response_text,
|
||||
"conversation_id": saved_conversation_id,
|
||||
})
|
||||
else:
|
||||
logger.warning(
|
||||
"No response received within %ds for conversation=%s",
|
||||
_RESPONSE_TIMEOUT_SECONDS,
|
||||
saved_conversation_id,
|
||||
)
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Agent did not respond in time. Please try again.",
|
||||
})
|
||||
|
||||
|
||||
@web_chat_router.websocket("/chat/ws/{conversation_id}")
|
||||
async def chat_websocket(websocket: WebSocket, conversation_id: str) -> None:
|
||||
"""
|
||||
WebSocket endpoint for web chat.
|
||||
|
||||
URL: /chat/ws/{conversation_id}
|
||||
|
||||
Protocol:
|
||||
1. Connect
|
||||
2. Send: {"type": "auth", "userId": "...", "role": "...", "tenantId": "..."}
|
||||
3. Send: {"type": "message", "text": "...", "agentId": "...", "conversationId": "..."}
|
||||
4. Receive: {"type": "typing"}
|
||||
5. Receive: {"type": "response", "text": "...", "conversation_id": "..."}
|
||||
|
||||
Closes with code 4001 on auth failure.
|
||||
"""
|
||||
try:
|
||||
await _handle_websocket_connection(websocket, conversation_id)
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket disconnected for conversation=%s", conversation_id)
|
||||
except Exception:
|
||||
logger.exception("Unhandled error in WebSocket handler for conversation=%s", conversation_id)
|
||||
@@ -39,10 +39,12 @@ from slack_bolt.adapter.fastapi.async_handler import AsyncSlackRequestHandler
|
||||
from slack_bolt.async_app import AsyncApp
|
||||
|
||||
from gateway.channels.slack import register_slack_handlers
|
||||
from gateway.channels.web import web_chat_router
|
||||
from gateway.channels.whatsapp import whatsapp_router
|
||||
from shared.api import (
|
||||
billing_router,
|
||||
channels_router,
|
||||
chat_router,
|
||||
invitations_router,
|
||||
llm_keys_router,
|
||||
portal_router,
|
||||
@@ -146,6 +148,12 @@ app.include_router(invitations_router)
|
||||
# ---------------------------------------------------------------------------
|
||||
app.include_router(templates_router)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 6 Web Chat routers
|
||||
# ---------------------------------------------------------------------------
|
||||
app.include_router(chat_router) # REST: /api/portal/chat/*
|
||||
app.include_router(web_chat_router) # WebSocket: /chat/ws/{conversation_id}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routes
|
||||
|
||||
@@ -6,6 +6,7 @@ Import and mount these routers in service main.py files.
|
||||
|
||||
from shared.api.billing import billing_router, webhook_router
|
||||
from shared.api.channels import channels_router
|
||||
from shared.api.chat import chat_router
|
||||
from shared.api.invitations import invitations_router
|
||||
from shared.api.llm_keys import llm_keys_router
|
||||
from shared.api.portal import portal_router
|
||||
@@ -21,4 +22,5 @@ __all__ = [
|
||||
"usage_router",
|
||||
"invitations_router",
|
||||
"templates_router",
|
||||
"chat_router",
|
||||
]
|
||||
|
||||
356
packages/shared/shared/api/chat.py
Normal file
356
packages/shared/shared/api/chat.py
Normal file
@@ -0,0 +1,356 @@
|
||||
"""
|
||||
FastAPI chat REST API — conversation CRUD with RBAC.
|
||||
|
||||
Provides conversation management for the Phase 6 web chat feature.
|
||||
All endpoints require portal authentication via X-Portal-User-Id headers
|
||||
and enforce tenant membership (or platform_admin bypass).
|
||||
|
||||
Endpoints:
|
||||
GET /api/portal/chat/conversations — list conversations
|
||||
POST /api/portal/chat/conversations — create or get-or-create
|
||||
GET /api/portal/chat/conversations/{id}/messages — paginated history
|
||||
DELETE /api/portal/chat/conversations/{id} — reset conversation
|
||||
|
||||
RBAC:
|
||||
- platform_admin: can access any tenant's conversations
|
||||
- customer_admin / customer_operator: must be a member of the target tenant
|
||||
- Other roles: 403
|
||||
|
||||
RLS:
|
||||
All DB queries set current_tenant_id context var before executing so
|
||||
PostgreSQL's FORCE ROW LEVEL SECURITY policy is applied automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import delete, select, text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from shared.api.rbac import PortalCaller, get_portal_caller, require_tenant_member
|
||||
from shared.db import get_session, engine
|
||||
from shared.models.chat import WebConversation, WebConversationMessage
|
||||
from shared.models.tenant import Agent
|
||||
from shared.rls import configure_rls_hook, current_tenant_id
|
||||
|
||||
chat_router = APIRouter(prefix="/api/portal/chat", tags=["chat"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ConversationOut(BaseModel):
|
||||
id: str
|
||||
tenant_id: str
|
||||
agent_id: str
|
||||
agent_name: str | None = None
|
||||
user_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_message_preview: str | None = None
|
||||
|
||||
|
||||
class ConversationCreate(BaseModel):
|
||||
tenant_id: uuid.UUID
|
||||
agent_id: uuid.UUID
|
||||
|
||||
|
||||
class MessageOut(BaseModel):
|
||||
id: str
|
||||
role: str
|
||||
content: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class DeleteResult(BaseModel):
|
||||
deleted: bool
|
||||
conversation_id: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: configure RLS and set context var
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _rls_set(engine_: Any, tenant_uuid: uuid.UUID) -> Any:
|
||||
"""Configure RLS hook and set the tenant context variable."""
|
||||
configure_rls_hook(engine_)
|
||||
return current_tenant_id.set(tenant_uuid)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/portal/chat/conversations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@chat_router.get("/conversations", response_model=list[ConversationOut])
|
||||
async def list_conversations(
|
||||
tenant_id: uuid.UUID = Query(...),
|
||||
caller: PortalCaller = Depends(get_portal_caller),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[ConversationOut]:
|
||||
"""
|
||||
List conversations for the authenticated user within a tenant.
|
||||
|
||||
Platform admins can see all conversations for the tenant.
|
||||
Other users see only their own conversations.
|
||||
"""
|
||||
# RBAC — raises 403 if caller is not a member (platform_admin bypasses)
|
||||
await require_tenant_member(tenant_id=tenant_id, caller=caller, session=session)
|
||||
|
||||
token = _rls_set(engine, tenant_id)
|
||||
try:
|
||||
stmt = (
|
||||
select(WebConversation, Agent.name.label("agent_name"))
|
||||
.join(Agent, WebConversation.agent_id == Agent.id, isouter=True)
|
||||
.where(WebConversation.tenant_id == tenant_id)
|
||||
)
|
||||
# Non-admins only see their own conversations
|
||||
if caller.role != "platform_admin":
|
||||
stmt = stmt.where(WebConversation.user_id == caller.user_id)
|
||||
|
||||
stmt = stmt.order_by(WebConversation.updated_at.desc())
|
||||
result = await session.execute(stmt)
|
||||
rows = result.all()
|
||||
|
||||
conversations: list[ConversationOut] = []
|
||||
for row in rows:
|
||||
conv = row[0]
|
||||
agent_name = row[1] if len(row) > 1 else None
|
||||
conversations.append(
|
||||
ConversationOut(
|
||||
id=str(conv.id),
|
||||
tenant_id=str(conv.tenant_id),
|
||||
agent_id=str(conv.agent_id),
|
||||
agent_name=agent_name,
|
||||
user_id=str(conv.user_id),
|
||||
created_at=conv.created_at,
|
||||
updated_at=conv.updated_at,
|
||||
)
|
||||
)
|
||||
return conversations
|
||||
finally:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/portal/chat/conversations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@chat_router.post("/conversations", response_model=ConversationOut, status_code=status.HTTP_200_OK)
|
||||
async def create_conversation(
|
||||
body: ConversationCreate,
|
||||
caller: PortalCaller = Depends(get_portal_caller),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> ConversationOut:
|
||||
"""
|
||||
Create or get an existing conversation for the caller + agent pair.
|
||||
|
||||
Uses get-or-create semantics: if a conversation already exists for this
|
||||
(tenant_id, agent_id, user_id) triple, it is returned rather than creating
|
||||
a duplicate.
|
||||
"""
|
||||
# RBAC
|
||||
await require_tenant_member(tenant_id=body.tenant_id, caller=caller, session=session)
|
||||
|
||||
token = _rls_set(engine, body.tenant_id)
|
||||
try:
|
||||
# Check for existing conversation
|
||||
existing_stmt = select(WebConversation).where(
|
||||
WebConversation.tenant_id == body.tenant_id,
|
||||
WebConversation.agent_id == body.agent_id,
|
||||
WebConversation.user_id == caller.user_id,
|
||||
)
|
||||
existing_result = await session.execute(existing_stmt)
|
||||
existing = existing_result.scalar_one_or_none()
|
||||
|
||||
if existing is not None:
|
||||
return ConversationOut(
|
||||
id=str(existing.id),
|
||||
tenant_id=str(existing.tenant_id),
|
||||
agent_id=str(existing.agent_id),
|
||||
user_id=str(existing.user_id),
|
||||
created_at=existing.created_at,
|
||||
updated_at=existing.updated_at,
|
||||
)
|
||||
|
||||
# Create new conversation
|
||||
new_conv = WebConversation(
|
||||
id=uuid.uuid4(),
|
||||
tenant_id=body.tenant_id,
|
||||
agent_id=body.agent_id,
|
||||
user_id=caller.user_id,
|
||||
)
|
||||
session.add(new_conv)
|
||||
try:
|
||||
await session.flush()
|
||||
await session.commit()
|
||||
await session.refresh(new_conv)
|
||||
except IntegrityError:
|
||||
# Race condition: another request created it between our SELECT and INSERT
|
||||
await session.rollback()
|
||||
existing_result2 = await session.execute(existing_stmt)
|
||||
existing2 = existing_result2.scalar_one_or_none()
|
||||
if existing2 is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create conversation",
|
||||
)
|
||||
return ConversationOut(
|
||||
id=str(existing2.id),
|
||||
tenant_id=str(existing2.tenant_id),
|
||||
agent_id=str(existing2.agent_id),
|
||||
user_id=str(existing2.user_id),
|
||||
created_at=existing2.created_at,
|
||||
updated_at=existing2.updated_at,
|
||||
)
|
||||
|
||||
return ConversationOut(
|
||||
id=str(new_conv.id),
|
||||
tenant_id=str(new_conv.tenant_id),
|
||||
agent_id=str(new_conv.agent_id),
|
||||
user_id=str(new_conv.user_id),
|
||||
created_at=new_conv.created_at,
|
||||
updated_at=new_conv.updated_at,
|
||||
)
|
||||
finally:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/portal/chat/conversations/{id}/messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@chat_router.get("/conversations/{conversation_id}/messages", response_model=list[MessageOut])
|
||||
async def list_messages(
|
||||
conversation_id: uuid.UUID,
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
before: str | None = Query(default=None),
|
||||
caller: PortalCaller = Depends(get_portal_caller),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[MessageOut]:
|
||||
"""
|
||||
Return paginated message history for a conversation.
|
||||
|
||||
Messages ordered by created_at ASC (oldest first).
|
||||
Cursor pagination via `before` parameter (message ID).
|
||||
|
||||
Ownership enforced: caller must own the conversation OR be platform_admin.
|
||||
"""
|
||||
# Fetch conversation first to verify ownership and get tenant_id
|
||||
conv_stmt = select(WebConversation).where(WebConversation.id == conversation_id)
|
||||
conv_result = await session.execute(conv_stmt)
|
||||
conversation = conv_result.scalar_one_or_none()
|
||||
|
||||
if conversation is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Conversation not found")
|
||||
|
||||
# Ownership check: caller owns the conversation or is platform_admin
|
||||
if caller.role != "platform_admin" and conversation.user_id != caller.user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You do not have access to this conversation",
|
||||
)
|
||||
|
||||
token = _rls_set(engine, conversation.tenant_id)
|
||||
try:
|
||||
msg_stmt = (
|
||||
select(WebConversationMessage)
|
||||
.where(WebConversationMessage.conversation_id == conversation_id)
|
||||
.order_by(WebConversationMessage.created_at.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if before:
|
||||
try:
|
||||
before_uuid = uuid.UUID(before)
|
||||
# Get the cursor message's created_at
|
||||
cursor_stmt = select(WebConversationMessage.created_at).where(
|
||||
WebConversationMessage.id == before_uuid
|
||||
)
|
||||
cursor_result = await session.execute(cursor_stmt)
|
||||
cursor_ts = cursor_result.scalar_one_or_none()
|
||||
if cursor_ts is not None:
|
||||
msg_stmt = msg_stmt.where(WebConversationMessage.created_at < cursor_ts)
|
||||
except (ValueError, AttributeError):
|
||||
pass # Invalid cursor — ignore and return from start
|
||||
|
||||
msg_result = await session.execute(msg_stmt)
|
||||
messages = msg_result.scalars().all()
|
||||
|
||||
return [
|
||||
MessageOut(
|
||||
id=str(m.id),
|
||||
role=m.role,
|
||||
content=m.content,
|
||||
created_at=m.created_at,
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
finally:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /api/portal/chat/conversations/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@chat_router.delete("/conversations/{conversation_id}", response_model=DeleteResult)
|
||||
async def reset_conversation(
|
||||
conversation_id: uuid.UUID,
|
||||
caller: PortalCaller = Depends(get_portal_caller),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> DeleteResult:
|
||||
"""
|
||||
Reset a conversation by deleting all messages.
|
||||
|
||||
The conversation row is kept but all messages are deleted.
|
||||
Updates updated_at on the conversation.
|
||||
|
||||
Ownership enforced: caller must own the conversation OR be platform_admin.
|
||||
"""
|
||||
# Fetch conversation
|
||||
conv_stmt = select(WebConversation).where(WebConversation.id == conversation_id)
|
||||
conv_result = await session.execute(conv_stmt)
|
||||
conversation = conv_result.scalar_one_or_none()
|
||||
|
||||
if conversation is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Conversation not found")
|
||||
|
||||
# Ownership check
|
||||
if caller.role != "platform_admin" and conversation.user_id != caller.user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You do not have access to this conversation",
|
||||
)
|
||||
|
||||
token = _rls_set(engine, conversation.tenant_id)
|
||||
try:
|
||||
# Delete all messages for this conversation
|
||||
delete_stmt = delete(WebConversationMessage).where(
|
||||
WebConversationMessage.conversation_id == conversation_id
|
||||
)
|
||||
await session.execute(delete_stmt)
|
||||
|
||||
# Update conversation timestamp
|
||||
await session.execute(
|
||||
text("UPDATE web_conversations SET updated_at = NOW() WHERE id = :conv_id"),
|
||||
{"conv_id": str(conversation_id)},
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return DeleteResult(deleted=True, conversation_id=str(conversation_id))
|
||||
finally:
|
||||
current_tenant_id.reset(token)
|
||||
Reference in New Issue
Block a user