diff --git a/config_examples/config_full.example.json b/config_examples/config_full.example.json index 8155cb145..5a5096f81 100644 --- a/config_examples/config_full.example.json +++ b/config_examples/config_full.example.json @@ -172,7 +172,24 @@ "jwt_secret_key": "somethingrandom", "CORS_origins": [], "username": "freqtrader", - "password": "SuperSecurePassword" + "password": "SuperSecurePassword", + "ws_token": "secret_ws_t0ken." + }, + "external_message_consumer": { + "enabled": false, + "producers": [ + { + "name": "default", + "host": "127.0.0.2", + "port": 8080, + "ws_token": "secret_ws_t0ken." + } + ], + "wait_timeout": 300, + "ping_timeout": 10, + "sleep_time": 10, + "remove_entry_exit_signals": false, + "message_size_limit": 8 }, "bot_name": "freqtrade", "db_url": "sqlite:///tradesv3.sqlite", diff --git a/docs/configuration.md b/docs/configuration.md index bb4f5ce41..b3dbcd817 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -225,14 +225,16 @@ Mandatory parameters are marked as **Required**, which means that they are requi | `webhook.webhookexitcancel` | Payload to send on exit order cancel. Only required if `webhook.enabled` is `true`. See the [webhook documentation](webhook-config.md) for more details.
**Datatype:** String | `webhook.webhookexitfill` | Payload to send on exit order filled. Only required if `webhook.enabled` is `true`. See the [webhook documentation](webhook-config.md) for more details.
**Datatype:** String | `webhook.webhookstatus` | Payload to send on status calls. Only required if `webhook.enabled` is `true`. See the [webhook documentation](webhook-config.md) for more details.
**Datatype:** String -| | **Rest API / FreqUI** +| | **Rest API / FreqUI / External Signals** | `api_server.enabled` | Enable usage of API Server. See the [API Server documentation](rest-api.md) for more details.
**Datatype:** Boolean | `api_server.listen_ip_address` | Bind IP address. See the [API Server documentation](rest-api.md) for more details.
**Datatype:** IPv4 | `api_server.listen_port` | Bind Port. See the [API Server documentation](rest-api.md) for more details.
**Datatype:** Integer between 1024 and 65535 | `api_server.verbosity` | Logging verbosity. `info` will print all RPC Calls, while "error" will only display errors.
**Datatype:** Enum, either `info` or `error`. Defaults to `info`. | `api_server.username` | Username for API server. See the [API Server documentation](rest-api.md) for more details.
**Keep it in secret, do not disclose publicly.**
**Datatype:** String | `api_server.password` | Password for API server. See the [API Server documentation](rest-api.md) for more details.
**Keep it in secret, do not disclose publicly.**
**Datatype:** String +| `api_server.ws_token` | API token for the Message WebSocket. See the [API Server documentation](rest-api.md) for more details.
**Keep it in secret, do not disclose publicly.**
**Datatype:** String | `bot_name` | Name of the bot. Passed via API to a client - can be shown to distinguish / name bots.
*Defaults to `freqtrade`*
**Datatype:** String +| `external_message_consumer` | Enable [Producer/Consumer mode](producer-consumer.md) for more details.
**Datatype:** Dict | | **Other** | `initial_state` | Defines the initial application state. If set to stopped, then the bot has to be explicitly started via `/start` RPC command.
*Defaults to `stopped`.*
**Datatype:** Enum, either `stopped` or `running` | `force_entry_enable` | Enables the RPC Commands to force a Trade entry. More information below.
**Datatype:** Boolean diff --git a/docs/producer-consumer.md b/docs/producer-consumer.md new file mode 100644 index 000000000..b69406edf --- /dev/null +++ b/docs/producer-consumer.md @@ -0,0 +1,163 @@ +# Producer / Consumer mode + +freqtrade provides a mechanism whereby an instance (also called `consumer`) may listen to messages from an upstream freqtrade instance (also called `producer`) using the message websocket. Mainly, `analyzed_df` and `whitelist` messages. This allows the reuse of computed indicators (and signals) for pairs in multiple bots without needing to compute them multiple times. + +See [Message Websocket](rest-api.md#message-websocket) in the Rest API docs for setting up the `api_server` configuration for your message websocket (this will be your producer). + +!!! Note + We strongly recommend to set `ws_token` to something random and known only to yourself to avoid unauthorized access to your bot. + +## Configuration + +Enable subscribing to an instance by adding the `external_message_consumer` section to the consumer's config file. + +```json +{ + //... + "external_message_consumer": { + "enabled": true, + "producers": [ + { + "name": "default", // This can be any name you'd like, default is "default" + "host": "127.0.0.1", // The host from your producer's api_server config + "port": 8080, // The port from your producer's api_server config + "ws_token": "sercet_Ws_t0ken" // The ws_token from your producer's api_server config + } + ], + // The following configurations are optional, and usually not required + // "wait_timeout": 300, + // "ping_timeout": 10, + // "sleep_time": 10, + // "remove_entry_exit_signals": false, + // "message_size_limit": 8 + } + //... +} +``` + +| Parameter | Description | +|------------|-------------| +| `enabled` | **Required.** Enable consumer mode. If set to false, all other settings in this section are ignored.
*Defaults to `false`.*
**Datatype:** boolean . +| `producers` | **Required.** List of producers
**Datatype:** Array. +| `producers.name` | **Required.** Name of this producer. This name must be used in calls to `get_producer_pairs()` and `get_producer_df()` if more than one producer is used.
**Datatype:** string +| `producers.host` | **Required.** The hostname or IP address from your producer.
**Datatype:** string +| `producers.port` | **Required.** The port matching the above host.
**Datatype:** string +| `producers.ws_token` | **Required.** `ws_token` as configured on the producer.
**Datatype:** string +| | **Optional settings** +| `wait_timeout` | Timeout until we ping again if no message is received.
*Defaults to `300`.*
**Datatype:** Integer - in seconds. +| `wait_timeout` | Ping timeout
*Defaults to `10`.*
**Datatype:** Integer - in seconds. +| `sleep_time` | Sleep time before retrying to connect.
*Defaults to `10`.*
**Datatype:** Integer - in seconds. +| `remove_entry_exit_signals` | Remove signal columns from the dataframe (set them to 0) on dataframe receipt.
*Defaults to `10`.*
**Datatype:** Integer - in seconds. +| `message_size_limit` | Size limit per message
*Defaults to `8`.*
**Datatype:** Integer - Megabytes. + +Instead of (or as well as) calculating indicators in `populate_indicators()` the follower instance listens on the connection to a producer instance's messages (or multiple producer instances in advanced configurations) and requests the producer's most recently analyzed dataframes for each pair in the active whitelist. + +A consumer instance will then have a full copy of the analyzed dataframes without the need to calculate them itself. + +## Examples + +### Example - Producer Strategy + +A simple strategy with multiple indicators. No special considerations are required in the strategy itself. + +```py +class ProducerStrategy(IStrategy): + #... + def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame: + """ + Calculate indicators in the standard freqtrade way which can then be broadcast to other instances + """ + dataframe['rsi'] = ta.RSI(dataframe) + bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=20, stds=2) + dataframe['bb_lowerband'] = bollinger['lower'] + dataframe['bb_middleband'] = bollinger['mid'] + dataframe['bb_upperband'] = bollinger['upper'] + dataframe['tema'] = ta.TEMA(dataframe, timeperiod=9) + + return dataframe + + def populate_entry_trend(self, dataframe: DataFrame, metadata: dict) -> DataFrame: + """ + Populates the entry signal for the given dataframe + """ + dataframe.loc[ + ( + (qtpylib.crossed_above(dataframe['rsi'], self.buy_rsi.value)) & + (dataframe['tema'] <= dataframe['bb_middleband']) & + (dataframe['tema'] > dataframe['tema'].shift(1)) & + (dataframe['volume'] > 0) + ), + 'enter_long'] = 1 + + return dataframe +``` + +!!! Tip "FreqAI" + You can use this to setup [FreqAI](freqai.md) on a powerful machine, while you run consumers on simple machines like raspberries, which can interpret the signals generated from the producer in different ways. + + +### Example - Consumer Strategy + +A logically equivalent strategy which calculates no indicators itself, but will have the same analyzed dataframes available to make trading decisions based on the indicators calculated in the producer. In this example the consumer has the same entry criteria, however this is not necessary. The consumer may use different logic to enter/exit trades, and only use the indicators as specified. + +```py +class ConsumerStrategy(IStrategy): + #... + process_only_new_candles = False # required for consumers + + _columns_to_expect = ['rsi_default', 'tema_default', 'bb_middleband_default'] + + def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame: + """ + Use the websocket api to get pre-populated indicators from another freqtrade instance. + Use `self.dp.get_producer_df(pair)` to get the dataframe + """ + pair = metadata['pair'] + timeframe = self.timeframe + + producer_pairs = self.dp.get_producer_pairs() + # You can specify which producer to get pairs from via: + # self.dp.get_producer_pairs("my_other_producer") + + # This func returns the analyzed dataframe, and when it was analyzed + producer_dataframe, _ = self.dp.get_producer_df(pair) + # You can get other data if the producer makes it available: + # self.dp.get_producer_df( + # pair, + # timeframe="1h", + # candle_type=CandleType.SPOT, + # producer_name="my_other_producer" + # ) + + if not producer_dataframe.empty: + # If you plan on passing the producer's entry/exit signal directly, + # specify ffill=False or it will have unintended results + merged_dataframe = merge_informative_pair(dataframe, producer_dataframe, + timeframe, timeframe, + append_timeframe=False, + suffix="default") + return merged_dataframe + else: + dataframe[self._columns_to_expect] = 0 + + return dataframe + + def populate_entry_trend(self, dataframe: DataFrame, metadata: dict) -> DataFrame: + """ + Populates the entry signal for the given dataframe + """ + # Use the dataframe columns as if we calculated them ourselves + dataframe.loc[ + ( + (qtpylib.crossed_above(dataframe['rsi_default'], self.buy_rsi.value)) & + (dataframe['tema_default'] <= dataframe['bb_middleband_default']) & + (dataframe['tema_default'] > dataframe['tema_default'].shift(1)) & + (dataframe['volume'] > 0) + ), + 'enter_long'] = 1 + + return dataframe +``` + +!!! Tip "Using upstream signals" + By setting `remove_entry_exit_signals=false`, you can also use the producer's signals directly. They should be available as `enter_long_default` (assuming `suffix="default"` was used) - and can be used as either signal directly, or as additional indicator. diff --git a/docs/rest-api.md b/docs/rest-api.md index cc82aadda..c7d762648 100644 --- a/docs/rest-api.md +++ b/docs/rest-api.md @@ -31,7 +31,8 @@ Sample configuration: "jwt_secret_key": "somethingrandom", "CORS_origins": [], "username": "Freqtrader", - "password": "SuperSecret1!" + "password": "SuperSecret1!", + "ws_token": "sercet_Ws_t0ken" }, ``` @@ -66,7 +67,7 @@ secrets.token_hex() !!! Danger "Password selection" Please make sure to select a very strong, unique password to protect your bot from unauthorized access. - Also change `jwt_secret_key` to something random (no need to remember this, but it'll be used to encrypt your session, so it better be something unique!). + Also change `jwt_secret_key` to something random (no need to remember this, but it'll be used to encrypt your session, so it better be something unique!). ### Configuration with docker @@ -93,7 +94,6 @@ Make sure that the following 2 lines are available in your docker-compose file: !!! Danger "Security warning" By using `8080:8080` in the docker port mapping, the API will be available to everyone connecting to the server under the correct port, so others may be able to control your bot. - ## Rest API ### Consuming the API @@ -274,7 +274,7 @@ reload_config Reload configuration. show_config - + Returns part of the configuration, relevant for trading operations. start @@ -322,6 +322,73 @@ whitelist ``` +### Message WebSocket + +The API Server includes a websocket endpoint for subscribing to RPC messages from the freqtrade Bot. +This can be used to consume real-time data from your bot, such as entry/exit fill messages, whitelist changes, populated indicators for pairs, and more. + +This is also used to setup [Producer/Consumer mode](producer-consumer.md) in Freqtrade. + +Assuming your rest API is set to `127.0.0.1` on port `8080`, the endpoint is available at `http://localhost:8080/api/v1/message/ws`. + +To access the websocket endpoint, the `ws_token` is required as a query parameter in the endpoint URL. + +To generate a safe `ws_token` you can run the following code: + +``` python +>>> import secrets +>>> secrets.token_urlsafe(25) +'hZ-y58LXyX_HZ8O1cJzVyN6ePWrLpNQv4Q' +``` + +You would then add that token under `ws_token` in your `api_server` config. Like so: + +``` json +"api_server": { + "enabled": true, + "listen_ip_address": "127.0.0.1", + "listen_port": 8080, + "verbosity": "error", + "enable_openapi": false, + "jwt_secret_key": "somethingrandom", + "CORS_origins": [], + "username": "Freqtrader", + "password": "SuperSecret1!", + "ws_token": "hZ-y58LXyX_HZ8O1cJzVyN6ePWrLpNQv4Q" // <----- +}, +``` + +You can now connect to the endpoint at `http://localhost:8080/api/v1/message/ws?token=hZ-y58LXyX_HZ8O1cJzVyN6ePWrLpNQv4Q`. + +!!! Danger "Reuse of example tokens" + Please do not use the above example token. To make sure you are secure, generate a completely new token. + +#### Using the WebSocket + +Once connected to the WebSocket, the bot will broadcast RPC messages to anyone who is subscribed to them. To subscribe to a list of messages, you must send a JSON request through the WebSocket like the one below. The `data` key must be a list of message type strings. + +``` json +{ + "type": "subscribe", + "data": ["whitelist", "analyzed_df"] // A list of string message types +} +``` + +For a list of message types, please refer to the RPCMessageType enum in `freqtrade/enums/rpcmessagetype.py` + +Now anytime those types of RPC messages are sent in the bot, you will receive them through the WebSocket as long as the connection is active. They typically take the same form as the request: + +``` json +{ + "type": "analyzed_df", + "data": { + "key": ["NEO/BTC", "5m", "spot"], + "df": {}, // The dataframe + "la": "2022-09-08 22:14:41.457786+00:00" + } +} +``` + ### OpenAPI interface To enable the builtin openAPI interface (Swagger UI), specify `"enable_openapi": true` in the api_server configuration. diff --git a/freqtrade/commands/build_config_commands.py b/freqtrade/commands/build_config_commands.py index 01cfa800a..1abd26328 100644 --- a/freqtrade/commands/build_config_commands.py +++ b/freqtrade/commands/build_config_commands.py @@ -211,6 +211,7 @@ def ask_user_config() -> Dict[str, Any]: ) # Force JWT token to be a random string answers['api_server_jwt_key'] = secrets.token_hex() + answers['api_server_ws_token'] = secrets.token_urlsafe(25) return answers diff --git a/freqtrade/constants.py b/freqtrade/constants.py index 1b3edddef..fe17b40bc 100644 --- a/freqtrade/constants.py +++ b/freqtrade/constants.py @@ -243,6 +243,7 @@ CONF_SCHEMA = { 'exchange': {'$ref': '#/definitions/exchange'}, 'edge': {'$ref': '#/definitions/edge'}, 'freqai': {'$ref': '#/definitions/freqai'}, + 'external_message_consumer': {'$ref': '#/definitions/external_message_consumer'}, 'experimental': { 'type': 'object', 'properties': { @@ -404,6 +405,7 @@ CONF_SCHEMA = { }, 'username': {'type': 'string'}, 'password': {'type': 'string'}, + 'ws_token': {'type': ['string', 'array'], 'items': {'type': 'string'}}, 'jwt_secret_key': {'type': 'string'}, 'CORS_origins': {'type': 'array', 'items': {'type': 'string'}}, 'verbosity': {'type': 'string', 'enum': ['error', 'info']}, @@ -488,6 +490,47 @@ CONF_SCHEMA = { }, 'required': ['process_throttle_secs', 'allowed_risk'] }, + 'external_message_consumer': { + 'type': 'object', + 'properties': { + 'enabled': {'type': 'boolean', 'default': False}, + 'producers': { + 'type': 'array', + 'items': { + 'type': 'object', + 'properties': { + 'name': {'type': 'string'}, + 'host': {'type': 'string'}, + 'port': { + 'type': 'integer', + 'default': 8080, + 'minimum': 0, + 'maximum': 65535 + }, + 'ws_token': {'type': 'string'}, + }, + 'required': ['name', 'host', 'ws_token'] + } + }, + 'wait_timeout': {'type': 'integer', 'minimum': 0}, + 'sleep_time': {'type': 'integer', 'minimum': 0}, + 'ping_timeout': {'type': 'integer', 'minimum': 0}, + 'remove_entry_exit_signals': {'type': 'boolean', 'default': False}, + 'initial_candle_limit': { + 'type': 'integer', + 'minimum': 0, + 'maximum': 1500, + 'default': 1500 + }, + 'message_size_limit': { # In megabytes + 'type': 'integer', + 'minimum': 1, + 'maxmium': 20, + 'default': 8, + } + }, + 'required': ['producers'] + }, "freqai": { "type": "object", "properties": { diff --git a/freqtrade/data/dataprovider.py b/freqtrade/data/dataprovider.py index 43850ddd9..1a0903516 100644 --- a/freqtrade/data/dataprovider.py +++ b/freqtrade/data/dataprovider.py @@ -14,9 +14,10 @@ from pandas import DataFrame from freqtrade.configuration import TimeRange from freqtrade.constants import Config, ListPairsWithTimeframes, PairWithTimeframe from freqtrade.data.history import load_pair_history -from freqtrade.enums import CandleType, RunMode +from freqtrade.enums import CandleType, RPCMessageType, RunMode from freqtrade.exceptions import ExchangeError, OperationalException from freqtrade.exchange import Exchange, timeframe_to_seconds +from freqtrade.rpc import RPCManager from freqtrade.util import PeriodicCache @@ -28,17 +29,33 @@ MAX_DATAFRAME_CANDLES = 1000 class DataProvider: - def __init__(self, config: Config, exchange: Optional[Exchange], pairlists=None) -> None: + def __init__( + self, + config: Config, + exchange: Optional[Exchange], + pairlists=None, + rpc: Optional[RPCManager] = None + ) -> None: self._config = config self._exchange = exchange self._pairlists = pairlists + self.__rpc = rpc self.__cached_pairs: Dict[PairWithTimeframe, Tuple[DataFrame, datetime]] = {} self.__slice_index: Optional[int] = None self.__cached_pairs_backtesting: Dict[PairWithTimeframe, DataFrame] = {} + self.__producer_pairs_df: Dict[str, + Dict[PairWithTimeframe, Tuple[DataFrame, datetime]]] = {} + self.__producer_pairs: Dict[str, List[str]] = {} self._msg_queue: deque = deque() + self._default_candle_type = self._config.get('candle_type_def', CandleType.SPOT) + self._default_timeframe = self._config.get('timeframe', '1h') + self.__msg_cache = PeriodicCache( - maxsize=1000, ttl=timeframe_to_seconds(self._config.get('timeframe', '1h'))) + maxsize=1000, ttl=timeframe_to_seconds(self._default_timeframe)) + + self.producers = self._config.get('external_message_consumer', {}).get('producers', []) + self.external_data_enabled = len(self.producers) > 0 def _set_dataframe_max_index(self, limit_index: int): """ @@ -63,9 +80,110 @@ class DataProvider: :param dataframe: analyzed dataframe :param candle_type: Any of the enum CandleType (must match trading mode!) """ - self.__cached_pairs[(pair, timeframe, candle_type)] = ( + pair_key = (pair, timeframe, candle_type) + self.__cached_pairs[pair_key] = ( dataframe, datetime.now(timezone.utc)) + # For multiple producers we will want to merge the pairlists instead of overwriting + def _set_producer_pairs(self, pairlist: List[str], producer_name: str = "default"): + """ + Set the pairs received to later be used. + + :param pairlist: List of pairs + """ + self.__producer_pairs[producer_name] = pairlist + + def get_producer_pairs(self, producer_name: str = "default") -> List[str]: + """ + Get the pairs cached from the producer + + :returns: List of pairs + """ + return self.__producer_pairs.get(producer_name, []).copy() + + def _emit_df( + self, + pair_key: PairWithTimeframe, + dataframe: DataFrame + ) -> None: + """ + Send this dataframe as an ANALYZED_DF message to RPC + + :param pair_key: PairWithTimeframe tuple + :param data: Tuple containing the DataFrame and the datetime it was cached + """ + if self.__rpc: + self.__rpc.send_msg( + { + 'type': RPCMessageType.ANALYZED_DF, + 'data': { + 'key': pair_key, + 'df': dataframe, + 'la': datetime.now(timezone.utc) + } + } + ) + + def _add_external_df( + self, + pair: str, + dataframe: DataFrame, + last_analyzed: datetime, + timeframe: str, + candle_type: CandleType, + producer_name: str = "default" + ) -> None: + """ + Add the pair data to this class from an external source. + + :param pair: pair to get the data for + :param timeframe: Timeframe to get data for + :param candle_type: Any of the enum CandleType (must match trading mode!) + """ + pair_key = (pair, timeframe, candle_type) + + if producer_name not in self.__producer_pairs_df: + self.__producer_pairs_df[producer_name] = {} + + _last_analyzed = datetime.now(timezone.utc) if not last_analyzed else last_analyzed + + self.__producer_pairs_df[producer_name][pair_key] = (dataframe, _last_analyzed) + logger.debug(f"External DataFrame for {pair_key} from {producer_name} added.") + + def get_producer_df( + self, + pair: str, + timeframe: Optional[str] = None, + candle_type: Optional[CandleType] = None, + producer_name: str = "default" + ) -> Tuple[DataFrame, datetime]: + """ + Get the pair data from producers. + + :param pair: pair to get the data for + :param timeframe: Timeframe to get data for + :param candle_type: Any of the enum CandleType (must match trading mode!) + :returns: Tuple of the DataFrame and last analyzed timestamp + """ + _timeframe = self._default_timeframe if not timeframe else timeframe + _candle_type = self._default_candle_type if not candle_type else candle_type + + pair_key = (pair, _timeframe, _candle_type) + + # If we have no data from this Producer yet + if producer_name not in self.__producer_pairs_df: + # We don't have this data yet, return empty DataFrame and datetime (01-01-1970) + return (DataFrame(), datetime.fromtimestamp(0, tz=timezone.utc)) + + # If we do have data from that Producer, but no data on this pair_key + if pair_key not in self.__producer_pairs_df[producer_name]: + # We don't have this data yet, return empty DataFrame and datetime (01-01-1970) + return (DataFrame(), datetime.fromtimestamp(0, tz=timezone.utc)) + + # We have it, return this data + df, la = self.__producer_pairs_df[producer_name][pair_key] + return (df.copy(), la) + def add_pairlisthandler(self, pairlists) -> None: """ Allow adding pairlisthandler after initialization diff --git a/freqtrade/enums/__init__.py b/freqtrade/enums/__init__.py index d2f5474fc..146d65f2d 100644 --- a/freqtrade/enums/__init__.py +++ b/freqtrade/enums/__init__.py @@ -6,7 +6,7 @@ from freqtrade.enums.exittype import ExitType from freqtrade.enums.hyperoptstate import HyperoptState from freqtrade.enums.marginmode import MarginMode from freqtrade.enums.ordertypevalue import OrderTypeValues -from freqtrade.enums.rpcmessagetype import RPCMessageType +from freqtrade.enums.rpcmessagetype import RPCMessageType, RPCRequestType from freqtrade.enums.runmode import NON_UTIL_MODES, OPTIMIZE_MODES, TRADING_MODES, RunMode from freqtrade.enums.signaltype import SignalDirection, SignalTagType, SignalType from freqtrade.enums.state import State diff --git a/freqtrade/enums/rpcmessagetype.py b/freqtrade/enums/rpcmessagetype.py index 415d8f18c..fae121a09 100644 --- a/freqtrade/enums/rpcmessagetype.py +++ b/freqtrade/enums/rpcmessagetype.py @@ -1,7 +1,7 @@ from enum import Enum -class RPCMessageType(Enum): +class RPCMessageType(str, Enum): STATUS = 'status' WARNING = 'warning' STARTUP = 'startup' @@ -19,8 +19,19 @@ class RPCMessageType(Enum): STRATEGY_MSG = 'strategy_msg' + WHITELIST = 'whitelist' + ANALYZED_DF = 'analyzed_df' + def __repr__(self): return self.value def __str__(self): return self.value + + +# Enum for parsing requests from ws consumers +class RPCRequestType(str, Enum): + SUBSCRIBE = 'subscribe' + + WHITELIST = 'whitelist' + ANALYZED_DF = 'analyzed_df' diff --git a/freqtrade/freqtradebot.py b/freqtrade/freqtradebot.py index eb5705c34..72b88a82f 100644 --- a/freqtrade/freqtradebot.py +++ b/freqtrade/freqtradebot.py @@ -29,6 +29,7 @@ from freqtrade.plugins.pairlistmanager import PairListManager from freqtrade.plugins.protectionmanager import ProtectionManager from freqtrade.resolvers import ExchangeResolver, StrategyResolver from freqtrade.rpc import RPCManager +from freqtrade.rpc.external_message_consumer import ExternalMessageConsumer from freqtrade.strategy.interface import IStrategy from freqtrade.strategy.strategy_wrapper import strategy_safe_wrapper from freqtrade.util import FtPrecise @@ -72,6 +73,8 @@ class FreqtradeBot(LoggingMixin): PairLocks.timeframe = self.config['timeframe'] + self.pairlists = PairListManager(self.exchange, self.config) + # RPC runs in separate threads, can start handling external commands just after # initialization, even before Freqtradebot has a chance to start its throttling, # so anything in the Freqtradebot instance should be ready (initialized), including @@ -79,9 +82,7 @@ class FreqtradeBot(LoggingMixin): # Keep this at the end of this initialization method. self.rpc: RPCManager = RPCManager(self) - self.pairlists = PairListManager(self.exchange, self.config) - - self.dataprovider = DataProvider(self.config, self.exchange, self.pairlists) + self.dataprovider = DataProvider(self.config, self.exchange, self.pairlists, self.rpc) # Attach Dataprovider to strategy instance self.strategy.dp = self.dataprovider @@ -92,6 +93,10 @@ class FreqtradeBot(LoggingMixin): self.edge = Edge(self.config, self.exchange, self.strategy) if \ self.config.get('edge', {}).get('enabled', False) else None + # Init ExternalMessageConsumer if enabled + self.emc = ExternalMessageConsumer(self.config, self.dataprovider) if \ + self.config.get('external_message_consumer', {}).get('enabled', False) else None + self.active_pair_whitelist = self._refresh_active_whitelist() # Set initial bot state from config @@ -151,9 +156,11 @@ class FreqtradeBot(LoggingMixin): finally: self.strategy.ft_bot_cleanup() - self.rpc.cleanup() - Trade.commit() - self.exchange.close() + self.rpc.cleanup() + if self.emc: + self.emc.shutdown() + Trade.commit() + self.exchange.close() def startup(self) -> None: """ @@ -254,6 +261,7 @@ class FreqtradeBot(LoggingMixin): pairs that have open trades. """ # Refresh whitelist + _prev_whitelist = self.pairlists.whitelist self.pairlists.refresh_pairlist() _whitelist = self.pairlists.whitelist @@ -266,6 +274,11 @@ class FreqtradeBot(LoggingMixin): # Extend active-pair whitelist with pairs of open trades # It ensures that candle (OHLCV) data are downloaded for open trades as well _whitelist.extend([trade.pair for trade in trades if trade.pair not in _whitelist]) + + # Called last to include the included pairs + if _prev_whitelist != _whitelist: + self.rpc.send_msg({'type': RPCMessageType.WHITELIST, 'data': _whitelist}) + return _whitelist def get_free_open_trades(self) -> int: diff --git a/freqtrade/misc.py b/freqtrade/misc.py index c3968e61c..56b3fef0e 100644 --- a/freqtrade/misc.py +++ b/freqtrade/misc.py @@ -10,9 +10,11 @@ from typing import Any, Iterator, List from typing.io import IO from urllib.parse import urlparse +import pandas import rapidjson from freqtrade.constants import DECIMAL_PER_COIN_FALLBACK, DECIMALS_PER_COIN +from freqtrade.enums import SignalTagType, SignalType logger = logging.getLogger(__name__) @@ -249,3 +251,41 @@ def parse_db_uri_for_logging(uri: str): return uri pwd = parsed_db_uri.netloc.split(':')[1].split('@')[0] return parsed_db_uri.geturl().replace(f':{pwd}@', ':*****@') + + +def dataframe_to_json(dataframe: pandas.DataFrame) -> str: + """ + Serialize a DataFrame for transmission over the wire using JSON + :param dataframe: A pandas DataFrame + :returns: A JSON string of the pandas DataFrame + """ + return dataframe.to_json(orient='split') + + +def json_to_dataframe(data: str) -> pandas.DataFrame: + """ + Deserialize JSON into a DataFrame + :param data: A JSON string + :returns: A pandas DataFrame from the JSON string + """ + dataframe = pandas.read_json(data, orient='split') + if 'date' in dataframe.columns: + dataframe['date'] = pandas.to_datetime(dataframe['date'], unit='ms', utc=True) + + return dataframe + + +def remove_entry_exit_signals(dataframe: pandas.DataFrame): + """ + Remove Entry and Exit signals from a DataFrame + + :param dataframe: The DataFrame to remove signals from + """ + dataframe[SignalType.ENTER_LONG.value] = 0 + dataframe[SignalType.EXIT_LONG.value] = 0 + dataframe[SignalType.ENTER_SHORT.value] = 0 + dataframe[SignalType.EXIT_SHORT.value] = 0 + dataframe[SignalTagType.ENTER_TAG.value] = None + dataframe[SignalTagType.EXIT_TAG.value] = None + + return dataframe diff --git a/freqtrade/rpc/api_server/api_auth.py b/freqtrade/rpc/api_server/api_auth.py index a39e31b85..ee66fce2b 100644 --- a/freqtrade/rpc/api_server/api_auth.py +++ b/freqtrade/rpc/api_server/api_auth.py @@ -1,8 +1,10 @@ +import logging import secrets from datetime import datetime, timedelta +from typing import Any, Dict, Union import jwt -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, status from fastapi.security import OAuth2PasswordBearer from fastapi.security.http import HTTPBasic, HTTPBasicCredentials @@ -10,6 +12,8 @@ from freqtrade.rpc.api_server.api_schemas import AccessAndRefreshToken, AccessTo from freqtrade.rpc.api_server.deps import get_api_config +logger = logging.getLogger(__name__) + ALGORITHM = "HS256" router_login = APIRouter() @@ -25,7 +29,7 @@ httpbasic = HTTPBasic(auto_error=False) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) -def get_user_from_token(token, secret_key: str, token_type: str = "access"): +def get_user_from_token(token, secret_key: str, token_type: str = "access") -> str: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", @@ -44,6 +48,45 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access"): return username +# This should be reimplemented to better realign with the existing tools provided +# by FastAPI regarding API Tokens +# https://github.com/tiangolo/fastapi/blob/master/fastapi/security/api_key.py +async def validate_ws_token( + ws: WebSocket, + ws_token: Union[str, None] = Query(default=None, alias="token"), + api_config: Dict[str, Any] = Depends(get_api_config) +): + secret_ws_token = api_config.get('ws_token', None) + secret_jwt_key = api_config.get('jwt_secret_key', 'super-secret') + + # Check if ws_token is/in secret_ws_token + if ws_token and secret_ws_token: + is_valid_ws_token = False + if isinstance(secret_ws_token, str): + is_valid_ws_token = secrets.compare_digest(secret_ws_token, ws_token) + elif isinstance(secret_ws_token, list): + is_valid_ws_token = any([ + secrets.compare_digest(potential, ws_token) + for potential in secret_ws_token + ]) + + if is_valid_ws_token: + return ws_token + + # Check if ws_token is a JWT + try: + user = get_user_from_token(ws_token, secret_jwt_key) + return user + # If the token is a jwt, and it's valid return the user + except HTTPException: + pass + + # No checks passed, deny the connection + logger.debug("Denying websocket request.") + # If it doesn't match, close the websocket connection + await ws.close(code=status.WS_1008_POLICY_VIOLATION) + + def create_token(data: dict, secret_key: str, token_type: str = "access") -> str: to_encode = data.copy() if token_type == "access": diff --git a/freqtrade/rpc/api_server/api_v1.py b/freqtrade/rpc/api_server/api_v1.py index bf21715b7..53f5c16d7 100644 --- a/freqtrade/rpc/api_server/api_v1.py +++ b/freqtrade/rpc/api_server/api_v1.py @@ -38,7 +38,8 @@ logger = logging.getLogger(__name__) # 2.15: Add backtest history endpoints # 2.16: Additional daily metrics # 2.17: Forceentry - leverage, partial force_exit -API_VERSION = 2.17 +# 2.20: Add websocket endpoints +API_VERSION = 2.20 # Public API, requires no auth. router_public = APIRouter() diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py new file mode 100644 index 000000000..f55b2dbd3 --- /dev/null +++ b/freqtrade/rpc/api_server/api_ws.py @@ -0,0 +1,140 @@ +import logging +from typing import Any, Dict + +from fastapi import APIRouter, Depends, WebSocketDisconnect +from fastapi.websockets import WebSocket, WebSocketState +from pydantic import ValidationError + +from freqtrade.enums import RPCMessageType, RPCRequestType +from freqtrade.rpc.api_server.api_auth import validate_ws_token +from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc +from freqtrade.rpc.api_server.ws import WebSocketChannel +from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, + WSRequestSchema, WSWhitelistMessage) +from freqtrade.rpc.rpc import RPC + + +logger = logging.getLogger(__name__) + +# Private router, protected by API Key authentication +router = APIRouter() + + +async def is_websocket_alive(ws: WebSocket) -> bool: + """ + Check if a FastAPI Websocket is still open + """ + if ( + ws.application_state == WebSocketState.CONNECTED and + ws.client_state == WebSocketState.CONNECTED + ): + return True + return False + + +async def _process_consumer_request( + request: Dict[str, Any], + channel: WebSocketChannel, + rpc: RPC +): + """ + Validate and handle a request from a websocket consumer + """ + # Validate the request, makes sure it matches the schema + try: + websocket_request = WSRequestSchema.parse_obj(request) + except ValidationError as e: + logger.error(f"Invalid request from {channel}: {e}") + return + + type, data = websocket_request.type, websocket_request.data + response: WSMessageSchema + + logger.debug(f"Request of type {type} from {channel}") + + # If we have a request of type SUBSCRIBE, set the topics in this channel + if type == RPCRequestType.SUBSCRIBE: + # If the request is empty, do nothing + if not data: + return + + # If all topics passed are a valid RPCMessageType, set subscriptions on channel + if all([any(x.value == topic for x in RPCMessageType) for topic in data]): + channel.set_subscriptions(data) + + # We don't send a response for subscriptions + return + + elif type == RPCRequestType.WHITELIST: + # Get whitelist + whitelist = rpc._ws_request_whitelist() + + # Format response + response = WSWhitelistMessage(data=whitelist) + # Send it back + await channel.send(response.dict(exclude_none=True)) + + elif type == RPCRequestType.ANALYZED_DF: + limit = None + + if data: + # Limit the amount of candles per dataframe to 'limit' or 1500 + limit = max(data.get('limit', 1500), 1500) + + # They requested the full historical analyzed dataframes + analyzed_df = rpc._ws_request_analyzed_df(limit) + + # For every dataframe, send as a separate message + for _, message in analyzed_df.items(): + response = WSAnalyzedDFMessage(data=message) + await channel.send(response.dict(exclude_none=True)) + + +@router.websocket("/message/ws") +async def message_endpoint( + ws: WebSocket, + rpc: RPC = Depends(get_rpc), + channel_manager=Depends(get_channel_manager), + token: str = Depends(validate_ws_token) +): + """ + Message WebSocket endpoint, facilitates sending RPC messages + """ + try: + channel = await channel_manager.on_connect(ws) + + if await is_websocket_alive(ws): + + logger.info(f"Consumer connected - {channel}") + + # Keep connection open until explicitly closed, and process requests + try: + while not channel.is_closed(): + request = await channel.recv() + + # Process the request here + await _process_consumer_request(request, channel, rpc) + + except WebSocketDisconnect: + # Handle client disconnects + logger.info(f"Consumer disconnected - {channel}") + await channel_manager.on_disconnect(ws) + except Exception as e: + logger.info(f"Consumer connection failed - {channel}") + logger.exception(e) + # Handle cases like - + # RuntimeError('Cannot call "send" once a closed message has been sent') + await channel_manager.on_disconnect(ws) + + else: + await ws.close() + + except RuntimeError: + # WebSocket was closed + await channel_manager.on_disconnect(ws) + + except Exception as e: + logger.error(f"Failed to serve - {ws.client}") + # Log tracebacks to keep track of what errors are happening + logger.exception(e) + await channel_manager.on_disconnect(ws) diff --git a/freqtrade/rpc/api_server/deps.py b/freqtrade/rpc/api_server/deps.py index 66654c0b1..abd3db036 100644 --- a/freqtrade/rpc/api_server/deps.py +++ b/freqtrade/rpc/api_server/deps.py @@ -41,6 +41,10 @@ def get_exchange(config=Depends(get_config)): return ApiServer._exchange +def get_channel_manager(): + return ApiServer._ws_channel_manager + + def is_webserver_mode(config=Depends(get_config)): if config['runmode'] != RunMode.WEBSERVER: raise RPCException('Bot is not in the correct state') diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index 642f25e47..df4324740 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -1,16 +1,21 @@ +import asyncio import logging from ipaddress import IPv4Address +from threading import Thread from typing import Any, Dict import orjson import uvicorn from fastapi import Depends, FastAPI from fastapi.middleware.cors import CORSMiddleware +# Look into alternatives +from janus import Queue as ThreadedQueue from starlette.responses import JSONResponse from freqtrade.constants import Config from freqtrade.exceptions import OperationalException from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer +from freqtrade.rpc.api_server.ws import ChannelManager from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler @@ -44,6 +49,10 @@ class ApiServer(RPCHandler): _config: Config = {} # Exchange - only available in webserver mode. _exchange = None + # websocket message queue stuff + _ws_channel_manager = None + _ws_thread = None + _ws_loop = None def __new__(cls, *args, **kwargs): """ @@ -61,17 +70,21 @@ class ApiServer(RPCHandler): return self._standalone: bool = standalone self._server = None + self._ws_queue = None + self._ws_background_task = None + ApiServer.__initialized = True api_config = self._config['api_server'] + ApiServer._ws_channel_manager = ChannelManager() + self.app = FastAPI(title="Freqtrade API", docs_url='/docs' if api_config.get('enable_openapi', False) else None, redoc_url=None, default_response_class=FTJSONResponse, ) self.configure_app(self.app, self._config) - self.start_api() def add_rpc_handler(self, rpc: RPC): @@ -93,6 +106,19 @@ class ApiServer(RPCHandler): logger.info("Stopping API Server") self._server.cleanup() + if self._ws_thread and self._ws_loop: + logger.info("Stopping API Server background tasks") + + if self._ws_background_task: + # Cancel the queue task + self._ws_background_task.cancel() + + self._ws_thread.join() + + self._ws_thread = None + self._ws_loop = None + self._ws_background_task = None + @classmethod def shutdown(cls): cls.__initialized = False @@ -102,7 +128,9 @@ class ApiServer(RPCHandler): cls._rpc = None def send_msg(self, msg: Dict[str, str]) -> None: - pass + if self._ws_queue: + sync_q = self._ws_queue.sync_q + sync_q.put(msg) def handle_rpc_exception(self, request, exc): logger.exception(f"API Error calling: {exc}") @@ -116,6 +144,7 @@ class ApiServer(RPCHandler): from freqtrade.rpc.api_server.api_backtest import router as api_backtest from freqtrade.rpc.api_server.api_v1 import router as api_v1 from freqtrade.rpc.api_server.api_v1 import router_public as api_v1_public + from freqtrade.rpc.api_server.api_ws import router as ws_router from freqtrade.rpc.api_server.web_ui import router_ui app.include_router(api_v1_public, prefix="/api/v1") @@ -126,6 +155,7 @@ class ApiServer(RPCHandler): app.include_router(api_backtest, prefix="/api/v1", dependencies=[Depends(http_basic_or_jwt_token)], ) + app.include_router(ws_router, prefix="/api/v1") app.include_router(router_login, prefix="/api/v1", tags=["auth"]) # UI Router MUST be last! app.include_router(router_ui, prefix='') @@ -140,6 +170,48 @@ class ApiServer(RPCHandler): app.add_exception_handler(RPCException, self.handle_rpc_exception) + def start_message_queue(self): + if self._ws_thread: + return + + # Create a new loop, as it'll be just for the background thread + self._ws_loop = asyncio.new_event_loop() + + # Start the thread + self._ws_thread = Thread(target=self._ws_loop.run_forever) + self._ws_thread.start() + + # Finally, submit the coro to the thread + self._ws_background_task = asyncio.run_coroutine_threadsafe( + self._broadcast_queue_data(), loop=self._ws_loop) + + async def _broadcast_queue_data(self): + # Instantiate the queue in this coroutine so it's attached to our loop + self._ws_queue = ThreadedQueue() + async_queue = self._ws_queue.async_q + + try: + while True: + logger.debug("Getting queue messages...") + # Get data from queue + message = await async_queue.get() + logger.debug(f"Found message of type: {message.get('type')}") + # Broadcast it + await self._ws_channel_manager.broadcast(message) + # Sleep, make this configurable? + await asyncio.sleep(0.1) + except asyncio.CancelledError: + pass + + # For testing, shouldn't happen when stable + except Exception as e: + logger.exception(f"Exception happened in background task: {e}") + + finally: + # Disconnect channels and stop the loop on cancel + await self._ws_channel_manager.disconnect_all() + self._ws_loop.stop() + def start_api(self): """ Start API ... should be run in thread. @@ -177,6 +249,7 @@ class ApiServer(RPCHandler): if self._standalone: self._server.run() else: + self.start_message_queue() self._server.run_in_thread() except Exception: logger.exception("Api server failed to start.") diff --git a/freqtrade/rpc/api_server/ws/__init__.py b/freqtrade/rpc/api_server/ws/__init__.py new file mode 100644 index 000000000..055b20a9d --- /dev/null +++ b/freqtrade/rpc/api_server/ws/__init__.py @@ -0,0 +1,6 @@ +# flake8: noqa: F401 +# isort: off +from freqtrade.rpc.api_server.ws.types import WebSocketType +from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy +from freqtrade.rpc.api_server.ws.serializer import HybridJSONWebSocketSerializer +from freqtrade.rpc.api_server.ws.channel import ChannelManager, WebSocketChannel diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py new file mode 100644 index 000000000..cffe3092d --- /dev/null +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -0,0 +1,178 @@ +import logging +from threading import RLock +from typing import List, Optional, Type +from uuid import uuid4 + +from fastapi import WebSocket as FastAPIWebSocket + +from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy +from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer, + WebSocketSerializer) +from freqtrade.rpc.api_server.ws.types import WebSocketType + + +logger = logging.getLogger(__name__) + + +class WebSocketChannel: + """ + Object to help facilitate managing a websocket connection + """ + + def __init__( + self, + websocket: WebSocketType, + channel_id: Optional[str] = None, + serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer + ): + + self.channel_id = channel_id if channel_id else uuid4().hex[:8] + + # The WebSocket object + self._websocket = WebSocketProxy(websocket) + # The Serializing class for the WebSocket object + self._serializer_cls = serializer_cls + + self._subscriptions: List[str] = [] + + # Internal event to signify a closed websocket + self._closed = False + + # Wrap the WebSocket in the Serializing class + self._wrapped_ws = self._serializer_cls(self._websocket) + + def __repr__(self): + return f"WebSocketChannel({self.channel_id}, {self.remote_addr})" + + @property + def remote_addr(self): + return self._websocket.remote_addr + + async def send(self, data): + """ + Send data on the wrapped websocket + """ + await self._wrapped_ws.send(data) + + async def recv(self): + """ + Receive data on the wrapped websocket + """ + return await self._wrapped_ws.recv() + + async def ping(self): + """ + Ping the websocket + """ + return await self._websocket.ping() + + async def close(self): + """ + Close the WebSocketChannel + """ + + self._closed = True + + def is_closed(self) -> bool: + """ + Closed flag + """ + return self._closed + + def set_subscriptions(self, subscriptions: List[str] = []) -> None: + """ + Set which subscriptions this channel is subscribed to + + :param subscriptions: List of subscriptions, List[str] + """ + self._subscriptions = subscriptions + + def subscribed_to(self, message_type: str) -> bool: + """ + Check if this channel is subscribed to the message_type + + :param message_type: The message type to check + """ + return message_type in self._subscriptions + + +class ChannelManager: + def __init__(self): + self.channels = dict() + self._lock = RLock() # Re-entrant Lock + + async def on_connect(self, websocket: WebSocketType): + """ + Wrap websocket connection into Channel and add to list + + :param websocket: The WebSocket object to attach to the Channel + """ + if isinstance(websocket, FastAPIWebSocket): + try: + await websocket.accept() + except RuntimeError: + # The connection was closed before we could accept it + return + + ws_channel = WebSocketChannel(websocket) + + with self._lock: + self.channels[websocket] = ws_channel + + return ws_channel + + async def on_disconnect(self, websocket: WebSocketType): + """ + Call close on the channel if it's not, and remove from channel list + + :param websocket: The WebSocket objet attached to the Channel + """ + with self._lock: + channel = self.channels.get(websocket) + if channel: + if not channel.is_closed(): + await channel.close() + + del self.channels[websocket] + + async def disconnect_all(self): + """ + Disconnect all Channels + """ + with self._lock: + for websocket, channel in self.channels.items(): + if not channel.is_closed(): + await channel.close() + + self.channels = dict() + + async def broadcast(self, data): + """ + Broadcast data on all Channels + + :param data: The data to send + """ + with self._lock: + message_type = data.get('type') + for websocket, channel in self.channels.items(): + try: + if channel.subscribed_to(message_type): + await channel.send(data) + except RuntimeError: + # Handle cannot send after close cases + await self.on_disconnect(websocket) + + async def send_direct(self, channel, data): + """ + Send data directly through direct_channel only + + :param direct_channel: The WebSocketChannel object to send data through + :param data: The data to send + """ + await channel.send(data) + + def has_channels(self): + """ + Flag for more than 0 channels + """ + return len(self.channels) > 0 diff --git a/freqtrade/rpc/api_server/ws/proxy.py b/freqtrade/rpc/api_server/ws/proxy.py new file mode 100644 index 000000000..2e5a59f05 --- /dev/null +++ b/freqtrade/rpc/api_server/ws/proxy.py @@ -0,0 +1,69 @@ +from typing import Any, Tuple, Union + +from fastapi import WebSocket as FastAPIWebSocket +from websockets.client import WebSocketClientProtocol as WebSocket + +from freqtrade.rpc.api_server.ws.types import WebSocketType + + +class WebSocketProxy: + """ + WebSocketProxy object to bring the FastAPIWebSocket and websockets.WebSocketClientProtocol + under the same API + """ + + def __init__(self, websocket: WebSocketType): + self._websocket: Union[FastAPIWebSocket, WebSocket] = websocket + + @property + def remote_addr(self) -> Tuple[Any, ...]: + if isinstance(self._websocket, WebSocket): + return self._websocket.remote_address + elif isinstance(self._websocket, FastAPIWebSocket): + if self._websocket.client: + client, port = self._websocket.client.host, self._websocket.client.port + return (client, port) + return ("unknown", 0) + + async def send(self, data): + """ + Send data on the wrapped websocket + """ + if hasattr(self._websocket, "send_text"): + await self._websocket.send_text(data) + else: + await self._websocket.send(data) + + async def recv(self): + """ + Receive data on the wrapped websocket + """ + if hasattr(self._websocket, "receive_text"): + return await self._websocket.receive_text() + else: + return await self._websocket.recv() + + async def ping(self): + """ + Ping the websocket, not supported by FastAPI WebSockets + """ + if hasattr(self._websocket, "ping"): + return await self._websocket.ping() + return False + + async def close(self, code: int = 1000): + """ + Close the websocket connection, only supported by FastAPI WebSockets + """ + if hasattr(self._websocket, "close"): + try: + return await self._websocket.close(code) + except RuntimeError: + pass + + async def accept(self): + """ + Accept the WebSocket connection, only support by FastAPI WebSockets + """ + if hasattr(self._websocket, "accept"): + return await self._websocket.accept() diff --git a/freqtrade/rpc/api_server/ws/serializer.py b/freqtrade/rpc/api_server/ws/serializer.py new file mode 100644 index 000000000..6c402a100 --- /dev/null +++ b/freqtrade/rpc/api_server/ws/serializer.py @@ -0,0 +1,62 @@ +import logging +from abc import ABC, abstractmethod + +import orjson +import rapidjson +from pandas import DataFrame + +from freqtrade.misc import dataframe_to_json, json_to_dataframe +from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy + + +logger = logging.getLogger(__name__) + + +class WebSocketSerializer(ABC): + def __init__(self, websocket: WebSocketProxy): + self._websocket: WebSocketProxy = websocket + + @abstractmethod + def _serialize(self, data): + raise NotImplementedError() + + @abstractmethod + def _deserialize(self, data): + raise NotImplementedError() + + async def send(self, data: bytes): + await self._websocket.send(self._serialize(data)) + + async def recv(self) -> bytes: + data = await self._websocket.recv() + + return self._deserialize(data) + + async def close(self, code: int = 1000): + await self._websocket.close(code) + + +class HybridJSONWebSocketSerializer(WebSocketSerializer): + def _serialize(self, data) -> str: + return str(orjson.dumps(data, default=_json_default), "utf-8") + + def _deserialize(self, data: str): + # RapidJSON expects strings + return rapidjson.loads(data, object_hook=_json_object_hook) + + +# Support serializing pandas DataFrames +def _json_default(z): + if isinstance(z, DataFrame): + return { + '__type__': 'dataframe', + '__value__': dataframe_to_json(z) + } + raise TypeError + + +# Support deserializing JSON to pandas DataFrames +def _json_object_hook(z): + if z.get('__type__') == 'dataframe': + return json_to_dataframe(z.get('__value__')) + return z diff --git a/freqtrade/rpc/api_server/ws/types.py b/freqtrade/rpc/api_server/ws/types.py new file mode 100644 index 000000000..9855f9e06 --- /dev/null +++ b/freqtrade/rpc/api_server/ws/types.py @@ -0,0 +1,8 @@ +from typing import Any, Dict, TypeVar + +from fastapi import WebSocket as FastAPIWebSocket +from websockets.client import WebSocketClientProtocol as WebSocket + + +WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket) +MessageType = Dict[str, Any] diff --git a/freqtrade/rpc/api_server/ws_schemas.py b/freqtrade/rpc/api_server/ws_schemas.py new file mode 100644 index 000000000..255226d84 --- /dev/null +++ b/freqtrade/rpc/api_server/ws_schemas.py @@ -0,0 +1,63 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pandas import DataFrame +from pydantic import BaseModel + +from freqtrade.constants import PairWithTimeframe +from freqtrade.enums.rpcmessagetype import RPCMessageType, RPCRequestType + + +class BaseArbitraryModel(BaseModel): + class Config: + arbitrary_types_allowed = True + + +class WSRequestSchema(BaseArbitraryModel): + type: RPCRequestType + data: Optional[Any] = None + + +class WSMessageSchema(BaseArbitraryModel): + type: RPCMessageType + data: Optional[Any] = None + + class Config: + extra = 'allow' + + +# ------------------------------ REQUEST SCHEMAS ---------------------------- + + +class WSSubscribeRequest(WSRequestSchema): + type: RPCRequestType = RPCRequestType.SUBSCRIBE + data: List[RPCMessageType] + + +class WSWhitelistRequest(WSRequestSchema): + type: RPCRequestType = RPCRequestType.WHITELIST + data: None = None + + +class WSAnalyzedDFRequest(WSRequestSchema): + type: RPCRequestType = RPCRequestType.ANALYZED_DF + data: Dict[str, Any] = {"limit": 1500} + + +# ------------------------------ MESSAGE SCHEMAS ---------------------------- + +class WSWhitelistMessage(WSMessageSchema): + type: RPCMessageType = RPCMessageType.WHITELIST + data: List[str] + + +class WSAnalyzedDFMessage(WSMessageSchema): + class AnalyzedDFData(BaseArbitraryModel): + key: PairWithTimeframe + df: DataFrame + la: datetime + + type: RPCMessageType = RPCMessageType.ANALYZED_DF + data: AnalyzedDFData + +# -------------------------------------------------------------------------- diff --git a/freqtrade/rpc/external_message_consumer.py b/freqtrade/rpc/external_message_consumer.py new file mode 100644 index 000000000..a57fac144 --- /dev/null +++ b/freqtrade/rpc/external_message_consumer.py @@ -0,0 +1,341 @@ +""" +ExternalMessageConsumer module + +Main purpose is to connect to external bot's message websocket to consume data +from it +""" +import asyncio +import logging +import socket +from threading import Thread +from typing import TYPE_CHECKING, Any, Callable, Dict, List + +import websockets +from pydantic import ValidationError + +from freqtrade.data.dataprovider import DataProvider +from freqtrade.enums import RPCMessageType +from freqtrade.exceptions import OperationalException +from freqtrade.misc import remove_entry_exit_signals +from freqtrade.rpc.api_server.ws import WebSocketChannel +from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSAnalyzedDFRequest, + WSMessageSchema, WSRequestSchema, + WSSubscribeRequest, WSWhitelistMessage, + WSWhitelistRequest) + + +if TYPE_CHECKING: + import websockets.connect + import websockets.exceptions + + +logger = logging.getLogger(__name__) + + +class ExternalMessageConsumer: + """ + The main controller class for consuming external messages from + other freqtrade bot's + """ + + def __init__( + self, + config: Dict[str, Any], + dataprovider: DataProvider + ): + self._config = config + self._dp = dataprovider + + self._running = False + self._thread = None + self._loop = None + self._main_task = None + self._sub_tasks = None + + self._emc_config = self._config.get('external_message_consumer', {}) + + self.enabled = self._emc_config.get('enabled', False) + self.producers = self._emc_config.get('producers', []) + + self.wait_timeout = self._emc_config.get('wait_timeout', 300) # in seconds + self.ping_timeout = self._emc_config.get('ping_timeout', 10) # in seconds + self.sleep_time = self._emc_config.get('sleep_time', 10) # in seconds + + # The amount of candles per dataframe on the initial request + self.initial_candle_limit = self._emc_config.get('initial_candle_limit', 1500) + + # Message size limit, in megabytes. Default 8mb, Use bitwise operator << 20 to convert + # as the websockets client expects bytes. + self.message_size_limit = (self._emc_config.get('message_size_limit', 8) << 20) + + self.validate_config() + + # Setting these explicitly as they probably shouldn't be changed by a user + # Unless we somehow integrate this with the strategy to allow creating + # callbacks for the messages + self.topics = [RPCMessageType.WHITELIST, RPCMessageType.ANALYZED_DF] + + # Allow setting data for each initial request + self._initial_requests: List[WSRequestSchema] = [ + WSSubscribeRequest(data=self.topics), + WSWhitelistRequest(), + WSAnalyzedDFRequest() + ] + + # Specify which function to use for which RPCMessageType + self._message_handlers: Dict[str, Callable[[str, WSMessageSchema], None]] = { + RPCMessageType.WHITELIST: self._consume_whitelist_message, + RPCMessageType.ANALYZED_DF: self._consume_analyzed_df_message, + } + + self.start() + + def validate_config(self): + """ + Make sure values are what they are supposed to be + """ + if self.enabled and len(self.producers) < 1: + raise OperationalException("You must specify at least 1 Producer to connect to.") + + if self.enabled and self._config.get('process_only_new_candles', True): + # Warning here or require it? + logger.warning("To receive best performance with external data, " + "please set `process_only_new_candles` to False") + + def start(self): + """ + Start the main internal loop in another thread to run coroutines + """ + if self._thread and self._loop: + return + + logger.info("Starting ExternalMessageConsumer") + + self._loop = asyncio.new_event_loop() + self._thread = Thread(target=self._loop.run_forever) + self._running = True + self._thread.start() + + self._main_task = asyncio.run_coroutine_threadsafe(self._main(), loop=self._loop) + + def shutdown(self): + """ + Shutdown the loop, thread, and tasks + """ + if self._thread and self._loop: + logger.info("Stopping ExternalMessageConsumer") + self._running = False + + if self._sub_tasks: + # Cancel sub tasks + for task in self._sub_tasks: + task.cancel() + + if self._main_task: + # Cancel the main task + self._main_task.cancel() + + self._thread.join() + + self._thread = None + self._loop = None + self._sub_tasks = None + self._main_task = None + + async def _main(self): + """ + The main task coroutine + """ + lock = asyncio.Lock() + + try: + # Create a connection to each producer + self._sub_tasks = [ + self._loop.create_task(self._handle_producer_connection(producer, lock)) + for producer in self.producers + ] + + await asyncio.gather(*self._sub_tasks) + except asyncio.CancelledError: + pass + finally: + # Stop the loop once we are done + self._loop.stop() + + async def _handle_producer_connection(self, producer: Dict[str, Any], lock: asyncio.Lock): + """ + Main connection loop for the consumer + + :param producer: Dictionary containing producer info + :param lock: An asyncio Lock + """ + try: + await self._create_connection(producer, lock) + except asyncio.CancelledError: + # Exit silently + pass + + async def _create_connection(self, producer: Dict[str, Any], lock: asyncio.Lock): + """ + Actually creates and handles the websocket connection, pinging on timeout + and handling connection errors. + + :param producer: Dictionary containing producer info + :param lock: An asyncio Lock + """ + while self._running: + try: + host, port = producer['host'], producer['port'] + token = producer['ws_token'] + name = producer['name'] + ws_url = f"ws://{host}:{port}/api/v1/message/ws?token={token}" + + # This will raise InvalidURI if the url is bad + async with websockets.connect(ws_url, max_size=self.message_size_limit) as ws: + channel = WebSocketChannel(ws, channel_id=name) + + logger.info(f"Producer connection success - {channel}") + + # Now request the initial data from this Producer + for request in self._initial_requests: + await channel.send( + request.dict(exclude_none=True) + ) + + # Now receive data, if none is within the time limit, ping + await self._receive_messages(channel, producer, lock) + + except (websockets.exceptions.InvalidURI, ValueError) as e: + logger.error(f"{ws_url} is an invalid WebSocket URL - {e}") + break + + except ( + socket.gaierror, + ConnectionRefusedError, + websockets.exceptions.InvalidStatusCode, + websockets.exceptions.InvalidMessage + ) as e: + logger.error(f"Connection Refused - {e} retrying in {self.sleep_time}s") + await asyncio.sleep(self.sleep_time) + + continue + + except websockets.exceptions.ConnectionClosedOK: + # Successfully closed, just keep trying to connect again indefinitely + continue + + except Exception as e: + # An unforseen error has occurred, log and continue + logger.error("Unexpected error has occurred:") + logger.exception(e) + continue + + async def _receive_messages( + self, + channel: WebSocketChannel, + producer: Dict[str, Any], + lock: asyncio.Lock + ): + """ + Loop to handle receiving messages from a Producer + + :param channel: The WebSocketChannel object for the WebSocket + :param producer: Dictionary containing producer info + :param lock: An asyncio Lock + """ + while self._running: + try: + message = await asyncio.wait_for( + channel.recv(), + timeout=self.wait_timeout + ) + + try: + async with lock: + # Handle the message + self.handle_producer_message(producer, message) + except Exception as e: + logger.exception(f"Error handling producer message: {e}") + + except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed): + # We haven't received data yet. Check the connection and continue. + try: + # ping + ping = await channel.ping() + + await asyncio.wait_for(ping, timeout=self.ping_timeout) + logger.debug(f"Connection to {channel} still alive...") + + continue + except Exception as e: + logger.warning(f"Ping error {channel} - retrying in {self.sleep_time}s") + logger.debug(e, exc_info=e) + await asyncio.sleep(self.sleep_time) + + break + + def handle_producer_message(self, producer: Dict[str, Any], message: Dict[str, Any]): + """ + Handles external messages from a Producer + """ + producer_name = producer.get('name', 'default') + + try: + producer_message = WSMessageSchema.parse_obj(message) + except ValidationError as e: + logger.error(f"Invalid message from `{producer_name}`: {e}") + return + + if not producer_message.data: + logger.error(f"Empty message received from `{producer_name}`") + return + + logger.info(f"Received message of type `{producer_message.type}` from `{producer_name}`") + + message_handler = self._message_handlers.get(producer_message.type) + + if not message_handler: + logger.info(f"Received unhandled message: `{producer_message.data}`, ignoring...") + return + + message_handler(producer_name, producer_message) + + def _consume_whitelist_message(self, producer_name: str, message: WSMessageSchema): + try: + # Validate the message + whitelist_message = WSWhitelistMessage.parse_obj(message) + except ValidationError as e: + logger.error(f"Invalid message from `{producer_name}`: {e}") + return + + # Add the pairlist data to the DataProvider + self._dp._set_producer_pairs(whitelist_message.data, producer_name=producer_name) + + logger.debug(f"Consumed message from `{producer_name}` of type `RPCMessageType.WHITELIST`") + + def _consume_analyzed_df_message(self, producer_name: str, message: WSMessageSchema): + try: + df_message = WSAnalyzedDFMessage.parse_obj(message) + except ValidationError as e: + logger.error(f"Invalid message from `{producer_name}`: {e}") + return + + key = df_message.data.key + df = df_message.data.df + la = df_message.data.la + + pair, timeframe, candle_type = key + + # If set, remove the Entry and Exit signals from the Producer + if self._emc_config.get('remove_entry_exit_signals', False): + df = remove_entry_exit_signals(df) + + # Add the dataframe to the dataprovider + self._dp._add_external_df(pair, df, + last_analyzed=la, + timeframe=timeframe, + candle_type=candle_type, + producer_name=producer_name) + + logger.debug( + f"Consumed message from `{producer_name}` of type `RPCMessageType.ANALYZED_DF`") diff --git a/freqtrade/rpc/rpc.py b/freqtrade/rpc/rpc.py index 6602cdd35..57fc7f473 100644 --- a/freqtrade/rpc/rpc.py +++ b/freqtrade/rpc/rpc.py @@ -1039,14 +1039,52 @@ class RPC: def _rpc_analysed_dataframe(self, pair: str, timeframe: str, limit: Optional[int]) -> Dict[str, Any]: + """ Analyzed dataframe in Dict form """ + _data, last_analyzed = self.__rpc_analysed_dataframe_raw(pair, timeframe, limit) + return self._convert_dataframe_to_dict(self._freqtrade.config['strategy'], + pair, timeframe, _data, last_analyzed) + + def __rpc_analysed_dataframe_raw(self, pair: str, timeframe: str, + limit: Optional[int]) -> Tuple[DataFrame, datetime]: + """ Get the dataframe and last analyze from the dataprovider """ _data, last_analyzed = self._freqtrade.dataprovider.get_analyzed_dataframe( pair, timeframe) _data = _data.copy() + if limit: _data = _data.iloc[-limit:] - return self._convert_dataframe_to_dict(self._freqtrade.config['strategy'], - pair, timeframe, _data, last_analyzed) + return _data, last_analyzed + + def _ws_all_analysed_dataframes( + self, + pairlist: List[str], + limit: Optional[int] + ) -> Dict[str, Any]: + """ Get the analysed dataframes of each pair in the pairlist """ + timeframe = self._freqtrade.config['timeframe'] + candle_type = self._freqtrade.config.get('candle_type_def', CandleType.SPOT) + _data = {} + + for pair in pairlist: + dataframe, last_analyzed = self.__rpc_analysed_dataframe_raw(pair, timeframe, limit) + + _data[pair] = { + "key": (pair, timeframe, candle_type), + "df": dataframe, + "la": last_analyzed + } + + return _data + + def _ws_request_analyzed_df(self, limit: Optional[int]): + """ Historical Analyzed Dataframes for WebSocket """ + whitelist = self._freqtrade.active_pair_whitelist + return self._ws_all_analysed_dataframes(whitelist, limit) + + def _ws_request_whitelist(self): + """ Whitelist data for WebSocket """ + return self._freqtrade.active_pair_whitelist @staticmethod def _rpc_analysed_history_full(config, pair: str, timeframe: str, diff --git a/freqtrade/rpc/rpc_manager.py b/freqtrade/rpc/rpc_manager.py index 8390e61aa..e286487ff 100644 --- a/freqtrade/rpc/rpc_manager.py +++ b/freqtrade/rpc/rpc_manager.py @@ -67,7 +67,8 @@ class RPCManager: 'status': 'stopping bot' } """ - logger.info('Sending rpc message: %s', msg) + if msg.get('type') is not RPCMessageType.ANALYZED_DF: + logger.info('Sending rpc message: %s', msg) if 'pair' in msg: msg.update({ 'base_currency': self._rpc._freqtrade.exchange.get_pair_base_currency(msg['pair']) diff --git a/freqtrade/strategy/interface.py b/freqtrade/strategy/interface.py index 5e765e85b..8f803045f 100644 --- a/freqtrade/strategy/interface.py +++ b/freqtrade/strategy/interface.py @@ -16,6 +16,7 @@ from freqtrade.enums import (CandleType, ExitCheckTuple, ExitType, RunMode, Sign SignalTagType, SignalType, TradingMode) from freqtrade.exceptions import OperationalException, StrategyError from freqtrade.exchange import timeframe_to_minutes, timeframe_to_next_date, timeframe_to_seconds +from freqtrade.misc import remove_entry_exit_signals from freqtrade.persistence import Order, PairLocks, Trade from freqtrade.strategy.hyper import HyperStrategyMixin from freqtrade.strategy.informative_decorator import (InformativeData, PopulateIndicators, @@ -742,20 +743,19 @@ class IStrategy(ABC, HyperStrategyMixin): # always run if process_only_new_candles is set to false if (not self.process_only_new_candles or self._last_candle_seen_per_pair.get(pair, None) != dataframe.iloc[-1]['date']): + # Defs that only make change on new candle data. dataframe = self.analyze_ticker(dataframe, metadata) + self._last_candle_seen_per_pair[pair] = dataframe.iloc[-1]['date'] - self.dp._set_cached_df( - pair, self.timeframe, dataframe, - candle_type=self.config.get('candle_type_def', CandleType.SPOT)) + + candle_type = self.config.get('candle_type_def', CandleType.SPOT) + self.dp._set_cached_df(pair, self.timeframe, dataframe, candle_type=candle_type) + self.dp._emit_df((pair, self.timeframe, candle_type), dataframe) + else: logger.debug("Skipping TA Analysis for already analyzed candle") - dataframe[SignalType.ENTER_LONG.value] = 0 - dataframe[SignalType.EXIT_LONG.value] = 0 - dataframe[SignalType.ENTER_SHORT.value] = 0 - dataframe[SignalType.EXIT_SHORT.value] = 0 - dataframe[SignalTagType.ENTER_TAG.value] = None - dataframe[SignalTagType.EXIT_TAG.value] = None + dataframe = remove_entry_exit_signals(dataframe) logger.debug("Loop Analysis Launched") diff --git a/freqtrade/templates/base_config.json.j2 b/freqtrade/templates/base_config.json.j2 index 681af84c6..299734a50 100644 --- a/freqtrade/templates/base_config.json.j2 +++ b/freqtrade/templates/base_config.json.j2 @@ -67,6 +67,7 @@ "verbosity": "error", "enable_openapi": false, "jwt_secret_key": "{{ api_server_jwt_key }}", + "ws_token": "{{ api_server_ws_token }}", "CORS_origins": [], "username": "{{ api_server_username }}", "password": "{{ api_server_password }}" diff --git a/mkdocs.yml b/mkdocs.yml index 257db7867..fd0280e83 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -35,6 +35,7 @@ nav: - Advanced Post-installation Tasks: advanced-setup.md - Advanced Strategy: strategy-advanced.md - Advanced Hyperopt: advanced-hyperopt.md + - Producer/Consumer mode: producer-consumer.md - FreqAI: freqai.md - Edge Positioning: edge.md - Sandbox Testing: sandbox-testing.md diff --git a/requirements.txt b/requirements.txt index e68d0e295..690e33a09 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,3 +50,7 @@ python-dateutil==2.8.2 #Futures schedule==1.1.0 + +#WS Messages +websockets==10.3 +janus==1.0.0 diff --git a/setup.py b/setup.py index 8f04e75f7..2e6e354b0 100644 --- a/setup.py +++ b/setup.py @@ -79,7 +79,9 @@ setup( 'psutil', 'pyjwt', 'aiofiles', - 'schedule' + 'schedule', + 'websockets', + 'janus' ], extras_require={ 'dev': all_extra, diff --git a/tests/conftest.py b/tests/conftest.py index 4039f9367..51b1b03e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,6 +58,11 @@ def log_has(line, logs): return any(line == message for message in logs.messages) +def log_has_when(line, logs, when): + """Check if line is found in caplog's messages during a specified stage""" + return any(line == message.message for message in logs.get_records(when)) + + def log_has_re(line, logs): """Check if line matches some caplog's message.""" return any(re.match(line, message) for message in logs.messages) diff --git a/tests/data/test_dataprovider.py b/tests/data/test_dataprovider.py index 49603feac..8500fa06c 100644 --- a/tests/data/test_dataprovider.py +++ b/tests/data/test_dataprovider.py @@ -144,6 +144,77 @@ def test_available_pairs(mocker, default_conf, ohlcv_history): assert dp.available_pairs == [("XRP/BTC", timeframe), ("UNITTEST/BTC", timeframe), ] +def test_producer_pairs(mocker, default_conf, ohlcv_history): + dataprovider = DataProvider(default_conf, None) + + producer = "default" + whitelist = ["XRP/BTC", "ETH/BTC"] + assert len(dataprovider.get_producer_pairs(producer)) == 0 + + dataprovider._set_producer_pairs(whitelist, producer) + assert len(dataprovider.get_producer_pairs(producer)) == 2 + + new_whitelist = ["BTC/USDT"] + dataprovider._set_producer_pairs(new_whitelist, producer) + assert dataprovider.get_producer_pairs(producer) == new_whitelist + + assert dataprovider.get_producer_pairs("bad") == [] + + +def test_get_producer_df(mocker, default_conf, ohlcv_history): + dataprovider = DataProvider(default_conf, None) + + pair = 'BTC/USDT' + timeframe = default_conf['timeframe'] + candle_type = CandleType.SPOT + + empty_la = datetime.fromtimestamp(0, tz=timezone.utc) + now = datetime.now(timezone.utc) + + # no data has been added, any request should return an empty dataframe + dataframe, la = dataprovider.get_producer_df(pair, timeframe, candle_type) + assert dataframe.empty + assert la == empty_la + + # the data is added, should return that added dataframe + dataprovider._add_external_df(pair, ohlcv_history, now, timeframe, candle_type) + dataframe, la = dataprovider.get_producer_df(pair, timeframe, candle_type) + assert len(dataframe) > 0 + assert la > empty_la + + # no data on this producer, should return empty dataframe + dataframe, la = dataprovider.get_producer_df(pair, producer_name='bad') + assert dataframe.empty + assert la == empty_la + + # non existent timeframe, empty dataframe + datframe, la = dataprovider.get_producer_df(pair, timeframe='1h') + assert dataframe.empty + assert la == empty_la + + +def test_emit_df(mocker, default_conf, ohlcv_history): + mocker.patch('freqtrade.rpc.rpc_manager.RPCManager.__init__', MagicMock()) + rpc_mock = mocker.patch('freqtrade.rpc.rpc_manager.RPCManager', MagicMock()) + send_mock = mocker.patch('freqtrade.rpc.rpc_manager.RPCManager.send_msg', MagicMock()) + + dataprovider = DataProvider(default_conf, exchange=None, rpc=rpc_mock) + dataprovider_no_rpc = DataProvider(default_conf, exchange=None) + + pair = "BTC/USDT" + + # No emit yet + assert send_mock.call_count == 0 + + # Rpc is added, we call emit, should call send_msg + dataprovider._emit_df(pair, ohlcv_history) + assert send_mock.call_count == 1 + + # No rpc added, emit called, should not call send_msg + dataprovider_no_rpc._emit_df(pair, ohlcv_history) + assert send_mock.call_count == 1 + + def test_refresh(mocker, default_conf, ohlcv_history): refresh_mock = MagicMock() mocker.patch("freqtrade.exchange.Exchange.refresh_latest_ohlcv", refresh_mock) diff --git a/tests/rpc/test_rpc_apiserver.py b/tests/rpc/test_rpc_apiserver.py index 5dfa77d8b..e007e0a9e 100644 --- a/tests/rpc/test_rpc_apiserver.py +++ b/tests/rpc/test_rpc_apiserver.py @@ -3,6 +3,8 @@ Unit test file for rpc/api_server.py """ import json +import logging +import time from datetime import datetime, timedelta, timezone from pathlib import Path from unittest.mock import ANY, MagicMock, PropertyMock @@ -10,7 +12,7 @@ from unittest.mock import ANY, MagicMock, PropertyMock import pandas as pd import pytest import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, WebSocketDisconnect from fastapi.exceptions import HTTPException from fastapi.testclient import TestClient from requests.auth import _basic_auth_str @@ -31,6 +33,7 @@ from tests.conftest import (CURRENT_TEST_STRATEGY, create_mock_trades, get_mock_ BASE_URI = "/api/v1" _TEST_USER = "FreqTrader" _TEST_PASS = "SuperSecurePassword1!" +_TEST_WS_TOKEN = "secret_Ws_t0ken" @pytest.fixture @@ -44,17 +47,21 @@ def botclient(default_conf, mocker): "CORS_origins": ['http://example.com'], "username": _TEST_USER, "password": _TEST_PASS, + "ws_token": _TEST_WS_TOKEN }}) ftbot = get_patched_freqtradebot(mocker, default_conf) rpc = RPC(ftbot) mocker.patch('freqtrade.rpc.api_server.ApiServer.start_api', MagicMock()) + apiserver = None try: apiserver = ApiServer(default_conf) apiserver.add_rpc_handler(rpc) yield ftbot, TestClient(apiserver.app) # Cleanup ... ? finally: + if apiserver: + apiserver.cleanup() ApiServer.shutdown() @@ -154,6 +161,25 @@ def test_api_auth(): get_user_from_token(b'not_a_token', 'secret1234') +def test_api_ws_auth(botclient): + ftbot, client = botclient + def url(token): return f"/api/v1/message/ws?token={token}" + + bad_token = "bad-ws_token" + with pytest.raises(WebSocketDisconnect): + with client.websocket_connect(url(bad_token)) as websocket: + websocket.receive() + + good_token = _TEST_WS_TOKEN + with client.websocket_connect(url(good_token)) as websocket: + pass + + jwt_secret = ftbot.config['api_server'].get('jwt_secret_key', 'super-secret') + jwt_token = create_token({'identity': {'u': 'Freqtrade'}}, jwt_secret) + with client.websocket_connect(url(jwt_token)) as websocket: + pass + + def test_api_unauthorized(botclient): ftbot, client = botclient rc = client.get(f"{BASE_URI}/ping") @@ -261,6 +287,7 @@ def test_api__init__(default_conf, mocker): with pytest.raises(OperationalException, match="RPC Handler already attached."): apiserver.add_rpc_handler(RPC(get_patched_freqtradebot(mocker, default_conf))) + apiserver.cleanup() ApiServer.shutdown() @@ -388,6 +415,7 @@ def test_api_run(default_conf, mocker, caplog): MagicMock(side_effect=Exception)) apiserver.start_api() assert log_has("Api server failed to start.", caplog) + apiserver.cleanup() ApiServer.shutdown() @@ -410,6 +438,7 @@ def test_api_cleanup(default_conf, mocker, caplog): apiserver.cleanup() assert apiserver._server.cleanup.call_count == 1 assert log_has("Stopping API Server", caplog) + assert log_has("Stopping API Server background tasks", caplog) ApiServer.shutdown() @@ -1663,3 +1692,93 @@ def test_health(botclient): ret = rc.json() assert ret['last_process_ts'] == 0 assert ret['last_process'] == '1970-01-01T00:00:00+00:00' + + +def test_api_ws_subscribe(botclient, mocker): + ftbot, client = botclient + ws_url = f"/api/v1/message/ws?token={_TEST_WS_TOKEN}" + + sub_mock = mocker.patch('freqtrade.rpc.api_server.ws.WebSocketChannel.set_subscriptions') + + with client.websocket_connect(ws_url) as ws: + ws.send_json({'type': 'subscribe', 'data': ['whitelist']}) + + # Check call count is now 1 as we sent a valid subscribe request + assert sub_mock.call_count == 1 + + with client.websocket_connect(ws_url) as ws: + ws.send_json({'type': 'subscribe', 'data': 'whitelist'}) + + # Call count hasn't changed as the subscribe request was invalid + assert sub_mock.call_count == 1 + + +def test_api_ws_requests(botclient, mocker, caplog): + caplog.set_level(logging.DEBUG) + + ftbot, client = botclient + ws_url = f"/api/v1/message/ws?token={_TEST_WS_TOKEN}" + + # Test whitelist request + with client.websocket_connect(ws_url) as ws: + ws.send_json({"type": "whitelist", "data": None}) + response = ws.receive_json() + + assert log_has_re(r"Request of type whitelist from.+", caplog) + assert response['type'] == "whitelist" + + # Test analyzed_df request + with client.websocket_connect(ws_url) as ws: + ws.send_json({"type": "analyzed_df", "data": {}}) + response = ws.receive_json() + + assert log_has_re(r"Request of type analyzed_df from.+", caplog) + assert response['type'] == "analyzed_df" + + caplog.clear() + # Test analyzed_df request with data + with client.websocket_connect(ws_url) as ws: + ws.send_json({"type": "analyzed_df", "data": {"limit": 100}}) + response = ws.receive_json() + + assert log_has_re(r"Request of type analyzed_df from.+", caplog) + assert response['type'] == "analyzed_df" + + +def test_api_ws_send_msg(default_conf, mocker, caplog): + try: + caplog.set_level(logging.DEBUG) + + default_conf.update({"api_server": {"enabled": True, + "listen_ip_address": "127.0.0.1", + "listen_port": 8080, + "CORS_origins": ['http://example.com'], + "username": _TEST_USER, + "password": _TEST_PASS, + "ws_token": _TEST_WS_TOKEN + }}) + mocker.patch('freqtrade.rpc.telegram.Updater') + mocker.patch('freqtrade.rpc.api_server.ApiServer.start_api') + apiserver = ApiServer(default_conf) + apiserver.add_rpc_handler(RPC(get_patched_freqtradebot(mocker, default_conf))) + apiserver.start_message_queue() + # Give the queue thread time to start + time.sleep(0.2) + + # Test message_queue coro receives the message + test_message = {"type": "status", "data": "test"} + apiserver.send_msg(test_message) + time.sleep(0.1) # Not sure how else to wait for the coro to receive the data + assert log_has("Found message of type: status", caplog) + + # Test if exception logged when error occurs in sending + mocker.patch('freqtrade.rpc.api_server.ws.channel.ChannelManager.broadcast', + side_effect=Exception) + + apiserver.send_msg(test_message) + time.sleep(0.1) # Not sure how else to wait for the coro to receive the data + assert log_has_re(r"Exception happened in background task.*", caplog) + + finally: + apiserver.cleanup() + ApiServer.shutdown() diff --git a/tests/rpc/test_rpc_emc.py b/tests/rpc/test_rpc_emc.py new file mode 100644 index 000000000..9aca88b4a --- /dev/null +++ b/tests/rpc/test_rpc_emc.py @@ -0,0 +1,460 @@ +""" +Unit test file for rpc/external_message_consumer.py +""" +import asyncio +import functools +import logging +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest +import websockets + +from freqtrade.data.dataprovider import DataProvider +from freqtrade.exceptions import OperationalException +from freqtrade.rpc.external_message_consumer import ExternalMessageConsumer +from tests.conftest import log_has, log_has_re, log_has_when + + +_TEST_WS_TOKEN = "secret_Ws_t0ken" +_TEST_WS_HOST = "127.0.0.1" +_TEST_WS_PORT = 9989 + + +@pytest.fixture +def patched_emc(default_conf, mocker): + default_conf.update({ + "external_message_consumer": { + "enabled": True, + "producers": [ + { + "name": "default", + "host": "null", + "port": 9891, + "ws_token": _TEST_WS_TOKEN + } + ] + } + }) + dataprovider = DataProvider(default_conf, None, None, None) + emc = ExternalMessageConsumer(default_conf, dataprovider) + + try: + yield emc + finally: + emc.shutdown() + + +def test_emc_start(patched_emc, caplog): + # Test if the message was printed + assert log_has_when("Starting ExternalMessageConsumer", caplog, "setup") + # Test if the thread and loop objects were created + assert patched_emc._thread and patched_emc._loop + + # Test we call start again nothing happens + prev_thread = patched_emc._thread + patched_emc.start() + assert prev_thread == patched_emc._thread + + +def test_emc_shutdown(patched_emc, caplog): + patched_emc.shutdown() + + assert log_has("Stopping ExternalMessageConsumer", caplog) + # Test the loop has stopped + assert patched_emc._loop is None + # Test if the thread has stopped + assert patched_emc._thread is None + + caplog.clear() + patched_emc.shutdown() + + # Test func didn't run again as it was called once already + assert not log_has("Stopping ExternalMessageConsumer", caplog) + + +def test_emc_init(patched_emc, default_conf): + # Test the settings were set correctly + assert patched_emc.initial_candle_limit <= 1500 + assert patched_emc.wait_timeout > 0 + assert patched_emc.sleep_time > 0 + + default_conf.update({ + "external_message_consumer": { + "enabled": True, + "producers": [] + } + }) + dataprovider = DataProvider(default_conf, None, None, None) + with pytest.raises(OperationalException, + match="You must specify at least 1 Producer to connect to."): + ExternalMessageConsumer(default_conf, dataprovider) + + +# Parametrize this? +def test_emc_handle_producer_message(patched_emc, caplog, ohlcv_history): + test_producer = {"name": "test", "url": "ws://test", "ws_token": "test"} + producer_name = test_producer['name'] + + caplog.set_level(logging.DEBUG) + + # Test handle whitelist message + whitelist_message = {"type": "whitelist", "data": ["BTC/USDT"]} + patched_emc.handle_producer_message(test_producer, whitelist_message) + + assert log_has(f"Received message of type `whitelist` from `{producer_name}`", caplog) + assert log_has( + f"Consumed message from `{producer_name}` of type `RPCMessageType.WHITELIST`", caplog) + + # Test handle analyzed_df message + df_message = { + "type": "analyzed_df", + "data": { + "key": ("BTC/USDT", "5m", "spot"), + "df": ohlcv_history, + "la": datetime.now(timezone.utc) + } + } + patched_emc.handle_producer_message(test_producer, df_message) + + assert log_has(f"Received message of type `analyzed_df` from `{producer_name}`", caplog) + assert log_has( + f"Consumed message from `{producer_name}` of type `RPCMessageType.ANALYZED_DF`", caplog) + + # Test unhandled message + unhandled_message = {"type": "status", "data": "RUNNING"} + patched_emc.handle_producer_message(test_producer, unhandled_message) + + assert log_has_re(r"Received unhandled message\: .*", caplog) + + # Test malformed messages + caplog.clear() + malformed_message = {"type": "whitelist", "data": {"pair": "BTC/USDT"}} + patched_emc.handle_producer_message(test_producer, malformed_message) + + assert log_has_re(r"Invalid message .+", caplog) + + malformed_message = { + "type": "analyzed_df", + "data": { + "key": "BTC/USDT", + "df": ohlcv_history, + "la": datetime.now(timezone.utc) + } + } + patched_emc.handle_producer_message(test_producer, malformed_message) + + assert log_has(f"Received message of type `analyzed_df` from `{producer_name}`", caplog) + assert log_has_re(r"Invalid message .+", caplog) + + caplog.clear() + malformed_message = {"some": "stuff"} + patched_emc.handle_producer_message(test_producer, malformed_message) + + assert log_has_re(r"Invalid message .+", caplog) + + caplog.clear() + malformed_message = {"type": "whitelist", "data": None} + patched_emc.handle_producer_message(test_producer, malformed_message) + + assert log_has_re(r"Empty message .+", caplog) + + +async def test_emc_create_connection_success(default_conf, caplog, mocker): + default_conf.update({ + "external_message_consumer": { + "enabled": True, + "producers": [ + { + "name": "default", + "host": _TEST_WS_HOST, + "port": _TEST_WS_PORT, + "ws_token": _TEST_WS_TOKEN + } + ], + "wait_timeout": 60, + "ping_timeout": 60, + "sleep_timeout": 60 + } + }) + + mocker.patch('freqtrade.rpc.external_message_consumer.ExternalMessageConsumer.start', + MagicMock()) + dp = DataProvider(default_conf, None, None, None) + emc = ExternalMessageConsumer(default_conf, dp) + + test_producer = default_conf['external_message_consumer']['producers'][0] + lock = asyncio.Lock() + + emc._running = True + + async def eat(websocket): + emc._running = False + + try: + async with websockets.serve(eat, _TEST_WS_HOST, _TEST_WS_PORT): + await emc._create_connection(test_producer, lock) + + assert log_has_re(r"Producer connection success.+", caplog) + finally: + emc.shutdown() + + +# async def test_emc_create_connection_invalid(default_conf, caplog, mocker): +# default_conf.update({ +# "external_message_consumer": { +# "enabled": True, +# "producers": [ +# { +# "name": "default", +# "host": _TEST_WS_HOST, +# "port": _TEST_WS_PORT, +# "ws_token": _TEST_WS_TOKEN +# } +# ], +# "wait_timeout": 60, +# "ping_timeout": 60, +# "sleep_timeout": 60 +# } +# }) +# +# mocker.patch('freqtrade.rpc.external_message_consumer.ExternalMessageConsumer.start', +# MagicMock()) +# +# test_producer = default_conf['external_message_consumer']['producers'][0] +# lock = asyncio.Lock() +# +# dp = DataProvider(default_conf, None, None, None) +# emc = ExternalMessageConsumer(default_conf, dp) +# +# try: +# # Test invalid URL +# test_producer['url'] = "tcp://null:8080/api/v1/message/ws" +# emc._running = True +# await emc._create_connection(test_producer, lock) +# emc._running = False +# +# assert log_has_re(r".+is an invalid WebSocket URL.+", caplog) +# finally: +# emc.shutdown() + + +async def test_emc_create_connection_error(default_conf, caplog, mocker): + default_conf.update({ + "external_message_consumer": { + "enabled": True, + "producers": [ + { + "name": "default", + "host": _TEST_WS_HOST, + "port": _TEST_WS_PORT, + "ws_token": _TEST_WS_TOKEN + } + ], + "wait_timeout": 60, + "ping_timeout": 60, + "sleep_timeout": 60 + } + }) + + # Test unexpected error + mocker.patch('websockets.connect', side_effect=RuntimeError) + + dp = DataProvider(default_conf, None, None, None) + emc = ExternalMessageConsumer(default_conf, dp) + + try: + await asyncio.sleep(0.01) + assert log_has("Unexpected error has occurred:", caplog) + finally: + emc.shutdown() + + +async def test_emc_receive_messages_valid(default_conf, caplog, mocker): + default_conf.update({ + "external_message_consumer": { + "enabled": True, + "producers": [ + { + "name": "default", + "host": _TEST_WS_HOST, + "port": _TEST_WS_PORT, + "ws_token": _TEST_WS_TOKEN + } + ], + "wait_timeout": 1, + "ping_timeout": 60, + "sleep_time": 60 + } + }) + + mocker.patch('freqtrade.rpc.external_message_consumer.ExternalMessageConsumer.start', + MagicMock()) + + lock = asyncio.Lock() + test_producer = default_conf['external_message_consumer']['producers'][0] + + dp = DataProvider(default_conf, None, None, None) + emc = ExternalMessageConsumer(default_conf, dp) + + loop = asyncio.get_event_loop() + def change_running(emc): emc._running = not emc._running + + class TestChannel: + async def recv(self, *args, **kwargs): + return {"type": "whitelist", "data": ["BTC/USDT"]} + + async def ping(self, *args, **kwargs): + return asyncio.Future() + + try: + change_running(emc) + loop.call_soon(functools.partial(change_running, emc=emc)) + await emc._receive_messages(TestChannel(), test_producer, lock) + + assert log_has_re(r"Received message of type `whitelist`.+", caplog) + finally: + emc.shutdown() + + +async def test_emc_receive_messages_invalid(default_conf, caplog, mocker): + default_conf.update({ + "external_message_consumer": { + "enabled": True, + "producers": [ + { + "name": "default", + "host": _TEST_WS_HOST, + "port": _TEST_WS_PORT, + "ws_token": _TEST_WS_TOKEN + } + ], + "wait_timeout": 1, + "ping_timeout": 60, + "sleep_time": 60 + } + }) + + mocker.patch('freqtrade.rpc.external_message_consumer.ExternalMessageConsumer.start', + MagicMock()) + + lock = asyncio.Lock() + test_producer = default_conf['external_message_consumer']['producers'][0] + + dp = DataProvider(default_conf, None, None, None) + emc = ExternalMessageConsumer(default_conf, dp) + + loop = asyncio.get_event_loop() + def change_running(emc): emc._running = not emc._running + + class TestChannel: + async def recv(self, *args, **kwargs): + return {"type": ["BTC/USDT"]} + + async def ping(self, *args, **kwargs): + return asyncio.Future() + + try: + change_running(emc) + loop.call_soon(functools.partial(change_running, emc=emc)) + await emc._receive_messages(TestChannel(), test_producer, lock) + + assert log_has_re(r"Invalid message from.+", caplog) + finally: + emc.shutdown() + + +async def test_emc_receive_messages_timeout(default_conf, caplog, mocker): + default_conf.update({ + "external_message_consumer": { + "enabled": True, + "producers": [ + { + "name": "default", + "host": _TEST_WS_HOST, + "port": _TEST_WS_PORT, + "ws_token": _TEST_WS_TOKEN + } + ], + "wait_timeout": 1, + "ping_timeout": 1, + "sleep_time": 1 + } + }) + + mocker.patch('freqtrade.rpc.external_message_consumer.ExternalMessageConsumer.start', + MagicMock()) + + lock = asyncio.Lock() + test_producer = default_conf['external_message_consumer']['producers'][0] + + dp = DataProvider(default_conf, None, None, None) + emc = ExternalMessageConsumer(default_conf, dp) + + loop = asyncio.get_event_loop() + def change_running(emc): emc._running = not emc._running + + class TestChannel: + async def recv(self, *args, **kwargs): + await asyncio.sleep(10) + + async def ping(self, *args, **kwargs): + return asyncio.Future() + + try: + change_running(emc) + loop.call_soon(functools.partial(change_running, emc=emc)) + await emc._receive_messages(TestChannel(), test_producer, lock) + + assert log_has_re(r"Ping error.+", caplog) + finally: + emc.shutdown() + + +async def test_emc_receive_messages_handle_error(default_conf, caplog, mocker): + default_conf.update({ + "external_message_consumer": { + "enabled": True, + "producers": [ + { + "name": "default", + "host": _TEST_WS_HOST, + "port": _TEST_WS_PORT, + "ws_token": _TEST_WS_TOKEN + } + ], + "wait_timeout": 1, + "ping_timeout": 1, + "sleep_time": 1 + } + }) + + mocker.patch('freqtrade.rpc.external_message_consumer.ExternalMessageConsumer.start', + MagicMock()) + + lock = asyncio.Lock() + test_producer = default_conf['external_message_consumer']['producers'][0] + + dp = DataProvider(default_conf, None, None, None) + emc = ExternalMessageConsumer(default_conf, dp) + + emc.handle_producer_message = MagicMock(side_effect=Exception) + + loop = asyncio.get_event_loop() + def change_running(emc): emc._running = not emc._running + + class TestChannel: + async def recv(self, *args, **kwargs): + return {"type": "whitelist", "data": ["BTC/USDT"]} + + async def ping(self, *args, **kwargs): + return asyncio.Future() + + try: + change_running(emc) + loop.call_soon(functools.partial(change_running, emc=emc)) + await emc._receive_messages(TestChannel(), test_producer, lock) + + assert log_has_re(r"Error handling producer message.+", caplog) + finally: + emc.shutdown() diff --git a/tests/test_freqtradebot.py b/tests/test_freqtradebot.py index 7851a73f8..5fe4d4011 100644 --- a/tests/test_freqtradebot.py +++ b/tests/test_freqtradebot.py @@ -1319,9 +1319,9 @@ def test_create_stoploss_order_invalid_order( assert create_order_mock.call_args[1]['amount'] == trade.amount # Rpc is sending first buy, then sell - assert rpc_mock.call_count == 2 - assert rpc_mock.call_args_list[1][0][0]['sell_reason'] == ExitType.EMERGENCY_EXIT.value - assert rpc_mock.call_args_list[1][0][0]['order_type'] == 'market' + assert rpc_mock.call_count == 3 + assert rpc_mock.call_args_list[2][0][0]['sell_reason'] == ExitType.EMERGENCY_EXIT.value + assert rpc_mock.call_args_list[2][0][0]['order_type'] == 'market' @pytest.mark.parametrize("is_short", [False, True]) @@ -2439,7 +2439,7 @@ def test_manage_open_orders_entry_usercustom( # Trade should be closed since the function returns true freqtrade.manage_open_orders() assert cancel_order_wr_mock.call_count == 1 - assert rpc_mock.call_count == 1 + assert rpc_mock.call_count == 2 trades = Trade.query.filter(Trade.open_order_id.is_(open_trade.open_order_id)).all() nb_trades = len(trades) assert nb_trades == 0 @@ -2478,7 +2478,7 @@ def test_manage_open_orders_entry( # check it does cancel buy orders over the time limit freqtrade.manage_open_orders() assert cancel_order_mock.call_count == 1 - assert rpc_mock.call_count == 1 + assert rpc_mock.call_count == 2 trades = Trade.query.filter(Trade.open_order_id.is_(open_trade.open_order_id)).all() nb_trades = len(trades) assert nb_trades == 0 @@ -2608,7 +2608,7 @@ def test_check_handle_cancelled_buy( # check it does cancel buy orders over the time limit freqtrade.manage_open_orders() assert cancel_order_mock.call_count == 0 - assert rpc_mock.call_count == 1 + assert rpc_mock.call_count == 2 trades = Trade.query.filter(Trade.open_order_id.is_(open_trade.open_order_id)).all() assert len(trades) == 0 assert log_has_re( @@ -2639,7 +2639,7 @@ def test_manage_open_orders_buy_exception( # check it does cancel buy orders over the time limit freqtrade.manage_open_orders() assert cancel_order_mock.call_count == 0 - assert rpc_mock.call_count == 0 + assert rpc_mock.call_count == 1 trades = Trade.query.filter(Trade.open_order_id.is_(open_trade.open_order_id)).all() nb_trades = len(trades) assert nb_trades == 1 @@ -2686,7 +2686,7 @@ def test_manage_open_orders_exit_usercustom( # Return false - No impact freqtrade.manage_open_orders() assert cancel_order_mock.call_count == 0 - assert rpc_mock.call_count == 0 + assert rpc_mock.call_count == 1 assert open_trade_usdt.is_open is False assert freqtrade.strategy.check_exit_timeout.call_count == 1 assert freqtrade.strategy.check_entry_timeout.call_count == 0 @@ -2696,7 +2696,7 @@ def test_manage_open_orders_exit_usercustom( # Return Error - No impact freqtrade.manage_open_orders() assert cancel_order_mock.call_count == 0 - assert rpc_mock.call_count == 0 + assert rpc_mock.call_count == 1 assert open_trade_usdt.is_open is False assert freqtrade.strategy.check_exit_timeout.call_count == 1 assert freqtrade.strategy.check_entry_timeout.call_count == 0 @@ -2706,7 +2706,7 @@ def test_manage_open_orders_exit_usercustom( freqtrade.strategy.check_entry_timeout = MagicMock(return_value=True) freqtrade.manage_open_orders() assert cancel_order_mock.call_count == 1 - assert rpc_mock.call_count == 1 + assert rpc_mock.call_count == 2 assert open_trade_usdt.is_open is True assert freqtrade.strategy.check_exit_timeout.call_count == 1 assert freqtrade.strategy.check_entry_timeout.call_count == 0 @@ -2766,7 +2766,7 @@ def test_manage_open_orders_exit( # check it does cancel sell orders over the time limit freqtrade.manage_open_orders() assert cancel_order_mock.call_count == 1 - assert rpc_mock.call_count == 1 + assert rpc_mock.call_count == 2 assert open_trade_usdt.is_open is True # Custom user sell-timeout is never called assert freqtrade.strategy.check_exit_timeout.call_count == 0 @@ -2805,7 +2805,7 @@ def test_check_handle_cancelled_exit( # check it does cancel sell orders over the time limit freqtrade.manage_open_orders() assert cancel_order_mock.call_count == 0 - assert rpc_mock.call_count == 1 + assert rpc_mock.call_count == 2 assert open_trade_usdt.is_open is True exit_name = 'Buy' if is_short else 'Sell' assert log_has_re(f"{exit_name} order cancelled on exchange for Trade.*", caplog) @@ -2843,7 +2843,7 @@ def test_manage_open_orders_partial( # note this is for a partially-complete buy order freqtrade.manage_open_orders() assert cancel_order_mock.call_count == 1 - assert rpc_mock.call_count == 2 + assert rpc_mock.call_count == 3 trades = Trade.query.filter(Trade.open_order_id.is_(open_trade.open_order_id)).all() assert len(trades) == 1 assert trades[0].amount == 23.0 @@ -2890,7 +2890,7 @@ def test_manage_open_orders_partial_fee( assert log_has_re(r"Applying fee on amount for Trade.*", caplog) assert cancel_order_mock.call_count == 1 - assert rpc_mock.call_count == 2 + assert rpc_mock.call_count == 3 trades = Trade.query.filter(Trade.open_order_id.is_(open_trade.open_order_id)).all() assert len(trades) == 1 # Verify that trade has been updated @@ -2940,7 +2940,7 @@ def test_manage_open_orders_partial_except( assert log_has_re(r"Could not update trade amount: .*", caplog) assert cancel_order_mock.call_count == 1 - assert rpc_mock.call_count == 2 + assert rpc_mock.call_count == 3 trades = Trade.query.filter(Trade.open_order_id.is_(open_trade.open_order_id)).all() assert len(trades) == 1 # Verify that trade has been updated @@ -3155,7 +3155,7 @@ def test_handle_cancel_exit_limit(mocker, default_conf_usdt, fee) -> None: reason = CANCEL_REASON['TIMEOUT'] assert freqtrade.handle_cancel_exit(trade, order, reason) assert cancel_order_mock.call_count == 1 - assert send_msg_mock.call_count == 1 + assert send_msg_mock.call_count == 2 assert trade.close_rate is None assert trade.exit_reason is None @@ -3592,7 +3592,7 @@ def test_execute_trade_exit_with_stoploss_on_exchange( trade.is_short = is_short assert trade assert cancel_order.call_count == 1 - assert rpc_mock.call_count == 3 + assert rpc_mock.call_count == 4 @pytest.mark.parametrize("is_short", [False, True]) @@ -3662,11 +3662,11 @@ def test_may_execute_trade_exit_after_stoploss_on_exchange_hit( assert trade.stoploss_order_id is None assert trade.is_open is False assert trade.exit_reason == ExitType.STOPLOSS_ON_EXCHANGE.value - assert rpc_mock.call_count == 3 - assert rpc_mock.call_args_list[0][0][0]['type'] == RPCMessageType.ENTRY - assert rpc_mock.call_args_list[0][0][0]['amount'] > 20 - assert rpc_mock.call_args_list[1][0][0]['type'] == RPCMessageType.ENTRY_FILL - assert rpc_mock.call_args_list[2][0][0]['type'] == RPCMessageType.EXIT_FILL + assert rpc_mock.call_count == 4 + assert rpc_mock.call_args_list[1][0][0]['type'] == RPCMessageType.ENTRY + assert rpc_mock.call_args_list[1][0][0]['amount'] > 20 + assert rpc_mock.call_args_list[2][0][0]['type'] == RPCMessageType.ENTRY_FILL + assert rpc_mock.call_args_list[3][0][0]['type'] == RPCMessageType.EXIT_FILL @pytest.mark.parametrize( diff --git a/tests/test_misc.py b/tests/test_misc.py index 4b52079bf..2da45bad9 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -7,10 +7,11 @@ from unittest.mock import MagicMock import pytest -from freqtrade.misc import (decimals_per_coin, deep_merge_dicts, file_dump_json, file_load_json, - format_ms_time, pair_to_filename, parse_db_uri_for_logging, plural, - render_template, render_template_with_fallback, round_coin_value, - safe_value_fallback, safe_value_fallback2, shorten_date) +from freqtrade.misc import (dataframe_to_json, decimals_per_coin, deep_merge_dicts, file_dump_json, + file_load_json, format_ms_time, json_to_dataframe, pair_to_filename, + parse_db_uri_for_logging, plural, render_template, + render_template_with_fallback, round_coin_value, safe_value_fallback, + safe_value_fallback2, shorten_date) def test_decimals_per_coin(): @@ -219,3 +220,14 @@ def test_deep_merge_dicts(): res2['first']['rows']['test'] = 'asdf' assert deep_merge_dicts(a, deepcopy(b), allow_null_overrides=False) == res2 + + +def test_dataframe_json(ohlcv_history): + from pandas.testing import assert_frame_equal + json = dataframe_to_json(ohlcv_history) + dataframe = json_to_dataframe(json) + + assert list(ohlcv_history.columns) == list(dataframe.columns) + assert len(ohlcv_history) == len(dataframe) + + assert_frame_equal(ohlcv_history, dataframe)