DataFrame transmission, strategy follower logic
This commit is contained in:
@@ -5,6 +5,7 @@ import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
import socket
|
||||
import traceback
|
||||
from threading import Event, Thread
|
||||
from typing import Any, Coroutine, Dict, Union
|
||||
|
||||
@@ -17,6 +18,7 @@ from freqtrade.enums import LeaderMessageType, ReplicateModeType, RPCMessageType
|
||||
from freqtrade.rpc import RPC, RPCHandler
|
||||
from freqtrade.rpc.replicate.channel import ChannelManager
|
||||
from freqtrade.rpc.replicate.thread_queue import Queue as ThreadedQueue
|
||||
from freqtrade.rpc.replicate.types import MessageType
|
||||
from freqtrade.rpc.replicate.utils import is_websocket_alive
|
||||
|
||||
|
||||
@@ -79,11 +81,11 @@ class ReplicateController(RPCHandler):
|
||||
self.mode = ReplicateModeType[self.replicate_config.get('mode', 'leader').lower()]
|
||||
|
||||
self.leaders_list = self.replicate_config.get('leaders', [])
|
||||
self.push_throttle_secs = self.replicate_config.get('push_throttle_secs', 1)
|
||||
self.push_throttle_secs = self.replicate_config.get('push_throttle_secs', 0.1)
|
||||
|
||||
self.reply_timeout = self.replicate_config.get('follower_reply_timeout', 10)
|
||||
self.ping_timeout = self.replicate_config.get('follower_ping_timeout', 2)
|
||||
self.sleep_time = self.replicate_config.get('follower_sleep_time', 1)
|
||||
self.sleep_time = self.replicate_config.get('follower_sleep_time', 5)
|
||||
|
||||
if self.mode == ReplicateModeType.follower and len(self.leaders_list) == 0:
|
||||
raise ValueError("You must specify at least 1 leader in follower mode.")
|
||||
@@ -143,6 +145,8 @@ class ReplicateController(RPCHandler):
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._loop.stop()
|
||||
|
||||
@@ -170,22 +174,19 @@ class ReplicateController(RPCHandler):
|
||||
|
||||
self._thread.join()
|
||||
|
||||
def send_msg(self, msg: Dict[str, Any]) -> None:
|
||||
def send_msg(self, msg: MessageType) -> None:
|
||||
"""
|
||||
Support RPC calls
|
||||
"""
|
||||
if msg["type"] == RPCMessageType.EMIT_DATA:
|
||||
self.send_message(
|
||||
{
|
||||
"data_type": msg.get("data_type"),
|
||||
"data": msg.get("data")
|
||||
}
|
||||
)
|
||||
message = msg.get("message")
|
||||
if message:
|
||||
self.send_message(message)
|
||||
else:
|
||||
logger.error(f"Message is empty! {msg}")
|
||||
|
||||
def send_message(self, msg: Dict[str, Any]) -> None:
|
||||
""" Push message through """
|
||||
|
||||
# We should probably do some type of schema validation here
|
||||
def send_message(self, msg: MessageType) -> None:
|
||||
""" Broadcast message over all channels if there are any """
|
||||
|
||||
if self.channel_manager.has_channels():
|
||||
self._send_message(msg)
|
||||
@@ -193,12 +194,11 @@ class ReplicateController(RPCHandler):
|
||||
logger.debug("No listening followers, skipping...")
|
||||
pass
|
||||
|
||||
def _send_message(self, msg: Dict[Any, Any]):
|
||||
def _send_message(self, msg: MessageType):
|
||||
"""
|
||||
Add data to the internal queue to be broadcasted. This func will block
|
||||
if the queue is full. This is meant to be called in the main thread.
|
||||
"""
|
||||
|
||||
if self._queue:
|
||||
queue = self._queue.sync_q
|
||||
queue.put(msg) # This will block if the queue is full
|
||||
@@ -226,7 +226,6 @@ class ReplicateController(RPCHandler):
|
||||
This starts all of the leader coros and registers the endpoint on
|
||||
the ApiServer
|
||||
"""
|
||||
|
||||
logger.info("Running rpc.replicate in Leader mode")
|
||||
logger.info("-" * 15)
|
||||
logger.info(f"API_KEY: {self.secret_api_key}")
|
||||
@@ -253,16 +252,17 @@ class ReplicateController(RPCHandler):
|
||||
# Get data from queue
|
||||
data = await async_queue.get()
|
||||
|
||||
logger.info(f"Found data - broadcasting: {data}")
|
||||
|
||||
# Broadcast it to everyone
|
||||
await self.channel_manager.broadcast(data)
|
||||
|
||||
# Sleep
|
||||
await asyncio.sleep(self.push_throttle_secs)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Silently stop
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
async def get_api_token(
|
||||
self,
|
||||
@@ -285,7 +285,6 @@ class ReplicateController(RPCHandler):
|
||||
|
||||
:param path: The endpoint path
|
||||
"""
|
||||
|
||||
if not self.api_server:
|
||||
raise RuntimeError("The leader needs the ApiServer to be active")
|
||||
|
||||
@@ -312,10 +311,13 @@ class ReplicateController(RPCHandler):
|
||||
# we may not have to send initial data at all. Further testing
|
||||
# required.
|
||||
|
||||
await self.send_initial_data(channel)
|
||||
|
||||
# Keep connection open until explicitly closed, and sleep
|
||||
try:
|
||||
while not channel.is_closed():
|
||||
await channel.recv()
|
||||
request = await channel.recv()
|
||||
logger.info(f"Follower request - {request}")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
# Handle client disconnects
|
||||
@@ -332,6 +334,17 @@ class ReplicateController(RPCHandler):
|
||||
logger.error(f"Failed to serve - {websocket.client}")
|
||||
await self.channel_manager.on_disconnect(websocket)
|
||||
|
||||
async def send_initial_data(self, channel):
|
||||
logger.info("Sending initial data through channel")
|
||||
|
||||
# We first send pairlist data
|
||||
initial_data = {
|
||||
"data_type": LeaderMessageType.pairlist,
|
||||
"data": self.freqtrade.pairlists.whitelist
|
||||
}
|
||||
|
||||
await channel.send(initial_data)
|
||||
|
||||
# -------------------------------FOLLOWER LOGIC----------------------------
|
||||
|
||||
async def follower_loop(self):
|
||||
@@ -340,18 +353,27 @@ class ReplicateController(RPCHandler):
|
||||
|
||||
This starts all of the follower connection coros
|
||||
"""
|
||||
|
||||
logger.info("Starting rpc.replicate in Follower mode")
|
||||
|
||||
try:
|
||||
results = await self._connect_to_leaders()
|
||||
except Exception as e:
|
||||
logger.error("Exception occurred in Follower loop: ")
|
||||
logger.exception(e)
|
||||
finally:
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.debug(f"Exception in Follower loop: {result}")
|
||||
responses = await self._connect_to_leaders()
|
||||
|
||||
# Eventually add the ability to send requests to the Leader
|
||||
# await self._send_requests()
|
||||
|
||||
for result in responses:
|
||||
if isinstance(result, Exception):
|
||||
logger.debug(f"Exception in Follower loop: {result}")
|
||||
traceback_message = ''.join(traceback.format_tb(result.__traceback__))
|
||||
logger.error(traceback_message)
|
||||
|
||||
async def _handle_leader_message(self, message: MessageType):
|
||||
"""
|
||||
Handle message received from a Leader
|
||||
"""
|
||||
type = message.get("data_type")
|
||||
data = message.get("data")
|
||||
|
||||
self._rpc._handle_emitted_data(type, data)
|
||||
|
||||
async def _connect_to_leaders(self):
|
||||
"""
|
||||
@@ -375,7 +397,6 @@ class ReplicateController(RPCHandler):
|
||||
"""
|
||||
try:
|
||||
url, token = leader["url"], leader["token"]
|
||||
|
||||
websocket_url = f"{url}?token={token}"
|
||||
|
||||
logger.info(f"Attempting to connect to Leader at: {url}")
|
||||
@@ -384,6 +405,7 @@ class ReplicateController(RPCHandler):
|
||||
try:
|
||||
async with websockets.connect(websocket_url) as ws:
|
||||
channel = await self.channel_manager.on_connect(ws)
|
||||
logger.info(f"Connection to Leader at {url} successful")
|
||||
while True:
|
||||
try:
|
||||
data = await asyncio.wait_for(
|
||||
@@ -420,13 +442,3 @@ class ReplicateController(RPCHandler):
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _handle_leader_message(self, message: Dict[str, Any]):
|
||||
type = message.get('data_type')
|
||||
data = message.get('data')
|
||||
|
||||
logger.info(f"Received message from Leader: {type} - {data}")
|
||||
|
||||
if type == LeaderMessageType.pairlist:
|
||||
# Add the data to the ExternalPairlist
|
||||
self.freqtrade.pairlists._pairlist_handlers[0].add_pairlist_data(data)
|
||||
|
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Type
|
||||
|
||||
from freqtrade.rpc.replicate.proxy import WebSocketProxy
|
||||
from freqtrade.rpc.replicate.serializer import JSONWebSocketSerializer, WebSocketSerializer
|
||||
from freqtrade.rpc.replicate.serializer import MsgPackWebSocketSerializer, WebSocketSerializer
|
||||
from freqtrade.rpc.replicate.types import WebSocketType
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ class WebSocketChannel:
|
||||
def __init__(
|
||||
self,
|
||||
websocket: WebSocketType,
|
||||
serializer_cls: Type[WebSocketSerializer] = JSONWebSocketSerializer
|
||||
serializer_cls: Type[WebSocketSerializer] = MsgPackWebSocketSerializer
|
||||
):
|
||||
# The WebSocket object
|
||||
self._websocket = WebSocketProxy(websocket)
|
||||
@@ -34,6 +34,7 @@ class WebSocketChannel:
|
||||
"""
|
||||
Send data on the wrapped websocket
|
||||
"""
|
||||
# logger.info(f"Serialized Send - {self._wrapped_ws._serialize(data)}")
|
||||
await self._wrapped_ws.send(data)
|
||||
|
||||
async def recv(self):
|
||||
@@ -116,6 +117,17 @@ class ChannelManager:
|
||||
# Handle cannot send after close cases
|
||||
await self.on_disconnect(websocket)
|
||||
|
||||
async def send_direct(self, channel, data):
|
||||
"""
|
||||
Send data directly through direct_channel only
|
||||
|
||||
:param direct_channel: The WebSocketChannel object to send data through
|
||||
:param data: The data to send
|
||||
"""
|
||||
# We iterate over the channels to get reference to the websocket object
|
||||
# so we can disconnect incase of failure
|
||||
await channel.send(data)
|
||||
|
||||
def has_channels(self):
|
||||
"""
|
||||
Flag for more than 0 channels
|
||||
|
@@ -1,9 +1,16 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import msgpack
|
||||
import orjson
|
||||
|
||||
from freqtrade.rpc.replicate.proxy import WebSocketProxy
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSocketSerializer(ABC):
|
||||
def __init__(self, websocket: WebSocketProxy):
|
||||
self._websocket: WebSocketProxy = websocket
|
||||
@@ -34,9 +41,25 @@ class WebSocketSerializer(ABC):
|
||||
|
||||
class JSONWebSocketSerializer(WebSocketSerializer):
|
||||
def _serialize(self, data):
|
||||
# json expects string not bytes
|
||||
return json.dumps(data)
|
||||
|
||||
def _deserialize(self, data):
|
||||
# The WebSocketSerializer gives bytes not string
|
||||
return json.loads(data)
|
||||
|
||||
|
||||
class ORJSONWebSocketSerializer(WebSocketSerializer):
|
||||
ORJSON_OPTIONS = orjson.OPT_NAIVE_UTC | orjson.OPT_SERIALIZE_NUMPY
|
||||
|
||||
def _serialize(self, data):
|
||||
return orjson.dumps(data, option=self.ORJSON_OPTIONS)
|
||||
|
||||
def _deserialize(self, data):
|
||||
return orjson.loads(data, option=self.ORJSON_OPTIONS)
|
||||
|
||||
|
||||
class MsgPackWebSocketSerializer(WebSocketSerializer):
|
||||
def _serialize(self, data):
|
||||
return msgpack.packb(data, use_bin_type=True)
|
||||
|
||||
def _deserialize(self, data):
|
||||
return msgpack.unpackb(data, raw=False)
|
||||
|
@@ -1,7 +1,8 @@
|
||||
from typing import TypeVar
|
||||
from typing import Any, Dict, TypeVar
|
||||
|
||||
from fastapi import WebSocket as FastAPIWebSocket
|
||||
from websockets import WebSocketClientProtocol as WebSocket
|
||||
|
||||
|
||||
WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket)
|
||||
MessageType = Dict[str, Any]
|
||||
|
@@ -19,12 +19,12 @@ from freqtrade.configuration.timerange import TimeRange
|
||||
from freqtrade.constants import CANCEL_REASON, DATETIME_PRINT_FORMAT
|
||||
from freqtrade.data.history import load_data
|
||||
from freqtrade.data.metrics import calculate_max_drawdown
|
||||
from freqtrade.enums import (CandleType, ExitCheckTuple, ExitType, SignalDirection, State,
|
||||
TradingMode)
|
||||
from freqtrade.enums import (CandleType, ExitCheckTuple, ExitType, LeaderMessageType,
|
||||
SignalDirection, State, TradingMode)
|
||||
from freqtrade.exceptions import ExchangeError, PricingError
|
||||
from freqtrade.exchange import timeframe_to_minutes, timeframe_to_msecs
|
||||
from freqtrade.loggers import bufferHandler
|
||||
from freqtrade.misc import decimals_per_coin, shorten_date
|
||||
from freqtrade.misc import decimals_per_coin, json_to_dataframe, shorten_date
|
||||
from freqtrade.persistence import PairLocks, Trade
|
||||
from freqtrade.persistence.models import PairLock
|
||||
from freqtrade.plugins.pairlist.pairlist_helpers import expand_pairlist
|
||||
@@ -1089,3 +1089,36 @@ class RPC:
|
||||
'last_process_loc': last_p.astimezone(tzlocal()).strftime(DATETIME_PRINT_FORMAT),
|
||||
'last_process_ts': int(last_p.timestamp()),
|
||||
}
|
||||
|
||||
def _handle_emitted_data(self, type, data):
|
||||
"""
|
||||
Handles the emitted data from the Leaders
|
||||
|
||||
:param type: The data_type of the data
|
||||
:param data: The data
|
||||
"""
|
||||
logger.debug(f"Handling emitted data of type ({type})")
|
||||
|
||||
if type == LeaderMessageType.pairlist:
|
||||
pairlist = data
|
||||
|
||||
logger.debug(pairlist)
|
||||
|
||||
# Add the pairlist data to the ExternalPairList object
|
||||
external_pairlist = self._freqtrade.pairlists._pairlist_handlers[0]
|
||||
external_pairlist.add_pairlist_data(pairlist)
|
||||
|
||||
elif type == LeaderMessageType.analyzed_df:
|
||||
# Convert the dataframe back from json
|
||||
key, value = data["key"], data["value"]
|
||||
|
||||
pair, timeframe, candle_type = key
|
||||
dataframe = json_to_dataframe(value)
|
||||
|
||||
dataprovider = self._freqtrade.dataprovider
|
||||
|
||||
logger.debug(f"Received analyzed dataframe for {pair}")
|
||||
logger.debug(dataframe.tail())
|
||||
|
||||
# Add the dataframe to the dataprovider
|
||||
dataprovider.add_external_df(pair, timeframe, dataframe, candle_type)
|
||||
|
@@ -20,6 +20,7 @@ class RPCManager:
|
||||
def __init__(self, freqtrade) -> None:
|
||||
""" Initializes all enabled rpc modules """
|
||||
self.registered_modules: List[RPCHandler] = []
|
||||
self._freqtrade = freqtrade
|
||||
self._rpc = RPC(freqtrade)
|
||||
config = freqtrade.config
|
||||
# Enable telegram
|
||||
@@ -82,7 +83,8 @@ class RPCManager:
|
||||
'status': 'stopping bot'
|
||||
}
|
||||
"""
|
||||
logger.info('Sending rpc message: %s', msg)
|
||||
if msg.get("type") != RPCMessageType.EMIT_DATA:
|
||||
logger.info('Sending rpc message: %s', msg)
|
||||
if 'pair' in msg:
|
||||
msg.update({
|
||||
'base_currency': self._rpc._freqtrade.exchange.get_pair_base_currency(msg['pair'])
|
||||
@@ -141,3 +143,12 @@ class RPCManager:
|
||||
'type': RPCMessageType.STARTUP,
|
||||
'status': f'Using Protections: \n{prots}'
|
||||
})
|
||||
|
||||
def emit_data(self, data: Dict[str, Any]):
|
||||
"""
|
||||
Send a message via RPC with type RPCMessageType.EMIT_DATA
|
||||
"""
|
||||
self.send_msg({
|
||||
"type": RPCMessageType.EMIT_DATA,
|
||||
"message": data
|
||||
})
|
||||
|
Reference in New Issue
Block a user