add websocket request/message schemas

This commit is contained in:
Timothy Pogue
2022-09-07 15:08:01 -06:00
parent 8bfaf0a998
commit 5934495dda
5 changed files with 165 additions and 82 deletions

View File

@@ -8,6 +8,8 @@ from starlette.websockets import WebSocketState
from freqtrade.enums import RPCMessageType, RPCRequestType
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage,
WSRequestSchema, WSWhitelistMessage)
from freqtrade.rpc.rpc import RPC
@@ -18,6 +20,9 @@ router = APIRouter()
async def is_websocket_alive(ws: WebSocket) -> bool:
"""
Check if a FastAPI Websocket is still open
"""
if (
ws.application_state == WebSocketState.CONNECTED and
ws.client_state == WebSocketState.CONNECTED
@@ -31,7 +36,17 @@ async def _process_consumer_request(
channel: WebSocketChannel,
rpc: RPC
):
type, data = request.get('type'), request.get('data')
"""
Validate and handle a request from a websocket consumer
"""
# Validate the request, makes sure it matches the schema
try:
websocket_request = WSRequestSchema.parse_obj(request)
except ValidationError as e:
logger.error(f"Invalid request from {channel}: {e}")
return
type, data = websocket_request.type, websocket_request.data
logger.debug(f"Request of type {type} from {channel}")
@@ -41,35 +56,35 @@ async def _process_consumer_request(
if not data:
return
if not isinstance(data, list):
logger.error(f"Improper subscribe request from channel: {channel} - {request}")
return
# If all topics passed are a valid RPCMessageType, set subscriptions on channel
if all([any(x.value == topic for x in RPCMessageType) for topic in data]):
logger.debug(f"{channel} subscribed to topics: {data}")
channel.set_subscriptions(data)
# We don't send a response for subscriptions
elif type == RPCRequestType.WHITELIST:
# They requested the whitelist
# Get whitelist
whitelist = rpc._ws_request_whitelist()
await channel.send({"type": RPCMessageType.WHITELIST, "data": whitelist})
# Format response
response = WSWhitelistMessage(data=whitelist)
# Send it back
await channel.send(response.dict(exclude_none=True))
elif type == RPCRequestType.ANALYZED_DF:
limit = None
if data:
# Limit the amount of candles per dataframe to 'limit' or 1500
limit = max(data.get('limit', 500), 1500)
limit = max(data.get('limit', 1500), 1500)
# They requested the full historical analyzed dataframes
analyzed_df = rpc._ws_request_analyzed_df(limit)
# For every dataframe, send as a separate message
for _, message in analyzed_df.items():
await channel.send({"type": RPCMessageType.ANALYZED_DF, "data": message})
response = WSAnalyzedDFMessage(data=message)
await channel.send(response.dict(exclude_none=True))
@router.websocket("/message/ws")
@@ -78,6 +93,9 @@ async def message_endpoint(
rpc: RPC = Depends(get_rpc),
channel_manager=Depends(get_channel_manager),
):
"""
Message WebSocket endpoint, facilitates sending RPC messages
"""
try:
if is_websocket_alive(ws):
# TODO: