From 9f6bba40af1a407f190a89f5c0c8b4e3f528ba46 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Thu, 18 Aug 2022 10:39:20 -0600 Subject: [PATCH] initial concept for replicate, basic leader and follower logic --- .gitignore | 2 + freqtrade/__init__.py | 2 +- freqtrade/constants.py | 28 + freqtrade/enums/__init__.py | 1 + freqtrade/enums/replicate.py | 11 + freqtrade/enums/rpcmessagetype.py | 2 + freqtrade/rpc/api_server/webserver.py | 8 +- freqtrade/rpc/replicate/__init__.py | 385 ++++++++++++++ freqtrade/rpc/replicate/channel.py | 106 ++++ freqtrade/rpc/replicate/proxy.py | 60 +++ freqtrade/rpc/replicate/serializer.py | 42 ++ freqtrade/rpc/replicate/thread_queue.py | 650 ++++++++++++++++++++++++ freqtrade/rpc/replicate/types.py | 9 + freqtrade/rpc/replicate/utils.py | 10 + freqtrade/rpc/rpc_manager.py | 13 + requirements-replicate.txt | 5 + 16 files changed, 1330 insertions(+), 4 deletions(-) create mode 100644 freqtrade/enums/replicate.py create mode 100644 freqtrade/rpc/replicate/__init__.py create mode 100644 freqtrade/rpc/replicate/channel.py create mode 100644 freqtrade/rpc/replicate/proxy.py create mode 100644 freqtrade/rpc/replicate/serializer.py create mode 100644 freqtrade/rpc/replicate/thread_queue.py create mode 100644 freqtrade/rpc/replicate/types.py create mode 100644 freqtrade/rpc/replicate/utils.py create mode 100644 requirements-replicate.txt diff --git a/.gitignore b/.gitignore index e400c01f5..df2121990 100644 --- a/.gitignore +++ b/.gitignore @@ -113,3 +113,5 @@ target/ !config_examples/config_full.example.json !config_examples/config_kraken.example.json !config_examples/config_freqai.example.json + +*-config.json diff --git a/freqtrade/__init__.py b/freqtrade/__init__.py index 2572c03f1..9e022b2d9 100644 --- a/freqtrade/__init__.py +++ b/freqtrade/__init__.py @@ -1,5 +1,5 @@ """ Freqtrade bot """ -__version__ = '2022.8.dev' +__version__ = '2022.8.1+pubsub' # Metadata 1.2 mandates PEP 440 version, but 'develop' is not if 'dev' in __version__: try: diff --git a/freqtrade/constants.py b/freqtrade/constants.py index ddbc84fa9..416b4646f 100644 --- a/freqtrade/constants.py +++ b/freqtrade/constants.py @@ -60,6 +60,8 @@ USERPATH_FREQAIMODELS = 'freqaimodels' TELEGRAM_SETTING_OPTIONS = ['on', 'off', 'silent'] WEBHOOK_FORMAT_OPTIONS = ['form', 'json', 'raw'] +FOLLOWER_MODE_OPTIONS = ['follower', 'leader'] + ENV_VAR_PREFIX = 'FREQTRADE__' NON_OPEN_EXCHANGE_STATES = ('cancelled', 'canceled', 'closed', 'expired') @@ -242,6 +244,7 @@ CONF_SCHEMA = { 'exchange': {'$ref': '#/definitions/exchange'}, 'edge': {'$ref': '#/definitions/edge'}, 'freqai': {'$ref': '#/definitions/freqai'}, + 'replicate': {'$ref': '#/definitions/replicate'}, 'experimental': { 'type': 'object', 'properties': { @@ -483,6 +486,31 @@ CONF_SCHEMA = { }, 'required': ['process_throttle_secs', 'allowed_risk'] }, + 'replicate': { + 'type': 'object', + 'properties': { + 'enabled': {'type': 'boolean', 'default': False}, + 'mode': { + 'type': 'string', + 'enum': FOLLOWER_MODE_OPTIONS + }, + 'api_key': {'type': 'string', 'default': ''}, + 'leaders': { + 'type': 'array', + 'items': { + 'type': 'object', + 'properties': { + 'url': {'type': 'string', 'default': ''}, + 'token': {'type': 'string', 'default': ''}, + } + } + }, + 'follower_reply_timeout': {'type': 'integer'}, + 'follower_sleep_time': {'type': 'integer'}, + 'follower_ping_timeout': {'type': 'integer'}, + }, + 'required': ['mode'] + }, "freqai": { "type": "object", "properties": { diff --git a/freqtrade/enums/__init__.py b/freqtrade/enums/__init__.py index e50ebc4a4..e1057208a 100644 --- a/freqtrade/enums/__init__.py +++ b/freqtrade/enums/__init__.py @@ -5,6 +5,7 @@ from freqtrade.enums.exitchecktuple import ExitCheckTuple from freqtrade.enums.exittype import ExitType from freqtrade.enums.marginmode import MarginMode from freqtrade.enums.ordertypevalue import OrderTypeValues +from freqtrade.enums.replicate import LeaderMessageType, ReplicateModeType from freqtrade.enums.rpcmessagetype import RPCMessageType from freqtrade.enums.runmode import NON_UTIL_MODES, OPTIMIZE_MODES, TRADING_MODES, RunMode from freqtrade.enums.signaltype import SignalDirection, SignalTagType, SignalType diff --git a/freqtrade/enums/replicate.py b/freqtrade/enums/replicate.py new file mode 100644 index 000000000..d55d45b45 --- /dev/null +++ b/freqtrade/enums/replicate.py @@ -0,0 +1,11 @@ +from enum import Enum + + +class ReplicateModeType(str, Enum): + leader = "leader" + follower = "follower" + + +class LeaderMessageType(str, Enum): + Pairlist = "pairlist" + Dataframe = "dataframe" diff --git a/freqtrade/enums/rpcmessagetype.py b/freqtrade/enums/rpcmessagetype.py index 415d8f18c..d5b3ce89c 100644 --- a/freqtrade/enums/rpcmessagetype.py +++ b/freqtrade/enums/rpcmessagetype.py @@ -19,6 +19,8 @@ class RPCMessageType(Enum): STRATEGY_MSG = 'strategy_msg' + EMIT_DATA = 'emit_data' + def __repr__(self): return self.value diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index 0da129583..c98fb9fd4 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -54,7 +54,11 @@ class ApiServer(RPCHandler): ApiServer.__initialized = False return ApiServer.__instance - def __init__(self, config: Dict[str, Any], standalone: bool = False) -> None: + def __init__( + self, + config: Dict[str, Any], + standalone: bool = False, + ) -> None: ApiServer._config = config if self.__initialized and (standalone or self._standalone): return @@ -71,8 +75,6 @@ class ApiServer(RPCHandler): ) self.configure_app(self.app, self._config) - self.start_api() - def add_rpc_handler(self, rpc: RPC): """ Attach rpc handler diff --git a/freqtrade/rpc/replicate/__init__.py b/freqtrade/rpc/replicate/__init__.py new file mode 100644 index 000000000..d725a4a90 --- /dev/null +++ b/freqtrade/rpc/replicate/__init__.py @@ -0,0 +1,385 @@ +""" +This module manages replicate mode communication +""" +import asyncio +import logging +import secrets +import socket +from threading import Thread +from typing import Any, Coroutine, Dict, Union + +import websockets +from fastapi import Depends +from fastapi import WebSocket as FastAPIWebSocket +from fastapi import WebSocketDisconnect, status + +from freqtrade.enums import LeaderMessageType, ReplicateModeType, RPCMessageType +from freqtrade.rpc import RPC, RPCHandler +from freqtrade.rpc.replicate.channel import ChannelManager +from freqtrade.rpc.replicate.thread_queue import Queue as ThreadedQueue +from freqtrade.rpc.replicate.utils import is_websocket_alive + + +logger = logging.getLogger(__name__) + + +class ReplicateController(RPCHandler): + """ This class handles all websocket communication """ + + def __init__( + self, + rpc: RPC, + config: Dict[str, Any], + api_server: Union[Any, None] = None + ) -> None: + """ + Init the ReplicateRPC class, and init the super class RPCHandler + :param rpc: instance of RPC Helper class + :param config: Configuration object + :return: None + """ + super().__init__(rpc, config) + + self.api_server = api_server + + if not self.api_server: + raise RuntimeError("The API server must be enabled for replicate to work") + + self._loop = None + self._running = False + self._thread = None + self._queue = None + + self.channel_manager = ChannelManager() + + self.replicate_config = config.get('replicate', {}) + + # What the config should look like + # "replicate": { + # "enabled": true, + # "mode": "follower", + # "leaders": [ + # { + # "url": "ws://localhost:8080/replicate/ws", + # "token": "test" + # } + # ] + # } + + # "replicate": { + # "enabled": true, + # "mode": "leader", + # "api_key": "test" + # } + + self.mode = ReplicateModeType[self.replicate_config.get('mode', 'leader').lower()] + + self.leaders_list = self.replicate_config.get('leaders', []) + self.push_throttle_secs = self.replicate_config.get('push_throttle_secs', 1) + + self.reply_timeout = self.replicate_config.get('follower_reply_timeout', 10) + self.ping_timeout = self.replicate_config.get('follower_ping_timeout', 2) + self.sleep_time = self.replicate_config.get('follower_sleep_time', 1) + + if self.mode == ReplicateModeType.follower and len(self.leaders_list) == 0: + raise ValueError("You must specify at least 1 leader in follower mode.") + + # This is only used by the leader, the followers use the tokens specified + # in each of the leaders + # If you do not specify an API key in the config, one will be randomly + # generated and logged on startup + default_api_key = secrets.token_urlsafe(16) + self.secret_api_key = self.replicate_config.get('api_key', default_api_key) + + self.start_threaded_loop() + + if self.mode == ReplicateModeType.follower: + self.start_follower_mode() + elif self.mode == ReplicateModeType.leader: + self.start_leader_mode() + + def start_threaded_loop(self): + """ + Start the main internal loop in another thread to run coroutines + """ + self._loop = asyncio.new_event_loop() + + if not self._thread: + self._thread = Thread(target=self._loop.run_forever) + self._thread.start() + self._running = True + else: + raise RuntimeError("A loop is already running") + + def submit_coroutine(self, coroutine: Coroutine): + """ + Submit a coroutine to the threaded loop + """ + if not self._running: + raise RuntimeError("Cannot schedule new futures after shutdown") + + if not self._loop or not self._loop.is_running(): + raise RuntimeError("Loop must be started before any function can" + " be submitted") + + logger.debug(f"Running coroutine {repr(coroutine)} in loop") + try: + return asyncio.run_coroutine_threadsafe(coroutine, self._loop) + except Exception as e: + logger.error(f"Error running coroutine - {str(e)}") + return None + + def cleanup(self) -> None: + """ + Cleanup pending module resources. + """ + if self._thread: + if self._loop.is_running(): + + self._running = False + + # Tell all coroutines submitted to the loop they're cancelled + pending = asyncio.all_tasks(loop=self._loop) + for task in pending: + task.cancel() + + self._loop.call_soon_threadsafe(self.channel_manager.disconnect_all) + # This must be called threadsafe, otherwise would not work + self._loop.call_soon_threadsafe(self._loop.stop) + + self._thread.join() + + def send_msg(self, msg: Dict[str, Any]) -> None: + """ Push message through """ + + if msg["type"] == RPCMessageType.EMIT_DATA: + self._send_message( + { + "type": msg["data_type"], + "content": msg["data"] + } + ) + + # ----------------------- LEADER LOGIC ------------------------------ + + def start_leader_mode(self): + """ + Register the endpoint and start the leader loop + """ + + logger.info("Running rpc.replicate in Leader mode") + logger.info("-" * 15) + logger.info(f"API_KEY: {self.secret_api_key}") + logger.info("-" * 15) + + self.register_leader_endpoint() + self.submit_coroutine(self.leader_loop()) + + async def leader_loop(self): + """ + Main leader coroutine + At the moment this just broadcasts data that's in the queue to the followers + """ + try: + await self._broadcast_queue_data() + except Exception as e: + logger.error("Exception occurred in leader loop: ") + logger.exception(e) + + def _send_message(self, data: Dict[Any, Any]): + """ + Add data to the internal queue to be broadcasted. This func will block + if the queue is full. This is meant to be called in the main thread. + """ + + if self._queue: + self._queue.put(data) + else: + logger.warning("Can not send data, leader loop has not started yet!") + + async def _broadcast_queue_data(self): + """ + Loop over queue data and broadcast it + """ + # Instantiate the queue in this coroutine so it's attached to our loop + self._queue = ThreadedQueue() + async_queue = self._queue.async_q + + try: + while self._running: + # Get data from queue + data = await async_queue.get() + + # Broadcast it to everyone + await self.channel_manager.broadcast(data) + + # Sleep + await asyncio.sleep(self.push_throttle_secs) + except asyncio.CancelledError: + # Silently stop + pass + + async def get_api_token( + self, + websocket: FastAPIWebSocket, + token: Union[str, None] = None + ): + """ + Extract the API key from query param. Must match the + set secret_api_key or the websocket connection will be closed. + """ + if token == self.secret_api_key: + return token + else: + logger.info("Denying websocket request...") + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + + def register_leader_endpoint(self, path: str = "/replicate/ws"): + """ + Attach and start the main leader loop to the ApiServer + + :param path: The endpoint path + """ + + if not self.api_server: + raise RuntimeError("The leader needs the ApiServer to be active") + + # The endpoint function for running the main leader loop + @self.api_server.app.websocket(path) + async def leader_endpoint( + websocket: FastAPIWebSocket, + api_key: str = Depends(self.get_api_token) + ): + await self.leader_endpoint_loop(websocket) + + async def leader_endpoint_loop(self, websocket: FastAPIWebSocket): + """ + The WebSocket endpoint served by the ApiServer. This handles connections, + and adding them to the channel manager. + """ + try: + if is_websocket_alive(websocket): + logger.info(f"Follower connected - {websocket.client}") + channel = await self.channel_manager.on_connect(websocket) + + # Send initial data here + + # Keep connection open until explicitly closed, and sleep + try: + while not channel.is_closed(): + await channel.recv() + + except WebSocketDisconnect: + # Handle client disconnects + logger.info(f"Follower disconnected - {websocket.client}") + await self.channel_manager.on_disconnect(websocket) + except Exception as e: + logger.info(f"Follower connection failed - {websocket.client}") + logger.exception(e) + # Handle cases like - + # RuntimeError('Cannot call "send" once a closed message has been sent') + await self.channel_manager.on_disconnect(websocket) + + except Exception: + logger.error(f"Failed to serve - {websocket.client}") + await self.channel_manager.on_disconnect(websocket) + + # -------------------------------FOLLOWER LOGIC---------------------------- + + def start_follower_mode(self): + """ + Start the ReplicateController in Follower mode + """ + logger.info("Starting rpc.replicate in Follower mode") + + self.submit_coroutine(self.follower_loop()) + + async def follower_loop(self): + """ + Main follower coroutine + + This starts all of the leader connection coros + """ + try: + await self._connect_to_leaders() + except Exception as e: + logger.error("Exception occurred in follower loop: ") + logger.exception(e) + + async def _connect_to_leaders(self): + rpc_lock = asyncio.Lock() + + logger.info("Starting connections to Leaders...") + await asyncio.wait( + [ + self._handle_leader_connection(leader, rpc_lock) + for leader in self.leaders_list + ] + ) + + async def _handle_leader_connection(self, leader, lock): + """ + Given a leader, connect and wait on data. If connection is lost, + it will attempt to reconnect. + """ + url, token = leader["url"], leader["token"] + + websocket_url = f"{url}?token={token}" + + logger.info(f"Attempting to connect to leader at: {url}") + # TODO: limit the amount of connection retries + while True: + try: + async with websockets.connect(websocket_url) as ws: + channel = await self.channel_manager.on_connect(ws) + while True: + try: + data = await asyncio.wait_for( + channel.recv(), + timeout=self.reply_timeout + ) + except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed): + # We haven't received data yet. Just check the connection and continue. + try: + # ping + ping = await channel.ping() + await asyncio.wait_for(ping, timeout=self.ping_timeout) + logger.info(f"Connection to {url} still alive...") + continue + except Exception: + logger.info(f"Ping error {url} - retrying in {self.sleep_time}s") + asyncio.sleep(self.sleep_time) + break + + with lock: + # Should we have a lock here? + await self._handle_leader_message(data) + + except socket.gaierror: + logger.info(f"Socket error - retrying connection in {self.sleep_time}s") + await asyncio.sleep(self.sleep_time) + continue + except ConnectionRefusedError: + logger.info(f"Connection Refused - retrying connection in {self.sleep_time}s") + await asyncio.sleep(self.sleep_time) + continue + + async def _handle_leader_message(self, message): + type = message.get("type") + + message_type_handlers = { + LeaderMessageType.Pairlist.value: self._handle_pairlist_message, + LeaderMessageType.Dataframe.value: self._handle_dataframe_message + } + + handler = message_type_handlers.get(type, self._handle_default_message) + return await handler(message) + + async def _handle_default_message(self, message): + logger.info(f"Default message handled: {message}") + + async def _handle_pairlist_message(self, message): + logger.info(f"Pairlist message handled: {message}") + + async def _handle_dataframe_message(self, message): + logger.info(f"Dataframe message handled: {message}") diff --git a/freqtrade/rpc/replicate/channel.py b/freqtrade/rpc/replicate/channel.py new file mode 100644 index 000000000..9950742da --- /dev/null +++ b/freqtrade/rpc/replicate/channel.py @@ -0,0 +1,106 @@ +from typing import Type + +from freqtrade.rpc.replicate.proxy import WebSocketProxy +from freqtrade.rpc.replicate.serializer import JSONWebSocketSerializer, WebSocketSerializer +from freqtrade.rpc.replicate.types import WebSocketType + + +class WebSocketChannel: + """ + Object to help facilitate managing a websocket connection + """ + + def __init__( + self, + websocket: WebSocketType, + serializer_cls: Type[WebSocketSerializer] = JSONWebSocketSerializer + ): + # The WebSocket object + self._websocket = WebSocketProxy(websocket) + # The Serializing class for the WebSocket object + self._serializer_cls = serializer_cls + + # Internal event to signify a closed websocket + self._closed = False + + # Wrap the WebSocket in the Serializing class + self._wrapped_ws = self._serializer_cls(self._websocket) + + async def send(self, data): + """ + Send data on the wrapped websocket + """ + await self._wrapped_ws.send(data) + + async def recv(self): + """ + Receive data on the wrapped websocket + """ + return await self._wrapped_ws.recv() + + async def ping(self): + """ + Ping the websocket + """ + return await self._websocket.ping() + + async def close(self): + """ + Close the WebSocketChannel + """ + + self._closed = True + + def is_closed(self): + return self._closed + + +class ChannelManager: + def __init__(self): + self.channels = dict() + + async def on_connect(self, websocket: WebSocketType): + """ + Wrap websocket connection into Channel and add to list + + :param websocket: The WebSocket object to attach to the Channel + """ + if hasattr(websocket, "accept"): + try: + await websocket.accept() + except RuntimeError: + # The connection was closed before we could accept it + return + + ws_channel = WebSocketChannel(websocket) + self.channels[websocket] = ws_channel + + return ws_channel + + async def on_disconnect(self, websocket: WebSocketType): + """ + Call close on the channel if it's not, and remove from channel list + + :param websocket: The WebSocket objet attached to the Channel + """ + if websocket in self.channels.keys(): + channel = self.channels[websocket] + if not channel.is_closed(): + await channel.close() + del channel + + async def disconnect_all(self): + """ + Disconnect all Channels + """ + for websocket in self.channels.keys(): + await self.on_disconnect(websocket) + + async def broadcast(self, data): + """ + Broadcast data on all Channels + + :param data: The data to send + """ + for channel in self.channels.values(): + await channel.send(data) diff --git a/freqtrade/rpc/replicate/proxy.py b/freqtrade/rpc/replicate/proxy.py new file mode 100644 index 000000000..b2173670b --- /dev/null +++ b/freqtrade/rpc/replicate/proxy.py @@ -0,0 +1,60 @@ +from typing import TYPE_CHECKING, Union + +from fastapi import WebSocket as FastAPIWebSocket +from websockets import WebSocketClientProtocol as WebSocket + + +if TYPE_CHECKING: + from freqtrade.rpc.replicate.types import WebSocketType + + +class WebSocketProxy: + """ + WebSocketProxy object to bring the FastAPIWebSocket and websockets.WebSocketClientProtocol + under the same API + """ + + def __init__(self, websocket: WebSocketType): + self._websocket: Union[FastAPIWebSocket, WebSocket] = websocket + + async def send(self, data): + """ + Send data on the wrapped websocket + """ + if hasattr(self._websocket, "send_bytes"): + await self._websocket.send_bytes(data) + else: + await self._websocket.send(data) + + async def recv(self): + """ + Receive data on the wrapped websocket + """ + if hasattr(self._websocket, "receive_bytes"): + return await self._websocket.receive_bytes() + else: + return await self._websocket.recv() + + async def ping(self): + """ + Ping the websocket, not supported by FastAPI WebSockets + """ + if hasattr(self._websocket, "ping"): + return await self._websocket.ping() + return False + + async def close(self, code: int = 1000): + """ + Close the websocket connection, only supported by FastAPI WebSockets + """ + if hasattr(self._websocket, "close"): + return await self._websocket.close(code) + pass + + async def accept(self): + """ + Accept the WebSocket connection, only support by FastAPI WebSockets + """ + if hasattr(self._websocket, "accept"): + return await self._websocket.accept() + pass diff --git a/freqtrade/rpc/replicate/serializer.py b/freqtrade/rpc/replicate/serializer.py new file mode 100644 index 000000000..ae5e57b95 --- /dev/null +++ b/freqtrade/rpc/replicate/serializer.py @@ -0,0 +1,42 @@ +import json +from abc import ABC, abstractmethod + +from freqtrade.rpc.replicate.proxy import WebSocketProxy + + +class WebSocketSerializer(ABC): + def __init__(self, websocket: WebSocketProxy): + self._websocket: WebSocketProxy = websocket + + @abstractmethod + def _serialize(self, data): + raise NotImplementedError() + + @abstractmethod + def _deserialize(self, data): + raise NotImplementedError() + + async def send(self, data: bytes): + await self._websocket.send(self._serialize(data)) + + async def recv(self) -> bytes: + data = await self._websocket.recv() + + return self._deserialize(data) + + async def close(self, code: int = 1000): + await self._websocket.close(code) + +# Going to explore using MsgPack as the serialization, +# as that might be the best method for sending pandas +# dataframes over the wire + + +class JSONWebSocketSerializer(WebSocketSerializer): + def _serialize(self, data: bytes) -> bytes: + # json expects string not bytes + return json.dumps(data.decode()).encode() + + def _deserialize(self, data: bytes) -> bytes: + # The WebSocketSerializer gives bytes not string + return json.loads(data).encode() diff --git a/freqtrade/rpc/replicate/thread_queue.py b/freqtrade/rpc/replicate/thread_queue.py new file mode 100644 index 000000000..88321321b --- /dev/null +++ b/freqtrade/rpc/replicate/thread_queue.py @@ -0,0 +1,650 @@ +import asyncio +import sys +import threading +from asyncio import QueueEmpty as AsyncQueueEmpty +from asyncio import QueueFull as AsyncQueueFull +from collections import deque +from heapq import heappop, heappush +from queue import Empty as SyncQueueEmpty +from queue import Full as SyncQueueFull +from typing import Any, Callable, Deque, Generic, List, Optional, Set, TypeVar + +from typing_extensions import Protocol + + +__version__ = "1.0.0" +__all__ = ( + "Queue", + "PriorityQueue", + "LifoQueue", + "SyncQueue", + "AsyncQueue", + "BaseQueue", +) + + +T = TypeVar("T") +OptFloat = Optional[float] + + +class BaseQueue(Protocol[T]): + @property + def maxsize(self) -> int: + ... + + @property + def closed(self) -> bool: + ... + + def task_done(self) -> None: + ... + + def qsize(self) -> int: + ... + + @property + def unfinished_tasks(self) -> int: + ... + + def empty(self) -> bool: + ... + + def full(self) -> bool: + ... + + def put_nowait(self, item: T) -> None: + ... + + def get_nowait(self) -> T: + ... + + +class SyncQueue(BaseQueue[T], Protocol[T]): + @property + def maxsize(self) -> int: + ... + + @property + def closed(self) -> bool: + ... + + def task_done(self) -> None: + ... + + def qsize(self) -> int: + ... + + @property + def unfinished_tasks(self) -> int: + ... + + def empty(self) -> bool: + ... + + def full(self) -> bool: + ... + + def put_nowait(self, item: T) -> None: + ... + + def get_nowait(self) -> T: + ... + + def put(self, item: T, block: bool = True, timeout: OptFloat = None) -> None: + ... + + def get(self, block: bool = True, timeout: OptFloat = None) -> T: + ... + + def join(self) -> None: + ... + + +class AsyncQueue(BaseQueue[T], Protocol[T]): + async def put(self, item: T) -> None: + ... + + async def get(self) -> T: + ... + + async def join(self) -> None: + ... + + +class Queue(Generic[T]): + def __init__(self, maxsize: int = 0) -> None: + self._loop = asyncio.get_running_loop() + self._maxsize = maxsize + + self._init(maxsize) + + self._unfinished_tasks = 0 + + self._sync_mutex = threading.Lock() + self._sync_not_empty = threading.Condition(self._sync_mutex) + self._sync_not_full = threading.Condition(self._sync_mutex) + self._all_tasks_done = threading.Condition(self._sync_mutex) + + self._async_mutex = asyncio.Lock() + if sys.version_info[:3] == (3, 10, 0): + # Workaround for Python 3.10 bug, see #358: + getattr(self._async_mutex, "_get_loop", lambda: None)() + self._async_not_empty = asyncio.Condition(self._async_mutex) + self._async_not_full = asyncio.Condition(self._async_mutex) + self._finished = asyncio.Event() + self._finished.set() + + self._closing = False + self._pending = set() # type: Set[asyncio.Future[Any]] + + def checked_call_soon_threadsafe( + callback: Callable[..., None], *args: Any + ) -> None: + try: + self._loop.call_soon_threadsafe(callback, *args) + except RuntimeError: + # swallowing agreed in #2 + pass + + self._call_soon_threadsafe = checked_call_soon_threadsafe + + def checked_call_soon(callback: Callable[..., None], *args: Any) -> None: + if not self._loop.is_closed(): + self._loop.call_soon(callback, *args) + + self._call_soon = checked_call_soon + + self._sync_queue = _SyncQueueProxy(self) + self._async_queue = _AsyncQueueProxy(self) + + def close(self) -> None: + with self._sync_mutex: + self._closing = True + for fut in self._pending: + fut.cancel() + self._finished.set() # unblocks all async_q.join() + self._all_tasks_done.notify_all() # unblocks all sync_q.join() + + async def wait_closed(self) -> None: + # should be called from loop after close(). + # Nobody should put/get at this point, + # so lock acquiring is not required + if not self._closing: + raise RuntimeError("Waiting for non-closed queue") + # give execution chances for the task-done callbacks + # of async tasks created inside + # _notify_async_not_empty, _notify_async_not_full + # methods. + await asyncio.sleep(0) + if not self._pending: + return + await asyncio.wait(self._pending) + + @property + def closed(self) -> bool: + return self._closing and not self._pending + + @property + def maxsize(self) -> int: + return self._maxsize + + @property + def sync_q(self) -> "_SyncQueueProxy[T]": + return self._sync_queue + + @property + def async_q(self) -> "_AsyncQueueProxy[T]": + return self._async_queue + + # Override these methods to implement other queue organizations + # (e.g. stack or priority queue). + # These will only be called with appropriate locks held + + def _init(self, maxsize: int) -> None: + self._queue = deque() # type: Deque[T] + + def _qsize(self) -> int: + return len(self._queue) + + # Put a new item in the queue + def _put(self, item: T) -> None: + self._queue.append(item) + + # Get an item from the queue + def _get(self) -> T: + return self._queue.popleft() + + def _put_internal(self, item: T) -> None: + self._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def _notify_sync_not_empty(self) -> None: + def f() -> None: + with self._sync_mutex: + self._sync_not_empty.notify() + + self._loop.run_in_executor(None, f) + + def _notify_sync_not_full(self) -> None: + def f() -> None: + with self._sync_mutex: + self._sync_not_full.notify() + + fut = asyncio.ensure_future(self._loop.run_in_executor(None, f)) + fut.add_done_callback(self._pending.discard) + self._pending.add(fut) + + def _notify_async_not_empty(self, *, threadsafe: bool) -> None: + async def f() -> None: + async with self._async_mutex: + self._async_not_empty.notify() + + def task_maker() -> None: + task = self._loop.create_task(f()) + task.add_done_callback(self._pending.discard) + self._pending.add(task) + + if threadsafe: + self._call_soon_threadsafe(task_maker) + else: + self._call_soon(task_maker) + + def _notify_async_not_full(self, *, threadsafe: bool) -> None: + async def f() -> None: + async with self._async_mutex: + self._async_not_full.notify() + + def task_maker() -> None: + task = self._loop.create_task(f()) + task.add_done_callback(self._pending.discard) + self._pending.add(task) + + if threadsafe: + self._call_soon_threadsafe(task_maker) + else: + self._call_soon(task_maker) + + def _check_closing(self) -> None: + if self._closing: + raise RuntimeError("Operation on the closed queue is forbidden") + + +class _SyncQueueProxy(SyncQueue[T]): + """Create a queue object with a given maximum size. + + If maxsize is <= 0, the queue size is infinite. + """ + + def __init__(self, parent: Queue[T]): + self._parent = parent + + @property + def maxsize(self) -> int: + return self._parent._maxsize + + @property + def closed(self) -> bool: + return self._parent.closed + + def task_done(self) -> None: + """Indicate that a formerly enqueued task is complete. + + Used by Queue consumer threads. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items + have been processed (meaning that a task_done() call was received + for every item that had been put() into the queue). + + Raises a ValueError if called more times than there were items + placed in the queue. + """ + self._parent._check_closing() + with self._parent._all_tasks_done: + unfinished = self._parent._unfinished_tasks - 1 + if unfinished <= 0: + if unfinished < 0: + raise ValueError("task_done() called too many times") + self._parent._all_tasks_done.notify_all() + self._parent._loop.call_soon_threadsafe(self._parent._finished.set) + self._parent._unfinished_tasks = unfinished + + def join(self) -> None: + """Blocks until all items in the Queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate the item was retrieved and all work on it is complete. + + When the count of unfinished tasks drops to zero, join() unblocks. + """ + self._parent._check_closing() + with self._parent._all_tasks_done: + while self._parent._unfinished_tasks: + self._parent._all_tasks_done.wait() + self._parent._check_closing() + + def qsize(self) -> int: + """Return the approximate size of the queue (not reliable!).""" + return self._parent._qsize() + + @property + def unfinished_tasks(self) -> int: + """Return the number of unfinished tasks.""" + return self._parent._unfinished_tasks + + def empty(self) -> bool: + """Return True if the queue is empty, False otherwise (not reliable!). + + This method is likely to be removed at some point. Use qsize() == 0 + as a direct substitute, but be aware that either approach risks a race + condition where a queue can grow before the result of empty() or + qsize() can be used. + + To create code that needs to wait for all queued tasks to be + completed, the preferred technique is to use the join() method. + """ + return not self._parent._qsize() + + def full(self) -> bool: + """Return True if the queue is full, False otherwise (not reliable!). + + This method is likely to be removed at some point. Use qsize() >= n + as a direct substitute, but be aware that either approach risks a race + condition where a queue can shrink before the result of full() or + qsize() can be used. + """ + return 0 < self._parent._maxsize <= self._parent._qsize() + + def put(self, item: T, block: bool = True, timeout: OptFloat = None) -> None: + """Put an item into the queue. + + If optional args 'block' is true and 'timeout' is None (the default), + block if necessary until a free slot is available. If 'timeout' is + a non-negative number, it blocks at most 'timeout' seconds and raises + the Full exception if no free slot was available within that time. + Otherwise ('block' is false), put an item on the queue if a free slot + is immediately available, else raise the Full exception ('timeout' + is ignored in that case). + """ + self._parent._check_closing() + with self._parent._sync_not_full: + if self._parent._maxsize > 0: + if not block: + if self._parent._qsize() >= self._parent._maxsize: + raise SyncQueueFull + elif timeout is None: + while self._parent._qsize() >= self._parent._maxsize: + self._parent._sync_not_full.wait() + elif timeout < 0: + raise ValueError("'timeout' must be a non-negative number") + else: + time = self._parent._loop.time + endtime = time() + timeout + while self._parent._qsize() >= self._parent._maxsize: + remaining = endtime - time() + if remaining <= 0.0: + raise SyncQueueFull + self._parent._sync_not_full.wait(remaining) + self._parent._put_internal(item) + self._parent._sync_not_empty.notify() + self._parent._notify_async_not_empty(threadsafe=True) + + def get(self, block: bool = True, timeout: OptFloat = None) -> T: + """Remove and return an item from the queue. + + If optional args 'block' is true and 'timeout' is None (the default), + block if necessary until an item is available. If 'timeout' is + a non-negative number, it blocks at most 'timeout' seconds and raises + the Empty exception if no item was available within that time. + Otherwise ('block' is false), return an item if one is immediately + available, else raise the Empty exception ('timeout' is ignored + in that case). + """ + self._parent._check_closing() + with self._parent._sync_not_empty: + if not block: + if not self._parent._qsize(): + raise SyncQueueEmpty + elif timeout is None: + while not self._parent._qsize(): + self._parent._sync_not_empty.wait() + elif timeout < 0: + raise ValueError("'timeout' must be a non-negative number") + else: + time = self._parent._loop.time + endtime = time() + timeout + while not self._parent._qsize(): + remaining = endtime - time() + if remaining <= 0.0: + raise SyncQueueEmpty + self._parent._sync_not_empty.wait(remaining) + item = self._parent._get() + self._parent._sync_not_full.notify() + self._parent._notify_async_not_full(threadsafe=True) + return item + + def put_nowait(self, item: T) -> None: + """Put an item into the queue without blocking. + + Only enqueue the item if a free slot is immediately available. + Otherwise raise the Full exception. + """ + return self.put(item, block=False) + + def get_nowait(self) -> T: + """Remove and return an item from the queue without blocking. + + Only get an item if one is immediately available. Otherwise + raise the Empty exception. + """ + return self.get(block=False) + + +class _AsyncQueueProxy(AsyncQueue[T]): + """Create a queue object with a given maximum size. + + If maxsize is <= 0, the queue size is infinite. + """ + + def __init__(self, parent: Queue[T]): + self._parent = parent + + @property + def closed(self) -> bool: + return self._parent.closed + + def qsize(self) -> int: + """Number of items in the queue.""" + return self._parent._qsize() + + @property + def unfinished_tasks(self) -> int: + """Return the number of unfinished tasks.""" + return self._parent._unfinished_tasks + + @property + def maxsize(self) -> int: + """Number of items allowed in the queue.""" + return self._parent._maxsize + + def empty(self) -> bool: + """Return True if the queue is empty, False otherwise.""" + return self.qsize() == 0 + + def full(self) -> bool: + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._parent._maxsize <= 0: + return False + else: + return self.qsize() >= self._parent._maxsize + + async def put(self, item: T) -> None: + """Put an item into the queue. + + Put an item into the queue. If the queue is full, wait until a free + slot is available before adding item. + + This method is a coroutine. + """ + self._parent._check_closing() + async with self._parent._async_not_full: + self._parent._sync_mutex.acquire() + locked = True + try: + if self._parent._maxsize > 0: + do_wait = True + while do_wait: + do_wait = self._parent._qsize() >= self._parent._maxsize + if do_wait: + locked = False + self._parent._sync_mutex.release() + await self._parent._async_not_full.wait() + self._parent._sync_mutex.acquire() + locked = True + + self._parent._put_internal(item) + self._parent._async_not_empty.notify() + self._parent._notify_sync_not_empty() + finally: + if locked: + self._parent._sync_mutex.release() + + def put_nowait(self, item: T) -> None: + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise QueueFull. + """ + self._parent._check_closing() + with self._parent._sync_mutex: + if self._parent._maxsize > 0: + if self._parent._qsize() >= self._parent._maxsize: + raise AsyncQueueFull + + self._parent._put_internal(item) + self._parent._notify_async_not_empty(threadsafe=False) + self._parent._notify_sync_not_empty() + + async def get(self) -> T: + """Remove and return an item from the queue. + + If queue is empty, wait until an item is available. + + This method is a coroutine. + """ + self._parent._check_closing() + async with self._parent._async_not_empty: + self._parent._sync_mutex.acquire() + locked = True + try: + do_wait = True + while do_wait: + do_wait = self._parent._qsize() == 0 + + if do_wait: + locked = False + self._parent._sync_mutex.release() + await self._parent._async_not_empty.wait() + self._parent._sync_mutex.acquire() + locked = True + + item = self._parent._get() + self._parent._async_not_full.notify() + self._parent._notify_sync_not_full() + return item + finally: + if locked: + self._parent._sync_mutex.release() + + def get_nowait(self) -> T: + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise QueueEmpty. + """ + self._parent._check_closing() + with self._parent._sync_mutex: + if self._parent._qsize() == 0: + raise AsyncQueueEmpty + + item = self._parent._get() + self._parent._notify_async_not_full(threadsafe=False) + self._parent._notify_sync_not_full() + return item + + def task_done(self) -> None: + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + self._parent._check_closing() + with self._parent._all_tasks_done: + if self._parent._unfinished_tasks <= 0: + raise ValueError("task_done() called too many times") + self._parent._unfinished_tasks -= 1 + if self._parent._unfinished_tasks == 0: + self._parent._finished.set() + self._parent._all_tasks_done.notify_all() + + async def join(self) -> None: + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer calls task_done() to + indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + while True: + with self._parent._sync_mutex: + self._parent._check_closing() + if self._parent._unfinished_tasks == 0: + break + await self._parent._finished.wait() + + +class PriorityQueue(Queue[T]): + """Variant of Queue that retrieves open entries in priority order + (lowest first). + + Entries are typically tuples of the form: (priority number, data). + + """ + + def _init(self, maxsize: int) -> None: + self._heap_queue = [] # type: List[T] + + def _qsize(self) -> int: + return len(self._heap_queue) + + def _put(self, item: T) -> None: + heappush(self._heap_queue, item) + + def _get(self) -> T: + return heappop(self._heap_queue) + + +class LifoQueue(Queue[T]): + """Variant of Queue that retrieves most recently added entries first.""" + + def _qsize(self) -> int: + return len(self._queue) + + def _put(self, item: T) -> None: + self._queue.append(item) + + def _get(self) -> T: + return self._queue.pop() diff --git a/freqtrade/rpc/replicate/types.py b/freqtrade/rpc/replicate/types.py new file mode 100644 index 000000000..5d8c158bd --- /dev/null +++ b/freqtrade/rpc/replicate/types.py @@ -0,0 +1,9 @@ +from typing import TypeVar + +from fastapi import WebSocket as FastAPIWebSocket +from websockets import WebSocketClientProtocol as WebSocket + +from freqtrade.rpc.replicate.channel import WebSocketProxy + + +WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket, WebSocketProxy) diff --git a/freqtrade/rpc/replicate/utils.py b/freqtrade/rpc/replicate/utils.py new file mode 100644 index 000000000..7b703810e --- /dev/null +++ b/freqtrade/rpc/replicate/utils.py @@ -0,0 +1,10 @@ +from starlette.websockets import WebSocket, WebSocketState + + +async def is_websocket_alive(ws: WebSocket) -> bool: + if ( + ws.application_state == WebSocketState.CONNECTED and + ws.client_state == WebSocketState.CONNECTED + ): + return True + return False diff --git a/freqtrade/rpc/rpc_manager.py b/freqtrade/rpc/rpc_manager.py index 3ccf23228..140431586 100644 --- a/freqtrade/rpc/rpc_manager.py +++ b/freqtrade/rpc/rpc_manager.py @@ -44,10 +44,23 @@ class RPCManager: if config.get('api_server', {}).get('enabled', False): logger.info('Enabling rpc.api_server') from freqtrade.rpc.api_server import ApiServer + + # Pass replicate_rpc as param or defer starting api_server + # until we register the replicate rpc enpoint? apiserver = ApiServer(config) apiserver.add_rpc_handler(self._rpc) self.registered_modules.append(apiserver) + # Enable Replicate mode + # For this to be enabled, the API server must also be enabled + if config.get('replicate', {}).get('enabled', False): + logger.info('Enabling rpc.replicate') + from freqtrade.rpc.replicate import ReplicateController + replicate_rpc = ReplicateController(self._rpc, config, apiserver) + self.registered_modules.append(replicate_rpc) + + apiserver.start_api() + def cleanup(self) -> None: """ Stops all enabled rpc modules """ logger.info('Cleaning up rpc modules ...') diff --git a/requirements-replicate.txt b/requirements-replicate.txt new file mode 100644 index 000000000..7ee351d9d --- /dev/null +++ b/requirements-replicate.txt @@ -0,0 +1,5 @@ +# Include all requirements to run the bot. +-r requirements.txt + +# Required for follower +websockets