From 5934495dda06c4c62950c8eebe77fb431d394eb9 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Wed, 7 Sep 2022 15:08:01 -0600 Subject: [PATCH] add websocket request/message schemas --- freqtrade/data/dataprovider.py | 10 +- freqtrade/rpc/api_server/api_ws.py | 40 ++++++-- freqtrade/rpc/api_server/ws/schema.py | 78 +++++++++++++++ freqtrade/rpc/external_message_consumer.py | 111 +++++++++------------ freqtrade/rpc/rpc.py | 8 +- 5 files changed, 165 insertions(+), 82 deletions(-) create mode 100644 freqtrade/rpc/api_server/ws/schema.py diff --git a/freqtrade/data/dataprovider.py b/freqtrade/data/dataprovider.py index 44296ab40..4b5494e97 100644 --- a/freqtrade/data/dataprovider.py +++ b/freqtrade/data/dataprovider.py @@ -121,7 +121,8 @@ class DataProvider: 'type': RPCMessageType.ANALYZED_DF, 'data': { 'key': pair_key, - 'value': (dataframe, datetime.now(timezone.utc)) + 'df': dataframe, + 'la': datetime.now(timezone.utc) } } ) @@ -130,7 +131,7 @@ class DataProvider: self, pair: str, dataframe: DataFrame, - last_analyzed: Optional[str] = None, + last_analyzed: Optional[datetime] = None, timeframe: Optional[str] = None, candle_type: Optional[CandleType] = None, producer_name: str = "default" @@ -150,10 +151,7 @@ class DataProvider: if producer_name not in self.__producer_pairs_df: self.__producer_pairs_df[producer_name] = {} - if not last_analyzed: - _last_analyzed = datetime.now(timezone.utc) - else: - _last_analyzed = datetime.fromisoformat(last_analyzed) + _last_analyzed = datetime.now(timezone.utc) if not last_analyzed else last_analyzed self.__producer_pairs_df[producer_name][pair_key] = (dataframe, _last_analyzed) logger.debug(f"External DataFrame for {pair_key} from {producer_name} added.") diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index aaa526401..64c1cebb5 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -8,6 +8,8 @@ from starlette.websockets import WebSocketState from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc from freqtrade.rpc.api_server.ws.channel import WebSocketChannel +from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage, + WSRequestSchema, WSWhitelistMessage) from freqtrade.rpc.rpc import RPC @@ -18,6 +20,9 @@ router = APIRouter() async def is_websocket_alive(ws: WebSocket) -> bool: + """ + Check if a FastAPI Websocket is still open + """ if ( ws.application_state == WebSocketState.CONNECTED and ws.client_state == WebSocketState.CONNECTED @@ -31,7 +36,17 @@ async def _process_consumer_request( channel: WebSocketChannel, rpc: RPC ): - type, data = request.get('type'), request.get('data') + """ + Validate and handle a request from a websocket consumer + """ + # Validate the request, makes sure it matches the schema + try: + websocket_request = WSRequestSchema.parse_obj(request) + except ValidationError as e: + logger.error(f"Invalid request from {channel}: {e}") + return + + type, data = websocket_request.type, websocket_request.data logger.debug(f"Request of type {type} from {channel}") @@ -41,35 +56,35 @@ async def _process_consumer_request( if not data: return - if not isinstance(data, list): - logger.error(f"Improper subscribe request from channel: {channel} - {request}") - return - # If all topics passed are a valid RPCMessageType, set subscriptions on channel if all([any(x.value == topic for x in RPCMessageType) for topic in data]): - - logger.debug(f"{channel} subscribed to topics: {data}") channel.set_subscriptions(data) + # We don't send a response for subscriptions + elif type == RPCRequestType.WHITELIST: - # They requested the whitelist + # Get whitelist whitelist = rpc._ws_request_whitelist() - await channel.send({"type": RPCMessageType.WHITELIST, "data": whitelist}) + # Format response + response = WSWhitelistMessage(data=whitelist) + # Send it back + await channel.send(response.dict(exclude_none=True)) elif type == RPCRequestType.ANALYZED_DF: limit = None if data: # Limit the amount of candles per dataframe to 'limit' or 1500 - limit = max(data.get('limit', 500), 1500) + limit = max(data.get('limit', 1500), 1500) # They requested the full historical analyzed dataframes analyzed_df = rpc._ws_request_analyzed_df(limit) # For every dataframe, send as a separate message for _, message in analyzed_df.items(): - await channel.send({"type": RPCMessageType.ANALYZED_DF, "data": message}) + response = WSAnalyzedDFMessage(data=message) + await channel.send(response.dict(exclude_none=True)) @router.websocket("/message/ws") @@ -78,6 +93,9 @@ async def message_endpoint( rpc: RPC = Depends(get_rpc), channel_manager=Depends(get_channel_manager), ): + """ + Message WebSocket endpoint, facilitates sending RPC messages + """ try: if is_websocket_alive(ws): # TODO: diff --git a/freqtrade/rpc/api_server/ws/schema.py b/freqtrade/rpc/api_server/ws/schema.py new file mode 100644 index 000000000..3221911de --- /dev/null +++ b/freqtrade/rpc/api_server/ws/schema.py @@ -0,0 +1,78 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pandas import DataFrame +from pydantic import BaseModel, ValidationError + +from freqtrade.constants import PairWithTimeframe +from freqtrade.enums.rpcmessagetype import RPCMessageType, RPCRequestType + + +__all__ = ('WSRequestSchema', 'WSMessageSchema', 'ValidationError') + + +class BaseArbitraryModel(BaseModel): + class Config: + arbitrary_types_allowed = True + + +class WSRequestSchema(BaseArbitraryModel): + type: RPCRequestType + data: Optional[Any] = None + + +class WSMessageSchema(BaseArbitraryModel): + type: RPCMessageType + data: Optional[Any] = None + + class Config: + extra = 'allow' + + +# ------------------------------ REQUEST SCHEMAS ---------------------------- + + +class WSSubscribeRequest(WSRequestSchema): + type: RPCRequestType = RPCRequestType.SUBSCRIBE + data: List[RPCMessageType] + + +class WSWhitelistRequest(WSRequestSchema): + type: RPCRequestType = RPCRequestType.WHITELIST + data: None = None + + +class WSAnalyzedDFRequest(WSRequestSchema): + type: RPCRequestType = RPCRequestType.ANALYZED_DF + data: Dict[str, Any] = {"limit": 1500} + + +# ------------------------------ MESSAGE SCHEMAS ---------------------------- + +class WSWhitelistMessage(WSMessageSchema): + type: RPCMessageType = RPCMessageType.WHITELIST + data: List[str] + + +class WSAnalyzedDFMessage(WSMessageSchema): + class AnalyzedDFData(BaseArbitraryModel): + key: PairWithTimeframe + df: DataFrame + la: datetime + + type: RPCMessageType = RPCMessageType.ANALYZED_DF + data: AnalyzedDFData + +# -------------------------------------------------------------------------- + + +if __name__ == "__main__": + message = WSAnalyzedDFMessage( + data={ + "key": ("1", "5m", "spot"), + "df": DataFrame(), + "la": datetime.now() + } + ) + + print(message) diff --git a/freqtrade/rpc/external_message_consumer.py b/freqtrade/rpc/external_message_consumer.py index c1ad0512e..d1e970826 100644 --- a/freqtrade/rpc/external_message_consumer.py +++ b/freqtrade/rpc/external_message_consumer.py @@ -8,14 +8,18 @@ import asyncio import logging import socket from threading import Thread -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List import websockets from freqtrade.data.dataprovider import DataProvider -from freqtrade.enums import RPCMessageType, RPCRequestType +from freqtrade.enums import RPCMessageType from freqtrade.misc import remove_entry_exit_signals from freqtrade.rpc.api_server.ws.channel import WebSocketChannel +from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage, + WSAnalyzedDFRequest, WSMessageSchema, + WSRequestSchema, WSSubscribeRequest, + WSWhitelistMessage, WSWhitelistRequest) if TYPE_CHECKING: @@ -67,15 +71,10 @@ class ExternalMessageConsumer: self.topics = [RPCMessageType.WHITELIST, RPCMessageType.ANALYZED_DF] # Allow setting data for each initial request - self._initial_requests: List[Dict[str, Any]] = [ - { - "type": RPCRequestType.WHITELIST, - "data": None - }, - { - "type": RPCRequestType.ANALYZED_DF, - "data": {"limit": self.initial_candle_limit} - } + self._initial_requests: List[WSRequestSchema] = [ + WSSubscribeRequest(data=self.topics), + WSWhitelistRequest(), + WSAnalyzedDFRequest() ] # Specify which function to use for which RPCMessageType @@ -174,16 +173,10 @@ class ExternalMessageConsumer: logger.info(f"Producer connection success - {channel}") - # Tell the producer we only want these topics - # Should always be the first thing we send - await channel.send( - self.compose_consumer_request(RPCRequestType.SUBSCRIBE, self.topics) - ) - # Now request the initial data from this Producer for request in self._initial_requests: await channel.send( - self.compose_consumer_request(request['type'], request['data']) + request.dict(exclude_none=True) ) # Now receive data, if none is within the time limit, ping @@ -253,74 +246,66 @@ class ExternalMessageConsumer: break - def compose_consumer_request( - self, - type_: RPCRequestType, - data: Optional[Any] = None - ) -> Dict[str, Any]: - """ - Create a request for sending to a producer - - :param type_: The RPCRequestType - :param data: The data to send - :returns: Dict[str, Any] - """ - return {'type': type_, 'data': data} - def handle_producer_message(self, producer: Dict[str, Any], message: Dict[str, Any]): """ Handles external messages from a Producer """ producer_name = producer.get('name', 'default') - # Should we have a default message type? - message_type = message.get('type', RPCMessageType.STATUS) - message_data = message.get('data') + + try: + producer_message = WSMessageSchema.parse_obj(message) + except ValidationError as e: + logger.error(f"Invalid message from {producer_name}: {e}") + return # We shouldn't get empty messages - if message_data is None: + if producer_message.data is None: return - logger.info(f"Received message of type {message_type} from `{producer_name}`") + logger.info(f"Received message of type {producer_message.type} from `{producer_name}`") - message_handler = self._message_handlers.get(message_type) + message_handler = self._message_handlers.get(producer_message.type) if not message_handler: - logger.info(f"Received unhandled message: {message_data}, ignoring...") + logger.info(f"Received unhandled message: {producer_message.data}, ignoring...") return - message_handler(producer_name, message_data) + message_handler(producer_name, producer_message) - def _consume_whitelist_message(self, producer_name: str, message_data: Any): - # We expect List[str] - if not isinstance(message_data, list): + def _consume_whitelist_message(self, producer_name: str, message: Any): + try: + # Validate the message + message = WSWhitelistMessage.parse_obj(message) + except ValidationError: return # Add the pairlist data to the DataProvider - self._dp._set_producer_pairs(message_data, producer_name=producer_name) + self._dp._set_producer_pairs(message.data, producer_name=producer_name) - logger.debug(f"Consumed message from {producer_name} of type RPCMessageType.WHITELIST") + logger.debug(f"Consumed message from {producer_name} of type `RPCMessageType.WHITELIST`") - def _consume_analyzed_df_message(self, producer_name: str, message_data: Any): - # We expect a Dict[str, Any] - if not isinstance(message_data, dict): + def _consume_analyzed_df_message(self, producer_name: str, message: Any): + try: + message = WSAnalyzedDFMessage.parse_obj(message) + except ValidationError: return - key, value = message_data.get('key'), message_data.get('value') + key = message.data.key + df = message.data.df + la = message.data.la - if key and value: - pair, timeframe, candle_type = key - dataframe, last_analyzed = value + pair, timeframe, candle_type = key - # If set, remove the Entry and Exit signals from the Producer - if self._emc_config.get('remove_entry_exit_signals', False): - dataframe = remove_entry_exit_signals(dataframe) + # If set, remove the Entry and Exit signals from the Producer + if self._emc_config.get('remove_entry_exit_signals', False): + df = remove_entry_exit_signals(df) - # Add the dataframe to the dataprovider - self._dp._add_external_df(pair, dataframe, - last_analyzed=last_analyzed, - timeframe=timeframe, - candle_type=candle_type, - producer_name=producer_name) + # Add the dataframe to the dataprovider + self._dp._add_external_df(pair, df, + last_analyzed=la, + timeframe=timeframe, + candle_type=candle_type, + producer_name=producer_name) - logger.debug( - f"Consumed message from {producer_name} of type RPCMessageType.ANALYZED_DF") + logger.debug( + f"Consumed message from {producer_name} of type RPCMessageType.ANALYZED_DF") diff --git a/freqtrade/rpc/rpc.py b/freqtrade/rpc/rpc.py index df90b982e..9821bc001 100644 --- a/freqtrade/rpc/rpc.py +++ b/freqtrade/rpc/rpc.py @@ -1068,8 +1068,12 @@ class RPC: for pair in pairlist: dataframe, last_analyzed = self.__rpc_analysed_dataframe_raw(pair, timeframe, limit) - _data[pair] = {"key": (pair, timeframe, candle_type), - "value": (dataframe, last_analyzed)} + + _data[pair] = { + "key": (pair, timeframe, candle_type), + "df": dataframe, + "la": last_analyzed + } return _data