From f9b3b0ef77b34403ef11a70eac45106a67898152 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Sat, 8 Oct 2022 18:20:07 -0600 Subject: [PATCH] refactor broadcasting to queue per client, only send most recent candles --- freqtrade/data/dataprovider.py | 35 +++++++++++++++++--- freqtrade/rpc/api_server/api_ws.py | 37 ++++++++++++++++++---- freqtrade/rpc/api_server/webserver.py | 1 + freqtrade/rpc/api_server/ws/channel.py | 27 ++++++++-------- freqtrade/rpc/external_message_consumer.py | 10 ++++-- 5 files changed, 83 insertions(+), 27 deletions(-) diff --git a/freqtrade/data/dataprovider.py b/freqtrade/data/dataprovider.py index 4d7296ee7..8d5d7d01f 100644 --- a/freqtrade/data/dataprovider.py +++ b/freqtrade/data/dataprovider.py @@ -9,7 +9,7 @@ from collections import deque from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Tuple -from pandas import DataFrame +from pandas import DataFrame, concat from freqtrade.configuration import TimeRange from freqtrade.constants import Config, ListPairsWithTimeframes, PairWithTimeframe @@ -118,13 +118,13 @@ class DataProvider: 'type': RPCMessageType.ANALYZED_DF, 'data': { 'key': pair_key, - 'df': dataframe, + 'df': dataframe.tail(1), 'la': datetime.now(timezone.utc) } } ) - def _add_external_df( + def _add_producer_df( self, pair: str, dataframe: DataFrame, @@ -147,7 +147,16 @@ class DataProvider: _last_analyzed = datetime.now(timezone.utc) if not last_analyzed else last_analyzed - self.__producer_pairs_df[producer_name][pair_key] = (dataframe, _last_analyzed) + if pair_key not in self.__producer_pairs_df[producer_name]: + # This is the first message, set the dataframe in that pair key + self.__producer_pairs_df[producer_name][pair_key] = (dataframe, _last_analyzed) + else: + # These are new candles, append them to the dataframe + existing_df, _ = self.__producer_pairs_df[producer_name][pair_key] + existing_df = self._append_candle_to_dataframe(existing_df, dataframe) + + self.__producer_pairs_df[producer_name][pair_key] = (existing_df, _last_analyzed) + logger.debug(f"External DataFrame for {pair_key} from {producer_name} added.") def get_producer_df( @@ -184,6 +193,24 @@ class DataProvider: df, la = self.__producer_pairs_df[producer_name][pair_key] return (df.copy(), la) + def _append_candle_to_dataframe(self, existing: DataFrame, new: DataFrame): + """ + Append the `new` dataframe to the `existing` dataframe + + :param existing: The full dataframe you want appended to + :param new: The new dataframe containing the data you want appended + :returns:The dataframe with the new data in it + """ + if existing.iloc[-1]['date'] != new.iloc[-1]['date']: + existing = concat([existing, new]) + + # Only keep the last 1000 candles in memory + # TODO: Do this better + if len(existing) > 1000: + existing = existing[-1000:] + + return existing + def add_pairlisthandler(self, pairlists) -> None: """ Allow adding pairlisthandler after initialization diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index f55b2dbd3..50a008bf6 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -1,6 +1,8 @@ +import asyncio import logging from typing import Any, Dict +import websockets from fastapi import APIRouter, Depends, WebSocketDisconnect from fastapi.websockets import WebSocket, WebSocketState from pydantic import ValidationError @@ -90,6 +92,20 @@ async def _process_consumer_request( await channel.send(response.dict(exclude_none=True)) +async def relay(channel, queue): + """ + Relay messages in the queue to the channel + """ + while True: + message = await queue.get() + try: + await channel.send(message) + queue.task_done() + except RuntimeError: + # What do we do here? + return + + @router.websocket("/message/ws") async def message_endpoint( ws: WebSocket, @@ -100,12 +116,13 @@ async def message_endpoint( """ Message WebSocket endpoint, facilitates sending RPC messages """ + relay_task = None try: - channel = await channel_manager.on_connect(ws) - + channel, queue = await channel_manager.on_connect(ws) if await is_websocket_alive(ws): logger.info(f"Consumer connected - {channel}") + relay_task = asyncio.create_task(relay(channel, queue)) # Keep connection open until explicitly closed, and process requests try: @@ -115,26 +132,32 @@ async def message_endpoint( # Process the request here await _process_consumer_request(request, channel, rpc) - except WebSocketDisconnect: + except ( + WebSocketDisconnect, + websockets.exceptions.ConnectionClosed + ): # Handle client disconnects logger.info(f"Consumer disconnected - {channel}") - await channel_manager.on_disconnect(ws) except Exception as e: logger.info(f"Consumer connection failed - {channel}") logger.exception(e) # Handle cases like - # RuntimeError('Cannot call "send" once a closed message has been sent') + finally: + relay_task.cancel() await channel_manager.on_disconnect(ws) else: await ws.close() except RuntimeError: - # WebSocket was closed - await channel_manager.on_disconnect(ws) - + # We don't want to log these + pass except Exception as e: logger.error(f"Failed to serve - {ws.client}") # Log tracebacks to keep track of what errors are happening logger.exception(e) + finally: await channel_manager.on_disconnect(ws) + if relay_task: + relay_task.cancel() diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index 53af91477..2f597f95a 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -245,6 +245,7 @@ class ApiServer(RPCHandler): use_colors=False, log_config=None, access_log=True if verbosity != 'error' else False, + ws_ping_interval=None ) try: self._server = UvicornServer(uvconfig) diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 69a32e266..fb001fdf1 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -1,6 +1,7 @@ +import asyncio import logging from threading import RLock -from typing import List, Optional, Type +from typing import Any, Dict, List, Optional, Type from uuid import uuid4 from fastapi import WebSocket as FastAPIWebSocket @@ -52,7 +53,7 @@ class WebSocketChannel: """ Send data on the wrapped websocket """ - await self._wrapped_ws.send(data) + return await self._wrapped_ws.send(data) async def recv(self): """ @@ -115,11 +116,12 @@ class ChannelManager: return ws_channel = WebSocketChannel(websocket) + ws_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() with self._lock: - self.channels[websocket] = ws_channel + self.channels[websocket] = (ws_channel, ws_queue) - return ws_channel + return ws_channel, ws_queue async def on_disconnect(self, websocket: WebSocketType): """ @@ -128,7 +130,7 @@ class ChannelManager: :param websocket: The WebSocket objet attached to the Channel """ with self._lock: - channel = self.channels.get(websocket) + channel, _ = self.channels.get(websocket, (None, None)) if channel: if not channel.is_closed(): await channel.close() @@ -140,7 +142,7 @@ class ChannelManager: Disconnect all Channels """ with self._lock: - for websocket, channel in self.channels.copy().items(): + for websocket, (channel, _) in self.channels.copy().items(): if not channel.is_closed(): await channel.close() @@ -154,13 +156,12 @@ class ChannelManager: """ with self._lock: message_type = data.get('type') - for websocket, channel in self.channels.copy().items(): - try: - if channel.subscribed_to(message_type): - await channel.send(data) - except RuntimeError: - # Handle cannot send after close cases - await self.on_disconnect(websocket) + for websocket, (channel, queue) in self.channels.copy().items(): + if channel.subscribed_to(message_type): + if not queue.full(): + queue.put_nowait(data) + else: + await self.on_disconnect(websocket) async def send_direct(self, channel, data): """ diff --git a/freqtrade/rpc/external_message_consumer.py b/freqtrade/rpc/external_message_consumer.py index f5ba4b490..fa870595c 100644 --- a/freqtrade/rpc/external_message_consumer.py +++ b/freqtrade/rpc/external_message_consumer.py @@ -62,7 +62,7 @@ class ExternalMessageConsumer: self.enabled = self._emc_config.get('enabled', False) self.producers: List[Producer] = self._emc_config.get('producers', []) - self.wait_timeout = self._emc_config.get('wait_timeout', 300) # in seconds + self.wait_timeout = self._emc_config.get('wait_timeout', 30) # in seconds self.ping_timeout = self._emc_config.get('ping_timeout', 10) # in seconds self.sleep_time = self._emc_config.get('sleep_time', 10) # in seconds @@ -182,7 +182,11 @@ class ExternalMessageConsumer: ws_url = f"ws://{host}:{port}/api/v1/message/ws?token={token}" # This will raise InvalidURI if the url is bad - async with websockets.connect(ws_url, max_size=self.message_size_limit) as ws: + async with websockets.connect( + ws_url, + max_size=self.message_size_limit, + ping_interval=None + ) as ws: channel = WebSocketChannel(ws, channel_id=name) logger.info(f"Producer connection success - {channel}") @@ -325,7 +329,7 @@ class ExternalMessageConsumer: df = remove_entry_exit_signals(df) # Add the dataframe to the dataprovider - self._dp._add_external_df(pair, df, + self._dp._add_producer_df(pair, df, last_analyzed=la, timeframe=timeframe, candle_type=candle_type,