DataFrame transmission, strategy follower logic

This commit is contained in:
Timothy Pogue
2022-08-21 22:45:36 -06:00
parent 739b68f8fd
commit 6f5478cc02
13 changed files with 332 additions and 79 deletions

View File

@@ -5,6 +5,7 @@ import asyncio
import logging
import secrets
import socket
import traceback
from threading import Event, Thread
from typing import Any, Coroutine, Dict, Union
@@ -17,6 +18,7 @@ from freqtrade.enums import LeaderMessageType, ReplicateModeType, RPCMessageType
from freqtrade.rpc import RPC, RPCHandler
from freqtrade.rpc.replicate.channel import ChannelManager
from freqtrade.rpc.replicate.thread_queue import Queue as ThreadedQueue
from freqtrade.rpc.replicate.types import MessageType
from freqtrade.rpc.replicate.utils import is_websocket_alive
@@ -79,11 +81,11 @@ class ReplicateController(RPCHandler):
self.mode = ReplicateModeType[self.replicate_config.get('mode', 'leader').lower()]
self.leaders_list = self.replicate_config.get('leaders', [])
self.push_throttle_secs = self.replicate_config.get('push_throttle_secs', 1)
self.push_throttle_secs = self.replicate_config.get('push_throttle_secs', 0.1)
self.reply_timeout = self.replicate_config.get('follower_reply_timeout', 10)
self.ping_timeout = self.replicate_config.get('follower_ping_timeout', 2)
self.sleep_time = self.replicate_config.get('follower_sleep_time', 1)
self.sleep_time = self.replicate_config.get('follower_sleep_time', 5)
if self.mode == ReplicateModeType.follower and len(self.leaders_list) == 0:
raise ValueError("You must specify at least 1 leader in follower mode.")
@@ -143,6 +145,8 @@ class ReplicateController(RPCHandler):
except asyncio.CancelledError:
pass
except Exception:
pass
finally:
self._loop.stop()
@@ -170,22 +174,19 @@ class ReplicateController(RPCHandler):
self._thread.join()
def send_msg(self, msg: Dict[str, Any]) -> None:
def send_msg(self, msg: MessageType) -> None:
"""
Support RPC calls
"""
if msg["type"] == RPCMessageType.EMIT_DATA:
self.send_message(
{
"data_type": msg.get("data_type"),
"data": msg.get("data")
}
)
message = msg.get("message")
if message:
self.send_message(message)
else:
logger.error(f"Message is empty! {msg}")
def send_message(self, msg: Dict[str, Any]) -> None:
""" Push message through """
# We should probably do some type of schema validation here
def send_message(self, msg: MessageType) -> None:
""" Broadcast message over all channels if there are any """
if self.channel_manager.has_channels():
self._send_message(msg)
@@ -193,12 +194,11 @@ class ReplicateController(RPCHandler):
logger.debug("No listening followers, skipping...")
pass
def _send_message(self, msg: Dict[Any, Any]):
def _send_message(self, msg: MessageType):
"""
Add data to the internal queue to be broadcasted. This func will block
if the queue is full. This is meant to be called in the main thread.
"""
if self._queue:
queue = self._queue.sync_q
queue.put(msg) # This will block if the queue is full
@@ -226,7 +226,6 @@ class ReplicateController(RPCHandler):
This starts all of the leader coros and registers the endpoint on
the ApiServer
"""
logger.info("Running rpc.replicate in Leader mode")
logger.info("-" * 15)
logger.info(f"API_KEY: {self.secret_api_key}")
@@ -253,16 +252,17 @@ class ReplicateController(RPCHandler):
# Get data from queue
data = await async_queue.get()
logger.info(f"Found data - broadcasting: {data}")
# Broadcast it to everyone
await self.channel_manager.broadcast(data)
# Sleep
await asyncio.sleep(self.push_throttle_secs)
except asyncio.CancelledError:
# Silently stop
pass
except Exception as e:
logger.exception(e)
async def get_api_token(
self,
@@ -285,7 +285,6 @@ class ReplicateController(RPCHandler):
:param path: The endpoint path
"""
if not self.api_server:
raise RuntimeError("The leader needs the ApiServer to be active")
@@ -312,10 +311,13 @@ class ReplicateController(RPCHandler):
# we may not have to send initial data at all. Further testing
# required.
await self.send_initial_data(channel)
# Keep connection open until explicitly closed, and sleep
try:
while not channel.is_closed():
await channel.recv()
request = await channel.recv()
logger.info(f"Follower request - {request}")
except WebSocketDisconnect:
# Handle client disconnects
@@ -332,6 +334,17 @@ class ReplicateController(RPCHandler):
logger.error(f"Failed to serve - {websocket.client}")
await self.channel_manager.on_disconnect(websocket)
async def send_initial_data(self, channel):
logger.info("Sending initial data through channel")
# We first send pairlist data
initial_data = {
"data_type": LeaderMessageType.pairlist,
"data": self.freqtrade.pairlists.whitelist
}
await channel.send(initial_data)
# -------------------------------FOLLOWER LOGIC----------------------------
async def follower_loop(self):
@@ -340,18 +353,27 @@ class ReplicateController(RPCHandler):
This starts all of the follower connection coros
"""
logger.info("Starting rpc.replicate in Follower mode")
try:
results = await self._connect_to_leaders()
except Exception as e:
logger.error("Exception occurred in Follower loop: ")
logger.exception(e)
finally:
for result in results:
if isinstance(result, Exception):
logger.debug(f"Exception in Follower loop: {result}")
responses = await self._connect_to_leaders()
# Eventually add the ability to send requests to the Leader
# await self._send_requests()
for result in responses:
if isinstance(result, Exception):
logger.debug(f"Exception in Follower loop: {result}")
traceback_message = ''.join(traceback.format_tb(result.__traceback__))
logger.error(traceback_message)
async def _handle_leader_message(self, message: MessageType):
"""
Handle message received from a Leader
"""
type = message.get("data_type")
data = message.get("data")
self._rpc._handle_emitted_data(type, data)
async def _connect_to_leaders(self):
"""
@@ -375,7 +397,6 @@ class ReplicateController(RPCHandler):
"""
try:
url, token = leader["url"], leader["token"]
websocket_url = f"{url}?token={token}"
logger.info(f"Attempting to connect to Leader at: {url}")
@@ -384,6 +405,7 @@ class ReplicateController(RPCHandler):
try:
async with websockets.connect(websocket_url) as ws:
channel = await self.channel_manager.on_connect(ws)
logger.info(f"Connection to Leader at {url} successful")
while True:
try:
data = await asyncio.wait_for(
@@ -420,13 +442,3 @@ class ReplicateController(RPCHandler):
except asyncio.CancelledError:
pass
async def _handle_leader_message(self, message: Dict[str, Any]):
type = message.get('data_type')
data = message.get('data')
logger.info(f"Received message from Leader: {type} - {data}")
if type == LeaderMessageType.pairlist:
# Add the data to the ExternalPairlist
self.freqtrade.pairlists._pairlist_handlers[0].add_pairlist_data(data)

View File

@@ -2,7 +2,7 @@ import logging
from typing import Type
from freqtrade.rpc.replicate.proxy import WebSocketProxy
from freqtrade.rpc.replicate.serializer import JSONWebSocketSerializer, WebSocketSerializer
from freqtrade.rpc.replicate.serializer import MsgPackWebSocketSerializer, WebSocketSerializer
from freqtrade.rpc.replicate.types import WebSocketType
@@ -17,7 +17,7 @@ class WebSocketChannel:
def __init__(
self,
websocket: WebSocketType,
serializer_cls: Type[WebSocketSerializer] = JSONWebSocketSerializer
serializer_cls: Type[WebSocketSerializer] = MsgPackWebSocketSerializer
):
# The WebSocket object
self._websocket = WebSocketProxy(websocket)
@@ -34,6 +34,7 @@ class WebSocketChannel:
"""
Send data on the wrapped websocket
"""
# logger.info(f"Serialized Send - {self._wrapped_ws._serialize(data)}")
await self._wrapped_ws.send(data)
async def recv(self):
@@ -116,6 +117,17 @@ class ChannelManager:
# Handle cannot send after close cases
await self.on_disconnect(websocket)
async def send_direct(self, channel, data):
"""
Send data directly through direct_channel only
:param direct_channel: The WebSocketChannel object to send data through
:param data: The data to send
"""
# We iterate over the channels to get reference to the websocket object
# so we can disconnect incase of failure
await channel.send(data)
def has_channels(self):
"""
Flag for more than 0 channels

View File

@@ -1,9 +1,16 @@
import json
import logging
from abc import ABC, abstractmethod
import msgpack
import orjson
from freqtrade.rpc.replicate.proxy import WebSocketProxy
logger = logging.getLogger(__name__)
class WebSocketSerializer(ABC):
def __init__(self, websocket: WebSocketProxy):
self._websocket: WebSocketProxy = websocket
@@ -34,9 +41,25 @@ class WebSocketSerializer(ABC):
class JSONWebSocketSerializer(WebSocketSerializer):
def _serialize(self, data):
# json expects string not bytes
return json.dumps(data)
def _deserialize(self, data):
# The WebSocketSerializer gives bytes not string
return json.loads(data)
class ORJSONWebSocketSerializer(WebSocketSerializer):
ORJSON_OPTIONS = orjson.OPT_NAIVE_UTC | orjson.OPT_SERIALIZE_NUMPY
def _serialize(self, data):
return orjson.dumps(data, option=self.ORJSON_OPTIONS)
def _deserialize(self, data):
return orjson.loads(data, option=self.ORJSON_OPTIONS)
class MsgPackWebSocketSerializer(WebSocketSerializer):
def _serialize(self, data):
return msgpack.packb(data, use_bin_type=True)
def _deserialize(self, data):
return msgpack.unpackb(data, raw=False)

View File

@@ -1,7 +1,8 @@
from typing import TypeVar
from typing import Any, Dict, TypeVar
from fastapi import WebSocket as FastAPIWebSocket
from websockets import WebSocketClientProtocol as WebSocket
WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket)
MessageType = Dict[str, Any]

View File

@@ -19,12 +19,12 @@ from freqtrade.configuration.timerange import TimeRange
from freqtrade.constants import CANCEL_REASON, DATETIME_PRINT_FORMAT
from freqtrade.data.history import load_data
from freqtrade.data.metrics import calculate_max_drawdown
from freqtrade.enums import (CandleType, ExitCheckTuple, ExitType, SignalDirection, State,
TradingMode)
from freqtrade.enums import (CandleType, ExitCheckTuple, ExitType, LeaderMessageType,
SignalDirection, State, TradingMode)
from freqtrade.exceptions import ExchangeError, PricingError
from freqtrade.exchange import timeframe_to_minutes, timeframe_to_msecs
from freqtrade.loggers import bufferHandler
from freqtrade.misc import decimals_per_coin, shorten_date
from freqtrade.misc import decimals_per_coin, json_to_dataframe, shorten_date
from freqtrade.persistence import PairLocks, Trade
from freqtrade.persistence.models import PairLock
from freqtrade.plugins.pairlist.pairlist_helpers import expand_pairlist
@@ -1089,3 +1089,36 @@ class RPC:
'last_process_loc': last_p.astimezone(tzlocal()).strftime(DATETIME_PRINT_FORMAT),
'last_process_ts': int(last_p.timestamp()),
}
def _handle_emitted_data(self, type, data):
"""
Handles the emitted data from the Leaders
:param type: The data_type of the data
:param data: The data
"""
logger.debug(f"Handling emitted data of type ({type})")
if type == LeaderMessageType.pairlist:
pairlist = data
logger.debug(pairlist)
# Add the pairlist data to the ExternalPairList object
external_pairlist = self._freqtrade.pairlists._pairlist_handlers[0]
external_pairlist.add_pairlist_data(pairlist)
elif type == LeaderMessageType.analyzed_df:
# Convert the dataframe back from json
key, value = data["key"], data["value"]
pair, timeframe, candle_type = key
dataframe = json_to_dataframe(value)
dataprovider = self._freqtrade.dataprovider
logger.debug(f"Received analyzed dataframe for {pair}")
logger.debug(dataframe.tail())
# Add the dataframe to the dataprovider
dataprovider.add_external_df(pair, timeframe, dataframe, candle_type)

View File

@@ -20,6 +20,7 @@ class RPCManager:
def __init__(self, freqtrade) -> None:
""" Initializes all enabled rpc modules """
self.registered_modules: List[RPCHandler] = []
self._freqtrade = freqtrade
self._rpc = RPC(freqtrade)
config = freqtrade.config
# Enable telegram
@@ -82,7 +83,8 @@ class RPCManager:
'status': 'stopping bot'
}
"""
logger.info('Sending rpc message: %s', msg)
if msg.get("type") != RPCMessageType.EMIT_DATA:
logger.info('Sending rpc message: %s', msg)
if 'pair' in msg:
msg.update({
'base_currency': self._rpc._freqtrade.exchange.get_pair_base_currency(msg['pair'])
@@ -141,3 +143,12 @@ class RPCManager:
'type': RPCMessageType.STARTUP,
'status': f'Using Protections: \n{prots}'
})
def emit_data(self, data: Dict[str, Any]):
"""
Send a message via RPC with type RPCMessageType.EMIT_DATA
"""
self.send_msg({
"type": RPCMessageType.EMIT_DATA,
"message": data
})