""" This module manages replicate mode communication """ import asyncio import logging import secrets import socket from threading import Thread from typing import Any, Coroutine, Dict, Union import websockets from fastapi import Depends from fastapi import WebSocket as FastAPIWebSocket from fastapi import WebSocketDisconnect, status from freqtrade.enums import ExternalSignalModeType, LeaderMessageType, RPCMessageType from freqtrade.rpc import RPC, RPCHandler from freqtrade.rpc.external_signal.channel import ChannelManager from freqtrade.rpc.external_signal.thread_queue import Queue as ThreadedQueue from freqtrade.rpc.external_signal.types import MessageType from freqtrade.rpc.external_signal.utils import is_websocket_alive logger = logging.getLogger(__name__) class ExternalSignalController(RPCHandler): """ This class handles all websocket communication """ def __init__( self, rpc: RPC, config: Dict[str, Any], api_server: Union[Any, None] = None ) -> None: """ Init the ExternalSignalController class, and init the super class RPCHandler :param rpc: instance of RPC Helper class :param config: Configuration object :param api_server: The ApiServer object :return: None """ super().__init__(rpc, config) self.freqtrade = rpc._freqtrade self.api_server = api_server if not self.api_server: raise RuntimeError("The API server must be enabled for external signals to work") self._loop = None self._running = False self._thread = None self._queue = None self._main_task = None self._sub_tasks = None self.channel_manager = ChannelManager() self.external_signal_config = config.get('external_signal', {}) # What the config should look like # "external_signal": { # "enabled": true, # "mode": "follower", # "leaders": [ # { # "url": "ws://localhost:8080/signals/ws", # "api_token": "test" # } # ] # } # "external_signal": { # "enabled": true, # "mode": "leader", # "api_token": "test" # } self.mode = ExternalSignalModeType[ self.external_signal_config.get('mode', 'leader').lower() ] self.leaders_list = self.external_signal_config.get('leaders', []) self.push_throttle_secs = self.external_signal_config.get('push_throttle_secs', 0.1) self.reply_timeout = self.external_signal_config.get('follower_reply_timeout', 10) self.ping_timeout = self.external_signal_config.get('follower_ping_timeout', 2) self.sleep_time = self.external_signal_config.get('follower_sleep_time', 5) if self.mode == ExternalSignalModeType.follower and len(self.leaders_list) == 0: raise ValueError("You must specify at least 1 leader in follower mode.") # This is only used by the leader, the followers use the tokens specified # in each of the leaders # If you do not specify an API key in the config, one will be randomly # generated and logged on startup default_api_key = secrets.token_urlsafe(16) self.secret_api_key = self.external_signal_config.get('api_token', default_api_key) self.start_threaded_loop() self.start() def is_leader(self): """ Leader flag """ return self.enabled() and self.mode == ExternalSignalModeType.leader def enabled(self): """ Enabled flag """ return self.external_signal_config.get('enabled', False) def start_threaded_loop(self): """ Start the main internal loop in another thread to run coroutines """ self._loop = asyncio.new_event_loop() if not self._thread: self._thread = Thread(target=self._loop.run_forever) self._thread.start() self._running = True else: raise RuntimeError("A loop is already running") def submit_coroutine(self, coroutine: Coroutine): """ Submit a coroutine to the threaded loop """ if not self._running: raise RuntimeError("Cannot schedule new futures after shutdown") if not self._loop or not self._loop.is_running(): raise RuntimeError("Loop must be started before any function can" " be submitted") return asyncio.run_coroutine_threadsafe(coroutine, self._loop) def start(self): """ Start the controller main loop """ self._main_task = self.submit_coroutine(self.main()) async def shutdown(self): """ Shutdown all tasks and close up """ logger.info("Stopping rpc.externalsignalcontroller") # Flip running flag self._running = False # Cancel sub tasks for task in self._sub_tasks: task.cancel() # Then disconnect all channels await self.channel_manager.disconnect_all() def cleanup(self) -> None: """ Cleanup pending module resources. """ if self._thread: if self._loop.is_running(): self._main_task.cancel() self._thread.join() async def main(self): """ Main coro Start the loop based on what mode we're in """ try: if self.mode == ExternalSignalModeType.leader: logger.info("Starting rpc.externalsignalcontroller in Leader mode") await self.run_leader_mode() elif self.mode == ExternalSignalModeType.follower: logger.info("Starting rpc.externalsignalcontroller in Follower mode") await self.run_follower_mode() except asyncio.CancelledError: # We're cancelled await self.shutdown() except Exception as e: # Log the error logger.error(f"Exception occurred in main task: {e}") logger.exception(e) finally: # This coroutine is the last thing to be ended, so it should stop the loop self._loop.stop() def log_api_token(self): """ Log the API token """ logger.info("-" * 15) logger.info(f"API_KEY: {self.secret_api_key}") logger.info("-" * 15) def send_msg(self, msg: MessageType) -> None: """ Support RPC calls """ if msg["type"] == RPCMessageType.EMIT_DATA: message = msg.get("message") if message: self.send_message(message) else: logger.error(f"Message is empty! {msg}") def send_message(self, msg: MessageType) -> None: """ Broadcast message over all channels if there are any """ if self.channel_manager.has_channels(): self._send_message(msg) else: logger.debug("No listening followers, skipping...") pass def _send_message(self, msg: MessageType): """ Add data to the internal queue to be broadcasted. This func will block if the queue is full. This is meant to be called in the main thread. """ if self._queue: queue = self._queue.sync_q queue.put(msg) # This will block if the queue is full else: logger.warning("Can not send data, leader loop has not started yet!") async def send_initial_data(self, channel): logger.info("Sending initial data through channel") # We first send pairlist data # We should move this to a func in the RPC object initial_data = { "data_type": LeaderMessageType.pairlist, "data": self.freqtrade.pairlists.whitelist } await channel.send(initial_data) async def _handle_leader_message(self, message: MessageType): """ Handle message received from a Leader """ type = message.get("data_type") data = message.get("data") self._rpc._handle_emitted_data(type, data) # ---------------------------------------------------------------------- async def run_leader_mode(self): """ Main leader coroutine This starts all of the leader coros and registers the endpoint on the ApiServer """ self.register_leader_endpoint() self.log_api_token() self._sub_tasks = [ self._loop.create_task(self._broadcast_queue_data()) ] return await asyncio.gather(*self._sub_tasks) async def run_follower_mode(self): """ Main follower coroutine This starts all of the follower connection coros """ rpc_lock = asyncio.Lock() self._sub_tasks = [ self._loop.create_task(self._handle_leader_connection(leader, rpc_lock)) for leader in self.leaders_list ] return await asyncio.gather(*self._sub_tasks) async def _broadcast_queue_data(self): """ Loop over queue data and broadcast it """ # Instantiate the queue in this coroutine so it's attached to our loop self._queue = ThreadedQueue() async_queue = self._queue.async_q try: while self._running: # Get data from queue data = await async_queue.get() # Broadcast it to everyone await self.channel_manager.broadcast(data) # Sleep await asyncio.sleep(self.push_throttle_secs) except asyncio.CancelledError: # Silently stop pass async def get_api_token( self, websocket: FastAPIWebSocket, token: Union[str, None] = None ): """ Extract the API key from query param. Must match the set secret_api_key or the websocket connection will be closed. """ if token == self.secret_api_key: return token else: logger.info("Denying websocket request...") await websocket.close(code=status.WS_1008_POLICY_VIOLATION) def register_leader_endpoint(self, path: str = "/signals/ws"): """ Attach and start the main leader loop to the ApiServer :param path: The endpoint path """ if not self.api_server: raise RuntimeError("The leader needs the ApiServer to be active") # The endpoint function for running the main leader loop @self.api_server.app.websocket(path) async def leader_endpoint( websocket: FastAPIWebSocket, api_key: str = Depends(self.get_api_token) ): await self.leader_endpoint_loop(websocket) async def leader_endpoint_loop(self, websocket: FastAPIWebSocket): """ The WebSocket endpoint served by the ApiServer. This handles connections, and adding them to the channel manager. """ try: if is_websocket_alive(websocket): logger.info(f"Follower connected - {websocket.client}") channel = await self.channel_manager.on_connect(websocket) # Send initial data here # Data is being broadcasted right away as soon as startup, # we may not have to send initial data at all. Further testing # required. await self.send_initial_data(channel) # Keep connection open until explicitly closed, and sleep try: while not channel.is_closed(): request = await channel.recv() logger.info(f"Follower request - {request}") except WebSocketDisconnect: # Handle client disconnects logger.info(f"Follower disconnected - {websocket.client}") await self.channel_manager.on_disconnect(websocket) except Exception as e: logger.info(f"Follower connection failed - {websocket.client}") logger.exception(e) # Handle cases like - # RuntimeError('Cannot call "send" once a closed message has been sent') await self.channel_manager.on_disconnect(websocket) except Exception: logger.error(f"Failed to serve - {websocket.client}") await self.channel_manager.on_disconnect(websocket) async def _handle_leader_connection(self, leader, lock): """ Given a leader, connect and wait on data. If connection is lost, it will attempt to reconnect. """ try: url, token = leader["url"], leader["api_token"] websocket_url = f"{url}?token={token}" logger.info(f"Attempting to connect to Leader at: {url}") while True: try: async with websockets.connect(websocket_url) as ws: channel = await self.channel_manager.on_connect(ws) logger.info(f"Connection to Leader at {url} successful") while True: try: data = await asyncio.wait_for( channel.recv(), timeout=self.reply_timeout ) except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed): # We haven't received data yet. Check the connection and continue. try: # ping ping = await channel.ping() await asyncio.wait_for(ping, timeout=self.ping_timeout) logger.debug(f"Connection to {url} still alive...") continue except Exception: logger.info( f"Ping error {url} - retrying in {self.sleep_time}s") asyncio.sleep(self.sleep_time) break async with lock: # Acquire lock so only 1 coro handling at a time # as we might call the RPC module in the main thread await self._handle_leader_message(data) except socket.gaierror: logger.info(f"Socket error - retrying connection in {self.sleep_time}s") await asyncio.sleep(self.sleep_time) continue except ConnectionRefusedError: logger.info(f"Connection Refused - retrying connection in {self.sleep_time}s") await asyncio.sleep(self.sleep_time) continue except asyncio.CancelledError: pass