minor improvements and pairlist data transmission
This commit is contained in:
		| @@ -7,5 +7,4 @@ class ReplicateModeType(str, Enum): | |||||||
|  |  | ||||||
|  |  | ||||||
| class LeaderMessageType(str, Enum): | class LeaderMessageType(str, Enum): | ||||||
|     Pairlist = "pairlist" |     whitelist = "whitelist" | ||||||
|     Dataframe = "dataframe" |  | ||||||
|   | |||||||
| @@ -75,6 +75,8 @@ class FreqtradeBot(LoggingMixin): | |||||||
|  |  | ||||||
|         PairLocks.timeframe = self.config['timeframe'] |         PairLocks.timeframe = self.config['timeframe'] | ||||||
|  |  | ||||||
|  |         self.replicate_controller = None | ||||||
|  |  | ||||||
|         # RPC runs in separate threads, can start handling external commands just after |         # RPC runs in separate threads, can start handling external commands just after | ||||||
|         # initialization, even before Freqtradebot has a chance to start its throttling, |         # initialization, even before Freqtradebot has a chance to start its throttling, | ||||||
|         # so anything in the Freqtradebot instance should be ready (initialized), including |         # so anything in the Freqtradebot instance should be ready (initialized), including | ||||||
| @@ -264,6 +266,17 @@ class FreqtradeBot(LoggingMixin): | |||||||
|             # Extend active-pair whitelist with pairs of open trades |             # Extend active-pair whitelist with pairs of open trades | ||||||
|             # It ensures that candle (OHLCV) data are downloaded for open trades as well |             # It ensures that candle (OHLCV) data are downloaded for open trades as well | ||||||
|             _whitelist.extend([trade.pair for trade in trades if trade.pair not in _whitelist]) |             _whitelist.extend([trade.pair for trade in trades if trade.pair not in _whitelist]) | ||||||
|  |  | ||||||
|  |         # If replicate leader, broadcast whitelist data | ||||||
|  |         if self.replicate_controller: | ||||||
|  |             if self.replicate_controller.is_leader(): | ||||||
|  |                 self.replicate_controller.send_message( | ||||||
|  |                     { | ||||||
|  |                         "data_type": "whitelist", | ||||||
|  |                         "data": _whitelist | ||||||
|  |                     } | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|         return _whitelist |         return _whitelist | ||||||
|  |  | ||||||
|     def get_free_open_trades(self) -> int: |     def get_free_open_trades(self) -> int: | ||||||
|   | |||||||
							
								
								
									
										59
									
								
								freqtrade/plugins/pairlist/ExternalPairList.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								freqtrade/plugins/pairlist/ExternalPairList.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,59 @@ | |||||||
|  | """ | ||||||
|  | External Pair List provider | ||||||
|  |  | ||||||
|  | Provides pair list from Leader data | ||||||
|  | """ | ||||||
|  | import logging | ||||||
|  | from typing import Any, Dict, List | ||||||
|  |  | ||||||
|  | from freqtrade.plugins.pairlist.IPairList import IPairList | ||||||
|  |  | ||||||
|  |  | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ExternalPairList(IPairList): | ||||||
|  |  | ||||||
|  |     def __init__(self, exchange, pairlistmanager, | ||||||
|  |                  config: Dict[str, Any], pairlistconfig: Dict[str, Any], | ||||||
|  |                  pairlist_pos: int) -> None: | ||||||
|  |         super().__init__(exchange, pairlistmanager, config, pairlistconfig, pairlist_pos) | ||||||
|  |  | ||||||
|  |         self._num_assets = self._pairlistconfig.get('num_assets') | ||||||
|  |         self._allow_inactive = self._pairlistconfig.get('allow_inactive', False) | ||||||
|  |  | ||||||
|  |         self._leader_pairs: List[str] = [] | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def needstickers(self) -> bool: | ||||||
|  |         """ | ||||||
|  |         Boolean property defining if tickers are necessary. | ||||||
|  |         If no Pairlist requires tickers, an empty Dict is passed | ||||||
|  |         as tickers argument to filter_pairlist | ||||||
|  |         """ | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |     def short_desc(self) -> str: | ||||||
|  |         """ | ||||||
|  |         Short whitelist method description - used for startup-messages | ||||||
|  |         -> Please overwrite in subclasses | ||||||
|  |         """ | ||||||
|  |         return f"{self.name}" | ||||||
|  |  | ||||||
|  |     def gen_pairlist(self, tickers: Dict) -> List[str]: | ||||||
|  |         """ | ||||||
|  |         Generate the pairlist | ||||||
|  |         :param tickers: Tickers (from exchange.get_tickers()). May be cached. | ||||||
|  |         :return: List of pairs | ||||||
|  |         """ | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |     def filter_pairlist(self, pairlist: List[str], tickers: Dict) -> List[str]: | ||||||
|  |         """ | ||||||
|  |         Filters and sorts pairlist and returns the whitelist again. | ||||||
|  |         Called on each bot iteration - please use internal caching if necessary | ||||||
|  |         :param pairlist: pairlist to filter or sort | ||||||
|  |         :param tickers: Tickers (from exchange.get_tickers()). May be cached. | ||||||
|  |         :return: new whitelist | ||||||
|  |         """ | ||||||
|  |         pass | ||||||
| @@ -5,7 +5,7 @@ import asyncio | |||||||
| import logging | import logging | ||||||
| import secrets | import secrets | ||||||
| import socket | import socket | ||||||
| from threading import Thread | from threading import Event, Thread | ||||||
| from typing import Any, Coroutine, Dict, Union | from typing import Any, Coroutine, Dict, Union | ||||||
|  |  | ||||||
| import websockets | import websockets | ||||||
| @@ -50,6 +50,9 @@ class ReplicateController(RPCHandler): | |||||||
|         self._thread = None |         self._thread = None | ||||||
|         self._queue = None |         self._queue = None | ||||||
|  |  | ||||||
|  |         self._stop_event = Event() | ||||||
|  |         self._follower_tasks = None | ||||||
|  |  | ||||||
|         self.channel_manager = ChannelManager() |         self.channel_manager = ChannelManager() | ||||||
|  |  | ||||||
|         self.replicate_config = config.get('replicate', {}) |         self.replicate_config = config.get('replicate', {}) | ||||||
| @@ -93,10 +96,7 @@ class ReplicateController(RPCHandler): | |||||||
|  |  | ||||||
|         self.start_threaded_loop() |         self.start_threaded_loop() | ||||||
|  |  | ||||||
|         if self.mode == ReplicateModeType.follower: |         self.start() | ||||||
|             self.start_follower_mode() |  | ||||||
|         elif self.mode == ReplicateModeType.leader: |  | ||||||
|             self.start_leader_mode() |  | ||||||
|  |  | ||||||
|     def start_threaded_loop(self): |     def start_threaded_loop(self): | ||||||
|         """ |         """ | ||||||
| @@ -129,6 +129,29 @@ class ReplicateController(RPCHandler): | |||||||
|             logger.error(f"Error running coroutine - {str(e)}") |             logger.error(f"Error running coroutine - {str(e)}") | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|  |     async def main_loop(self): | ||||||
|  |         """ | ||||||
|  |         Main loop coro | ||||||
|  |  | ||||||
|  |         Start the loop based on what mode we're in | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             if self.mode == ReplicateModeType.leader: | ||||||
|  |                 await self.leader_loop() | ||||||
|  |             elif self.mode == ReplicateModeType.follower: | ||||||
|  |                 await self.follower_loop() | ||||||
|  |  | ||||||
|  |         except asyncio.CancelledError: | ||||||
|  |             pass | ||||||
|  |         finally: | ||||||
|  |             self._loop.stop() | ||||||
|  |  | ||||||
|  |     def start(self): | ||||||
|  |         """ | ||||||
|  |         Start the controller main loop | ||||||
|  |         """ | ||||||
|  |         self.submit_coroutine(self.main_loop()) | ||||||
|  |  | ||||||
|     def cleanup(self) -> None: |     def cleanup(self) -> None: | ||||||
|         """ |         """ | ||||||
|         Cleanup pending module resources. |         Cleanup pending module resources. | ||||||
| @@ -144,27 +167,62 @@ class ReplicateController(RPCHandler): | |||||||
|                     task.cancel() |                     task.cancel() | ||||||
|  |  | ||||||
|                 self._loop.call_soon_threadsafe(self.channel_manager.disconnect_all) |                 self._loop.call_soon_threadsafe(self.channel_manager.disconnect_all) | ||||||
|                 # This must be called threadsafe, otherwise would not work |  | ||||||
|                 self._loop.call_soon_threadsafe(self._loop.stop) |  | ||||||
|  |  | ||||||
|             self._thread.join() |             self._thread.join() | ||||||
|  |  | ||||||
|     def send_msg(self, msg: Dict[str, Any]) -> None: |     def send_msg(self, msg: Dict[str, Any]) -> None: | ||||||
|         """ Push message through """ |         """ | ||||||
|  |         Support RPC calls | ||||||
|  |         """ | ||||||
|         if msg["type"] == RPCMessageType.EMIT_DATA: |         if msg["type"] == RPCMessageType.EMIT_DATA: | ||||||
|             self._send_message( |             self.send_message( | ||||||
|                 { |                 { | ||||||
|                     "type": msg["data_type"], |                     "data_type": msg.get("data_type"), | ||||||
|                     "content": msg["data"] |                     "data": msg.get("data") | ||||||
|                 } |                 } | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |     def send_message(self, msg: Dict[str, Any]) -> None: | ||||||
|  |         """ Push message through """ | ||||||
|  |  | ||||||
|  |         if self.channel_manager.has_channels(): | ||||||
|  |             self._send_message(msg) | ||||||
|  |         else: | ||||||
|  |             logger.debug("No listening followers, skipping...") | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |     def _send_message(self, msg: Dict[Any, Any]): | ||||||
|  |         """ | ||||||
|  |         Add data to the internal queue to be broadcasted. This func will block | ||||||
|  |         if the queue is full. This is meant to be called in the main thread. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         if self._queue: | ||||||
|  |             queue = self._queue.sync_q | ||||||
|  |             queue.put(msg) | ||||||
|  |         else: | ||||||
|  |             logger.warning("Can not send data, leader loop has not started yet!") | ||||||
|  |  | ||||||
|  |     def is_leader(self): | ||||||
|  |         """ | ||||||
|  |         Leader flag | ||||||
|  |         """ | ||||||
|  |         return self.enabled() and self.mode == ReplicateModeType.leader | ||||||
|  |  | ||||||
|  |     def enabled(self): | ||||||
|  |         """ | ||||||
|  |         Enabled flag | ||||||
|  |         """ | ||||||
|  |         return self.replicate_config.get('enabled', False) | ||||||
|  |  | ||||||
|     # ----------------------- LEADER LOGIC ------------------------------ |     # ----------------------- LEADER LOGIC ------------------------------ | ||||||
|  |  | ||||||
|     def start_leader_mode(self): |     async def leader_loop(self): | ||||||
|         """ |         """ | ||||||
|         Register the endpoint and start the leader loop |         Main leader coroutine | ||||||
|  |  | ||||||
|  |         This starts all of the leader coros and registers the endpoint on | ||||||
|  |         the ApiServer | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         logger.info("Running rpc.replicate in Leader mode") |         logger.info("Running rpc.replicate in Leader mode") | ||||||
| @@ -173,30 +231,13 @@ class ReplicateController(RPCHandler): | |||||||
|         logger.info("-" * 15) |         logger.info("-" * 15) | ||||||
|  |  | ||||||
|         self.register_leader_endpoint() |         self.register_leader_endpoint() | ||||||
|         self.submit_coroutine(self.leader_loop()) |  | ||||||
|  |  | ||||||
|     async def leader_loop(self): |  | ||||||
|         """ |  | ||||||
|         Main leader coroutine |  | ||||||
|         At the moment this just broadcasts data that's in the queue to the followers |  | ||||||
|         """ |  | ||||||
|         try: |         try: | ||||||
|             await self._broadcast_queue_data() |             await self._broadcast_queue_data() | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             logger.error("Exception occurred in leader loop: ") |             logger.error("Exception occurred in leader loop: ") | ||||||
|             logger.exception(e) |             logger.exception(e) | ||||||
|  |  | ||||||
|     def _send_message(self, data: Dict[Any, Any]): |  | ||||||
|         """ |  | ||||||
|         Add data to the internal queue to be broadcasted. This func will block |  | ||||||
|         if the queue is full. This is meant to be called in the main thread. |  | ||||||
|         """ |  | ||||||
|  |  | ||||||
|         if self._queue: |  | ||||||
|             self._queue.put(data) |  | ||||||
|         else: |  | ||||||
|             logger.warning("Can not send data, leader loop has not started yet!") |  | ||||||
|  |  | ||||||
|     async def _broadcast_queue_data(self): |     async def _broadcast_queue_data(self): | ||||||
|         """ |         """ | ||||||
|         Loop over queue data and broadcast it |         Loop over queue data and broadcast it | ||||||
| @@ -210,6 +251,8 @@ class ReplicateController(RPCHandler): | |||||||
|                 # Get data from queue |                 # Get data from queue | ||||||
|                 data = await async_queue.get() |                 data = await async_queue.get() | ||||||
|  |  | ||||||
|  |                 logger.info(f"Found data - broadcasting: {data}") | ||||||
|  |  | ||||||
|                 # Broadcast it to everyone |                 # Broadcast it to everyone | ||||||
|                 await self.channel_manager.broadcast(data) |                 await self.channel_manager.broadcast(data) | ||||||
|  |  | ||||||
| @@ -263,6 +306,9 @@ class ReplicateController(RPCHandler): | |||||||
|                 channel = await self.channel_manager.on_connect(websocket) |                 channel = await self.channel_manager.on_connect(websocket) | ||||||
|  |  | ||||||
|                 # Send initial data here |                 # Send initial data here | ||||||
|  |                 # Data is being broadcasted right away as soon as startup, | ||||||
|  |                 # we may not have to send initial data at all. Further testing | ||||||
|  |                 # required. | ||||||
|  |  | ||||||
|                 # Keep connection open until explicitly closed, and sleep |                 # Keep connection open until explicitly closed, and sleep | ||||||
|                 try: |                 try: | ||||||
| @@ -286,20 +332,15 @@ class ReplicateController(RPCHandler): | |||||||
|  |  | ||||||
|     # -------------------------------FOLLOWER LOGIC---------------------------- |     # -------------------------------FOLLOWER LOGIC---------------------------- | ||||||
|  |  | ||||||
|     def start_follower_mode(self): |  | ||||||
|         """ |  | ||||||
|         Start the ReplicateController in Follower mode |  | ||||||
|         """ |  | ||||||
|         logger.info("Starting rpc.replicate in Follower mode") |  | ||||||
|  |  | ||||||
|         self.submit_coroutine(self.follower_loop()) |  | ||||||
|  |  | ||||||
|     async def follower_loop(self): |     async def follower_loop(self): | ||||||
|         """ |         """ | ||||||
|         Main follower coroutine |         Main follower coroutine | ||||||
|  |  | ||||||
|         This starts all of the leader connection coros |         This starts all of the follower connection coros | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|  |         logger.info("Starting rpc.replicate in Follower mode") | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             await self._connect_to_leaders() |             await self._connect_to_leaders() | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
| @@ -307,21 +348,26 @@ class ReplicateController(RPCHandler): | |||||||
|             logger.exception(e) |             logger.exception(e) | ||||||
|  |  | ||||||
|     async def _connect_to_leaders(self): |     async def _connect_to_leaders(self): | ||||||
|  |         """ | ||||||
|  |         For each leader in `self.leaders_list` create a connection and | ||||||
|  |         listen for data. | ||||||
|  |         """ | ||||||
|         rpc_lock = asyncio.Lock() |         rpc_lock = asyncio.Lock() | ||||||
|  |  | ||||||
|         logger.info("Starting connections to Leaders...") |         logger.info("Starting connections to Leaders...") | ||||||
|         await asyncio.wait( |  | ||||||
|             [ |         self.follower_tasks = [ | ||||||
|                 self._handle_leader_connection(leader, rpc_lock) |             self._loop.create_task(self._handle_leader_connection(leader, rpc_lock)) | ||||||
|             for leader in self.leaders_list |             for leader in self.leaders_list | ||||||
|         ] |         ] | ||||||
|         ) |         return await asyncio.gather(*self.follower_tasks, return_exceptions=True) | ||||||
|  |  | ||||||
|     async def _handle_leader_connection(self, leader, lock): |     async def _handle_leader_connection(self, leader, lock): | ||||||
|         """ |         """ | ||||||
|         Given a leader, connect and wait on data. If connection is lost, |         Given a leader, connect and wait on data. If connection is lost, | ||||||
|         it will attempt to reconnect. |         it will attempt to reconnect. | ||||||
|         """ |         """ | ||||||
|  |         try: | ||||||
|             url, token = leader["url"], leader["token"] |             url, token = leader["url"], leader["token"] | ||||||
|  |  | ||||||
|             websocket_url = f"{url}?token={token}" |             websocket_url = f"{url}?token={token}" | ||||||
| @@ -339,20 +385,22 @@ class ReplicateController(RPCHandler): | |||||||
|                                     timeout=self.reply_timeout |                                     timeout=self.reply_timeout | ||||||
|                                 ) |                                 ) | ||||||
|                             except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed): |                             except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed): | ||||||
|                             # We haven't received data yet. Just check the connection and continue. |                                 # We haven't received data yet. Check the connection and continue. | ||||||
|                                 try: |                                 try: | ||||||
|                                     # ping |                                     # ping | ||||||
|                                     ping = await channel.ping() |                                     ping = await channel.ping() | ||||||
|                                     await asyncio.wait_for(ping, timeout=self.ping_timeout) |                                     await asyncio.wait_for(ping, timeout=self.ping_timeout) | ||||||
|                                 logger.info(f"Connection to {url} still alive...") |                                     logger.debug(f"Connection to {url} still alive...") | ||||||
|                                     continue |                                     continue | ||||||
|                                 except Exception: |                                 except Exception: | ||||||
|                                 logger.info(f"Ping error {url} - retrying in {self.sleep_time}s") |                                     logger.info( | ||||||
|  |                                         f"Ping error {url} - retrying in {self.sleep_time}s") | ||||||
|                                     asyncio.sleep(self.sleep_time) |                                     asyncio.sleep(self.sleep_time) | ||||||
|                                     break |                                     break | ||||||
|  |  | ||||||
|                         with lock: |                             async with lock: | ||||||
|                             # Should we have a lock here? |                                 # Acquire lock so only 1 coro handling at a time | ||||||
|  |                                 # as we might call the RPC module in the main thread | ||||||
|                                 await self._handle_leader_message(data) |                                 await self._handle_leader_message(data) | ||||||
|  |  | ||||||
|                 except socket.gaierror: |                 except socket.gaierror: | ||||||
| @@ -364,22 +412,12 @@ class ReplicateController(RPCHandler): | |||||||
|                     await asyncio.sleep(self.sleep_time) |                     await asyncio.sleep(self.sleep_time) | ||||||
|                     continue |                     continue | ||||||
|  |  | ||||||
|  |         except asyncio.CancelledError: | ||||||
|  |             pass | ||||||
|  |  | ||||||
|     async def _handle_leader_message(self, message): |     async def _handle_leader_message(self, message): | ||||||
|         type = message.get("type") |         type = message.get('data_type') | ||||||
|  |         data = message.get('data') | ||||||
|  |  | ||||||
|         message_type_handlers = { |         if type == LeaderMessageType.whitelist: | ||||||
|             LeaderMessageType.Pairlist.value: self._handle_pairlist_message, |             logger.info(f"Received whitelist from Leader: {data}") | ||||||
|             LeaderMessageType.Dataframe.value: self._handle_dataframe_message |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         handler = message_type_handlers.get(type, self._handle_default_message) |  | ||||||
|         return await handler(message) |  | ||||||
|  |  | ||||||
|     async def _handle_default_message(self, message): |  | ||||||
|         logger.info(f"Default message handled: {message}") |  | ||||||
|  |  | ||||||
|     async def _handle_pairlist_message(self, message): |  | ||||||
|         logger.info(f"Pairlist message handled: {message}") |  | ||||||
|  |  | ||||||
|     async def _handle_dataframe_message(self, message): |  | ||||||
|         logger.info(f"Dataframe message handled: {message}") |  | ||||||
|   | |||||||
| @@ -1,3 +1,4 @@ | |||||||
|  | import logging | ||||||
| from typing import Type | from typing import Type | ||||||
|  |  | ||||||
| from freqtrade.rpc.replicate.proxy import WebSocketProxy | from freqtrade.rpc.replicate.proxy import WebSocketProxy | ||||||
| @@ -5,6 +6,9 @@ from freqtrade.rpc.replicate.serializer import JSONWebSocketSerializer, WebSocke | |||||||
| from freqtrade.rpc.replicate.types import WebSocketType | from freqtrade.rpc.replicate.types import WebSocketType | ||||||
|  |  | ||||||
|  |  | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
| class WebSocketChannel: | class WebSocketChannel: | ||||||
|     """ |     """ | ||||||
|     Object to help facilitate managing a websocket connection |     Object to help facilitate managing a websocket connection | ||||||
| @@ -85,9 +89,12 @@ class ChannelManager: | |||||||
|         """ |         """ | ||||||
|         if websocket in self.channels.keys(): |         if websocket in self.channels.keys(): | ||||||
|             channel = self.channels[websocket] |             channel = self.channels[websocket] | ||||||
|  |  | ||||||
|  |             logger.debug(f"Disconnecting channel - {channel}") | ||||||
|  |  | ||||||
|             if not channel.is_closed(): |             if not channel.is_closed(): | ||||||
|                 await channel.close() |                 await channel.close() | ||||||
|             del channel |             del self.channels[websocket] | ||||||
|  |  | ||||||
|     async def disconnect_all(self): |     async def disconnect_all(self): | ||||||
|         """ |         """ | ||||||
| @@ -102,5 +109,15 @@ class ChannelManager: | |||||||
|  |  | ||||||
|         :param data: The data to send |         :param data: The data to send | ||||||
|         """ |         """ | ||||||
|         for channel in self.channels.values(): |         for websocket, channel in self.channels.items(): | ||||||
|  |             try: | ||||||
|                 await channel.send(data) |                 await channel.send(data) | ||||||
|  |             except RuntimeError: | ||||||
|  |                 # Handle cannot send after close cases | ||||||
|  |                 await self.on_disconnect(websocket) | ||||||
|  |  | ||||||
|  |     def has_channels(self): | ||||||
|  |         """ | ||||||
|  |         Flag for more than 0 channels | ||||||
|  |         """ | ||||||
|  |         return len(self.channels) > 0 | ||||||
|   | |||||||
| @@ -1,11 +1,9 @@ | |||||||
| from typing import TYPE_CHECKING, Union | from typing import Union | ||||||
|  |  | ||||||
| from fastapi import WebSocket as FastAPIWebSocket | from fastapi import WebSocket as FastAPIWebSocket | ||||||
| from websockets import WebSocketClientProtocol as WebSocket | from websockets import WebSocketClientProtocol as WebSocket | ||||||
|  |  | ||||||
|  | from freqtrade.rpc.replicate.types import WebSocketType | ||||||
| if TYPE_CHECKING: |  | ||||||
|     from freqtrade.rpc.replicate.types import WebSocketType |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class WebSocketProxy: | class WebSocketProxy: | ||||||
| @@ -21,6 +19,9 @@ class WebSocketProxy: | |||||||
|         """ |         """ | ||||||
|         Send data on the wrapped websocket |         Send data on the wrapped websocket | ||||||
|         """ |         """ | ||||||
|  |         if isinstance(data, str): | ||||||
|  |             data = data.encode() | ||||||
|  |  | ||||||
|         if hasattr(self._websocket, "send_bytes"): |         if hasattr(self._websocket, "send_bytes"): | ||||||
|             await self._websocket.send_bytes(data) |             await self._websocket.send_bytes(data) | ||||||
|         else: |         else: | ||||||
|   | |||||||
| @@ -33,10 +33,10 @@ class WebSocketSerializer(ABC): | |||||||
|  |  | ||||||
|  |  | ||||||
| class JSONWebSocketSerializer(WebSocketSerializer): | class JSONWebSocketSerializer(WebSocketSerializer): | ||||||
|     def _serialize(self, data: bytes) -> bytes: |     def _serialize(self, data): | ||||||
|         # json expects string not bytes |         # json expects string not bytes | ||||||
|         return json.dumps(data.decode()).encode() |         return json.dumps(data) | ||||||
|  |  | ||||||
|     def _deserialize(self, data: bytes) -> bytes: |     def _deserialize(self, data): | ||||||
|         # The WebSocketSerializer gives bytes not string |         # The WebSocketSerializer gives bytes not string | ||||||
|         return json.loads(data).encode() |         return json.loads(data) | ||||||
|   | |||||||
| @@ -3,7 +3,5 @@ from typing import TypeVar | |||||||
| from fastapi import WebSocket as FastAPIWebSocket | from fastapi import WebSocket as FastAPIWebSocket | ||||||
| from websockets import WebSocketClientProtocol as WebSocket | from websockets import WebSocketClientProtocol as WebSocket | ||||||
|  |  | ||||||
| from freqtrade.rpc.replicate.channel import WebSocketProxy |  | ||||||
|  |  | ||||||
|  | WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket) | ||||||
| WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket, WebSocketProxy) |  | ||||||
|   | |||||||
| @@ -59,6 +59,9 @@ class RPCManager: | |||||||
|                 replicate_rpc = ReplicateController(self._rpc, config, apiserver) |                 replicate_rpc = ReplicateController(self._rpc, config, apiserver) | ||||||
|                 self.registered_modules.append(replicate_rpc) |                 self.registered_modules.append(replicate_rpc) | ||||||
|  |  | ||||||
|  |                 # Attach the controller to FreqTrade | ||||||
|  |                 freqtrade.replicate_controller = replicate_rpc | ||||||
|  |  | ||||||
|             apiserver.start_api() |             apiserver.start_api() | ||||||
|  |  | ||||||
|     def cleanup(self) -> None: |     def cleanup(self) -> None: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user