minor improvements and pairlist data transmission

This commit is contained in:
Timothy Pogue 2022-08-19 00:06:19 -06:00
parent 9f6bba40af
commit 6834db11f3
9 changed files with 243 additions and 115 deletions

View File

@ -7,5 +7,4 @@ class ReplicateModeType(str, Enum):
class LeaderMessageType(str, Enum):
Pairlist = "pairlist"
Dataframe = "dataframe"
whitelist = "whitelist"

View File

@ -75,6 +75,8 @@ class FreqtradeBot(LoggingMixin):
PairLocks.timeframe = self.config['timeframe']
self.replicate_controller = None
# RPC runs in separate threads, can start handling external commands just after
# initialization, even before Freqtradebot has a chance to start its throttling,
# so anything in the Freqtradebot instance should be ready (initialized), including
@ -264,6 +266,17 @@ class FreqtradeBot(LoggingMixin):
# Extend active-pair whitelist with pairs of open trades
# It ensures that candle (OHLCV) data are downloaded for open trades as well
_whitelist.extend([trade.pair for trade in trades if trade.pair not in _whitelist])
# If replicate leader, broadcast whitelist data
if self.replicate_controller:
if self.replicate_controller.is_leader():
self.replicate_controller.send_message(
{
"data_type": "whitelist",
"data": _whitelist
}
)
return _whitelist
def get_free_open_trades(self) -> int:

View File

@ -0,0 +1,59 @@
"""
External Pair List provider
Provides pair list from Leader data
"""
import logging
from typing import Any, Dict, List
from freqtrade.plugins.pairlist.IPairList import IPairList
logger = logging.getLogger(__name__)
class ExternalPairList(IPairList):
def __init__(self, exchange, pairlistmanager,
config: Dict[str, Any], pairlistconfig: Dict[str, Any],
pairlist_pos: int) -> None:
super().__init__(exchange, pairlistmanager, config, pairlistconfig, pairlist_pos)
self._num_assets = self._pairlistconfig.get('num_assets')
self._allow_inactive = self._pairlistconfig.get('allow_inactive', False)
self._leader_pairs: List[str] = []
@property
def needstickers(self) -> bool:
"""
Boolean property defining if tickers are necessary.
If no Pairlist requires tickers, an empty Dict is passed
as tickers argument to filter_pairlist
"""
return False
def short_desc(self) -> str:
"""
Short whitelist method description - used for startup-messages
-> Please overwrite in subclasses
"""
return f"{self.name}"
def gen_pairlist(self, tickers: Dict) -> List[str]:
"""
Generate the pairlist
:param tickers: Tickers (from exchange.get_tickers()). May be cached.
:return: List of pairs
"""
pass
def filter_pairlist(self, pairlist: List[str], tickers: Dict) -> List[str]:
"""
Filters and sorts pairlist and returns the whitelist again.
Called on each bot iteration - please use internal caching if necessary
:param pairlist: pairlist to filter or sort
:param tickers: Tickers (from exchange.get_tickers()). May be cached.
:return: new whitelist
"""
pass

View File

@ -5,7 +5,7 @@ import asyncio
import logging
import secrets
import socket
from threading import Thread
from threading import Event, Thread
from typing import Any, Coroutine, Dict, Union
import websockets
@ -50,6 +50,9 @@ class ReplicateController(RPCHandler):
self._thread = None
self._queue = None
self._stop_event = Event()
self._follower_tasks = None
self.channel_manager = ChannelManager()
self.replicate_config = config.get('replicate', {})
@ -93,10 +96,7 @@ class ReplicateController(RPCHandler):
self.start_threaded_loop()
if self.mode == ReplicateModeType.follower:
self.start_follower_mode()
elif self.mode == ReplicateModeType.leader:
self.start_leader_mode()
self.start()
def start_threaded_loop(self):
"""
@ -129,6 +129,29 @@ class ReplicateController(RPCHandler):
logger.error(f"Error running coroutine - {str(e)}")
return None
async def main_loop(self):
"""
Main loop coro
Start the loop based on what mode we're in
"""
try:
if self.mode == ReplicateModeType.leader:
await self.leader_loop()
elif self.mode == ReplicateModeType.follower:
await self.follower_loop()
except asyncio.CancelledError:
pass
finally:
self._loop.stop()
def start(self):
"""
Start the controller main loop
"""
self.submit_coroutine(self.main_loop())
def cleanup(self) -> None:
"""
Cleanup pending module resources.
@ -144,27 +167,62 @@ class ReplicateController(RPCHandler):
task.cancel()
self._loop.call_soon_threadsafe(self.channel_manager.disconnect_all)
# This must be called threadsafe, otherwise would not work
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
def send_msg(self, msg: Dict[str, Any]) -> None:
""" Push message through """
"""
Support RPC calls
"""
if msg["type"] == RPCMessageType.EMIT_DATA:
self._send_message(
self.send_message(
{
"type": msg["data_type"],
"content": msg["data"]
"data_type": msg.get("data_type"),
"data": msg.get("data")
}
)
def send_message(self, msg: Dict[str, Any]) -> None:
""" Push message through """
if self.channel_manager.has_channels():
self._send_message(msg)
else:
logger.debug("No listening followers, skipping...")
pass
def _send_message(self, msg: Dict[Any, Any]):
"""
Add data to the internal queue to be broadcasted. This func will block
if the queue is full. This is meant to be called in the main thread.
"""
if self._queue:
queue = self._queue.sync_q
queue.put(msg)
else:
logger.warning("Can not send data, leader loop has not started yet!")
def is_leader(self):
"""
Leader flag
"""
return self.enabled() and self.mode == ReplicateModeType.leader
def enabled(self):
"""
Enabled flag
"""
return self.replicate_config.get('enabled', False)
# ----------------------- LEADER LOGIC ------------------------------
def start_leader_mode(self):
async def leader_loop(self):
"""
Register the endpoint and start the leader loop
Main leader coroutine
This starts all of the leader coros and registers the endpoint on
the ApiServer
"""
logger.info("Running rpc.replicate in Leader mode")
@ -173,30 +231,13 @@ class ReplicateController(RPCHandler):
logger.info("-" * 15)
self.register_leader_endpoint()
self.submit_coroutine(self.leader_loop())
async def leader_loop(self):
"""
Main leader coroutine
At the moment this just broadcasts data that's in the queue to the followers
"""
try:
await self._broadcast_queue_data()
except Exception as e:
logger.error("Exception occurred in leader loop: ")
logger.exception(e)
def _send_message(self, data: Dict[Any, Any]):
"""
Add data to the internal queue to be broadcasted. This func will block
if the queue is full. This is meant to be called in the main thread.
"""
if self._queue:
self._queue.put(data)
else:
logger.warning("Can not send data, leader loop has not started yet!")
async def _broadcast_queue_data(self):
"""
Loop over queue data and broadcast it
@ -210,6 +251,8 @@ class ReplicateController(RPCHandler):
# Get data from queue
data = await async_queue.get()
logger.info(f"Found data - broadcasting: {data}")
# Broadcast it to everyone
await self.channel_manager.broadcast(data)
@ -263,6 +306,9 @@ class ReplicateController(RPCHandler):
channel = await self.channel_manager.on_connect(websocket)
# Send initial data here
# Data is being broadcasted right away as soon as startup,
# we may not have to send initial data at all. Further testing
# required.
# Keep connection open until explicitly closed, and sleep
try:
@ -286,20 +332,15 @@ class ReplicateController(RPCHandler):
# -------------------------------FOLLOWER LOGIC----------------------------
def start_follower_mode(self):
"""
Start the ReplicateController in Follower mode
"""
logger.info("Starting rpc.replicate in Follower mode")
self.submit_coroutine(self.follower_loop())
async def follower_loop(self):
"""
Main follower coroutine
This starts all of the leader connection coros
This starts all of the follower connection coros
"""
logger.info("Starting rpc.replicate in Follower mode")
try:
await self._connect_to_leaders()
except Exception as e:
@ -307,21 +348,26 @@ class ReplicateController(RPCHandler):
logger.exception(e)
async def _connect_to_leaders(self):
"""
For each leader in `self.leaders_list` create a connection and
listen for data.
"""
rpc_lock = asyncio.Lock()
logger.info("Starting connections to Leaders...")
await asyncio.wait(
[
self._handle_leader_connection(leader, rpc_lock)
self.follower_tasks = [
self._loop.create_task(self._handle_leader_connection(leader, rpc_lock))
for leader in self.leaders_list
]
)
return await asyncio.gather(*self.follower_tasks, return_exceptions=True)
async def _handle_leader_connection(self, leader, lock):
"""
Given a leader, connect and wait on data. If connection is lost,
it will attempt to reconnect.
"""
try:
url, token = leader["url"], leader["token"]
websocket_url = f"{url}?token={token}"
@ -339,20 +385,22 @@ class ReplicateController(RPCHandler):
timeout=self.reply_timeout
)
except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
# We haven't received data yet. Just check the connection and continue.
# We haven't received data yet. Check the connection and continue.
try:
# ping
ping = await channel.ping()
await asyncio.wait_for(ping, timeout=self.ping_timeout)
logger.info(f"Connection to {url} still alive...")
logger.debug(f"Connection to {url} still alive...")
continue
except Exception:
logger.info(f"Ping error {url} - retrying in {self.sleep_time}s")
logger.info(
f"Ping error {url} - retrying in {self.sleep_time}s")
asyncio.sleep(self.sleep_time)
break
with lock:
# Should we have a lock here?
async with lock:
# Acquire lock so only 1 coro handling at a time
# as we might call the RPC module in the main thread
await self._handle_leader_message(data)
except socket.gaierror:
@ -364,22 +412,12 @@ class ReplicateController(RPCHandler):
await asyncio.sleep(self.sleep_time)
continue
except asyncio.CancelledError:
pass
async def _handle_leader_message(self, message):
type = message.get("type")
type = message.get('data_type')
data = message.get('data')
message_type_handlers = {
LeaderMessageType.Pairlist.value: self._handle_pairlist_message,
LeaderMessageType.Dataframe.value: self._handle_dataframe_message
}
handler = message_type_handlers.get(type, self._handle_default_message)
return await handler(message)
async def _handle_default_message(self, message):
logger.info(f"Default message handled: {message}")
async def _handle_pairlist_message(self, message):
logger.info(f"Pairlist message handled: {message}")
async def _handle_dataframe_message(self, message):
logger.info(f"Dataframe message handled: {message}")
if type == LeaderMessageType.whitelist:
logger.info(f"Received whitelist from Leader: {data}")

View File

@ -1,3 +1,4 @@
import logging
from typing import Type
from freqtrade.rpc.replicate.proxy import WebSocketProxy
@ -5,6 +6,9 @@ from freqtrade.rpc.replicate.serializer import JSONWebSocketSerializer, WebSocke
from freqtrade.rpc.replicate.types import WebSocketType
logger = logging.getLogger(__name__)
class WebSocketChannel:
"""
Object to help facilitate managing a websocket connection
@ -85,9 +89,12 @@ class ChannelManager:
"""
if websocket in self.channels.keys():
channel = self.channels[websocket]
logger.debug(f"Disconnecting channel - {channel}")
if not channel.is_closed():
await channel.close()
del channel
del self.channels[websocket]
async def disconnect_all(self):
"""
@ -102,5 +109,15 @@ class ChannelManager:
:param data: The data to send
"""
for channel in self.channels.values():
for websocket, channel in self.channels.items():
try:
await channel.send(data)
except RuntimeError:
# Handle cannot send after close cases
await self.on_disconnect(websocket)
def has_channels(self):
"""
Flag for more than 0 channels
"""
return len(self.channels) > 0

View File

@ -1,10 +1,8 @@
from typing import TYPE_CHECKING, Union
from typing import Union
from fastapi import WebSocket as FastAPIWebSocket
from websockets import WebSocketClientProtocol as WebSocket
if TYPE_CHECKING:
from freqtrade.rpc.replicate.types import WebSocketType
@ -21,6 +19,9 @@ class WebSocketProxy:
"""
Send data on the wrapped websocket
"""
if isinstance(data, str):
data = data.encode()
if hasattr(self._websocket, "send_bytes"):
await self._websocket.send_bytes(data)
else:

View File

@ -33,10 +33,10 @@ class WebSocketSerializer(ABC):
class JSONWebSocketSerializer(WebSocketSerializer):
def _serialize(self, data: bytes) -> bytes:
def _serialize(self, data):
# json expects string not bytes
return json.dumps(data.decode()).encode()
return json.dumps(data)
def _deserialize(self, data: bytes) -> bytes:
def _deserialize(self, data):
# The WebSocketSerializer gives bytes not string
return json.loads(data).encode()
return json.loads(data)

View File

@ -3,7 +3,5 @@ from typing import TypeVar
from fastapi import WebSocket as FastAPIWebSocket
from websockets import WebSocketClientProtocol as WebSocket
from freqtrade.rpc.replicate.channel import WebSocketProxy
WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket, WebSocketProxy)
WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket)

View File

@ -59,6 +59,9 @@ class RPCManager:
replicate_rpc = ReplicateController(self._rpc, config, apiserver)
self.registered_modules.append(replicate_rpc)
# Attach the controller to FreqTrade
freqtrade.replicate_controller = replicate_rpc
apiserver.start_api()
def cleanup(self) -> None: