minor improvements and pairlist data transmission

This commit is contained in:
Timothy Pogue 2022-08-19 00:06:19 -06:00
parent 9f6bba40af
commit 6834db11f3
9 changed files with 243 additions and 115 deletions

View File

@ -7,5 +7,4 @@ class ReplicateModeType(str, Enum):
class LeaderMessageType(str, Enum): class LeaderMessageType(str, Enum):
Pairlist = "pairlist" whitelist = "whitelist"
Dataframe = "dataframe"

View File

@ -75,6 +75,8 @@ class FreqtradeBot(LoggingMixin):
PairLocks.timeframe = self.config['timeframe'] PairLocks.timeframe = self.config['timeframe']
self.replicate_controller = None
# RPC runs in separate threads, can start handling external commands just after # RPC runs in separate threads, can start handling external commands just after
# initialization, even before Freqtradebot has a chance to start its throttling, # initialization, even before Freqtradebot has a chance to start its throttling,
# so anything in the Freqtradebot instance should be ready (initialized), including # so anything in the Freqtradebot instance should be ready (initialized), including
@ -264,6 +266,17 @@ class FreqtradeBot(LoggingMixin):
# Extend active-pair whitelist with pairs of open trades # Extend active-pair whitelist with pairs of open trades
# It ensures that candle (OHLCV) data are downloaded for open trades as well # It ensures that candle (OHLCV) data are downloaded for open trades as well
_whitelist.extend([trade.pair for trade in trades if trade.pair not in _whitelist]) _whitelist.extend([trade.pair for trade in trades if trade.pair not in _whitelist])
# If replicate leader, broadcast whitelist data
if self.replicate_controller:
if self.replicate_controller.is_leader():
self.replicate_controller.send_message(
{
"data_type": "whitelist",
"data": _whitelist
}
)
return _whitelist return _whitelist
def get_free_open_trades(self) -> int: def get_free_open_trades(self) -> int:

View File

@ -0,0 +1,59 @@
"""
External Pair List provider
Provides pair list from Leader data
"""
import logging
from typing import Any, Dict, List
from freqtrade.plugins.pairlist.IPairList import IPairList
logger = logging.getLogger(__name__)
class ExternalPairList(IPairList):
def __init__(self, exchange, pairlistmanager,
config: Dict[str, Any], pairlistconfig: Dict[str, Any],
pairlist_pos: int) -> None:
super().__init__(exchange, pairlistmanager, config, pairlistconfig, pairlist_pos)
self._num_assets = self._pairlistconfig.get('num_assets')
self._allow_inactive = self._pairlistconfig.get('allow_inactive', False)
self._leader_pairs: List[str] = []
@property
def needstickers(self) -> bool:
"""
Boolean property defining if tickers are necessary.
If no Pairlist requires tickers, an empty Dict is passed
as tickers argument to filter_pairlist
"""
return False
def short_desc(self) -> str:
"""
Short whitelist method description - used for startup-messages
-> Please overwrite in subclasses
"""
return f"{self.name}"
def gen_pairlist(self, tickers: Dict) -> List[str]:
"""
Generate the pairlist
:param tickers: Tickers (from exchange.get_tickers()). May be cached.
:return: List of pairs
"""
pass
def filter_pairlist(self, pairlist: List[str], tickers: Dict) -> List[str]:
"""
Filters and sorts pairlist and returns the whitelist again.
Called on each bot iteration - please use internal caching if necessary
:param pairlist: pairlist to filter or sort
:param tickers: Tickers (from exchange.get_tickers()). May be cached.
:return: new whitelist
"""
pass

View File

@ -5,7 +5,7 @@ import asyncio
import logging import logging
import secrets import secrets
import socket import socket
from threading import Thread from threading import Event, Thread
from typing import Any, Coroutine, Dict, Union from typing import Any, Coroutine, Dict, Union
import websockets import websockets
@ -50,6 +50,9 @@ class ReplicateController(RPCHandler):
self._thread = None self._thread = None
self._queue = None self._queue = None
self._stop_event = Event()
self._follower_tasks = None
self.channel_manager = ChannelManager() self.channel_manager = ChannelManager()
self.replicate_config = config.get('replicate', {}) self.replicate_config = config.get('replicate', {})
@ -93,10 +96,7 @@ class ReplicateController(RPCHandler):
self.start_threaded_loop() self.start_threaded_loop()
if self.mode == ReplicateModeType.follower: self.start()
self.start_follower_mode()
elif self.mode == ReplicateModeType.leader:
self.start_leader_mode()
def start_threaded_loop(self): def start_threaded_loop(self):
""" """
@ -129,6 +129,29 @@ class ReplicateController(RPCHandler):
logger.error(f"Error running coroutine - {str(e)}") logger.error(f"Error running coroutine - {str(e)}")
return None return None
async def main_loop(self):
"""
Main loop coro
Start the loop based on what mode we're in
"""
try:
if self.mode == ReplicateModeType.leader:
await self.leader_loop()
elif self.mode == ReplicateModeType.follower:
await self.follower_loop()
except asyncio.CancelledError:
pass
finally:
self._loop.stop()
def start(self):
"""
Start the controller main loop
"""
self.submit_coroutine(self.main_loop())
def cleanup(self) -> None: def cleanup(self) -> None:
""" """
Cleanup pending module resources. Cleanup pending module resources.
@ -144,27 +167,62 @@ class ReplicateController(RPCHandler):
task.cancel() task.cancel()
self._loop.call_soon_threadsafe(self.channel_manager.disconnect_all) self._loop.call_soon_threadsafe(self.channel_manager.disconnect_all)
# This must be called threadsafe, otherwise would not work
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join() self._thread.join()
def send_msg(self, msg: Dict[str, Any]) -> None: def send_msg(self, msg: Dict[str, Any]) -> None:
""" Push message through """ """
Support RPC calls
"""
if msg["type"] == RPCMessageType.EMIT_DATA: if msg["type"] == RPCMessageType.EMIT_DATA:
self._send_message( self.send_message(
{ {
"type": msg["data_type"], "data_type": msg.get("data_type"),
"content": msg["data"] "data": msg.get("data")
} }
) )
def send_message(self, msg: Dict[str, Any]) -> None:
""" Push message through """
if self.channel_manager.has_channels():
self._send_message(msg)
else:
logger.debug("No listening followers, skipping...")
pass
def _send_message(self, msg: Dict[Any, Any]):
"""
Add data to the internal queue to be broadcasted. This func will block
if the queue is full. This is meant to be called in the main thread.
"""
if self._queue:
queue = self._queue.sync_q
queue.put(msg)
else:
logger.warning("Can not send data, leader loop has not started yet!")
def is_leader(self):
"""
Leader flag
"""
return self.enabled() and self.mode == ReplicateModeType.leader
def enabled(self):
"""
Enabled flag
"""
return self.replicate_config.get('enabled', False)
# ----------------------- LEADER LOGIC ------------------------------ # ----------------------- LEADER LOGIC ------------------------------
def start_leader_mode(self): async def leader_loop(self):
""" """
Register the endpoint and start the leader loop Main leader coroutine
This starts all of the leader coros and registers the endpoint on
the ApiServer
""" """
logger.info("Running rpc.replicate in Leader mode") logger.info("Running rpc.replicate in Leader mode")
@ -173,30 +231,13 @@ class ReplicateController(RPCHandler):
logger.info("-" * 15) logger.info("-" * 15)
self.register_leader_endpoint() self.register_leader_endpoint()
self.submit_coroutine(self.leader_loop())
async def leader_loop(self):
"""
Main leader coroutine
At the moment this just broadcasts data that's in the queue to the followers
"""
try: try:
await self._broadcast_queue_data() await self._broadcast_queue_data()
except Exception as e: except Exception as e:
logger.error("Exception occurred in leader loop: ") logger.error("Exception occurred in leader loop: ")
logger.exception(e) logger.exception(e)
def _send_message(self, data: Dict[Any, Any]):
"""
Add data to the internal queue to be broadcasted. This func will block
if the queue is full. This is meant to be called in the main thread.
"""
if self._queue:
self._queue.put(data)
else:
logger.warning("Can not send data, leader loop has not started yet!")
async def _broadcast_queue_data(self): async def _broadcast_queue_data(self):
""" """
Loop over queue data and broadcast it Loop over queue data and broadcast it
@ -210,6 +251,8 @@ class ReplicateController(RPCHandler):
# Get data from queue # Get data from queue
data = await async_queue.get() data = await async_queue.get()
logger.info(f"Found data - broadcasting: {data}")
# Broadcast it to everyone # Broadcast it to everyone
await self.channel_manager.broadcast(data) await self.channel_manager.broadcast(data)
@ -263,6 +306,9 @@ class ReplicateController(RPCHandler):
channel = await self.channel_manager.on_connect(websocket) channel = await self.channel_manager.on_connect(websocket)
# Send initial data here # Send initial data here
# Data is being broadcasted right away as soon as startup,
# we may not have to send initial data at all. Further testing
# required.
# Keep connection open until explicitly closed, and sleep # Keep connection open until explicitly closed, and sleep
try: try:
@ -286,20 +332,15 @@ class ReplicateController(RPCHandler):
# -------------------------------FOLLOWER LOGIC---------------------------- # -------------------------------FOLLOWER LOGIC----------------------------
def start_follower_mode(self):
"""
Start the ReplicateController in Follower mode
"""
logger.info("Starting rpc.replicate in Follower mode")
self.submit_coroutine(self.follower_loop())
async def follower_loop(self): async def follower_loop(self):
""" """
Main follower coroutine Main follower coroutine
This starts all of the leader connection coros This starts all of the follower connection coros
""" """
logger.info("Starting rpc.replicate in Follower mode")
try: try:
await self._connect_to_leaders() await self._connect_to_leaders()
except Exception as e: except Exception as e:
@ -307,79 +348,76 @@ class ReplicateController(RPCHandler):
logger.exception(e) logger.exception(e)
async def _connect_to_leaders(self): async def _connect_to_leaders(self):
"""
For each leader in `self.leaders_list` create a connection and
listen for data.
"""
rpc_lock = asyncio.Lock() rpc_lock = asyncio.Lock()
logger.info("Starting connections to Leaders...") logger.info("Starting connections to Leaders...")
await asyncio.wait(
[ self.follower_tasks = [
self._handle_leader_connection(leader, rpc_lock) self._loop.create_task(self._handle_leader_connection(leader, rpc_lock))
for leader in self.leaders_list for leader in self.leaders_list
] ]
) return await asyncio.gather(*self.follower_tasks, return_exceptions=True)
async def _handle_leader_connection(self, leader, lock): async def _handle_leader_connection(self, leader, lock):
""" """
Given a leader, connect and wait on data. If connection is lost, Given a leader, connect and wait on data. If connection is lost,
it will attempt to reconnect. it will attempt to reconnect.
""" """
url, token = leader["url"], leader["token"] try:
url, token = leader["url"], leader["token"]
websocket_url = f"{url}?token={token}" websocket_url = f"{url}?token={token}"
logger.info(f"Attempting to connect to leader at: {url}") logger.info(f"Attempting to connect to leader at: {url}")
# TODO: limit the amount of connection retries # TODO: limit the amount of connection retries
while True: while True:
try: try:
async with websockets.connect(websocket_url) as ws: async with websockets.connect(websocket_url) as ws:
channel = await self.channel_manager.on_connect(ws) channel = await self.channel_manager.on_connect(ws)
while True: while True:
try:
data = await asyncio.wait_for(
channel.recv(),
timeout=self.reply_timeout
)
except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
# We haven't received data yet. Just check the connection and continue.
try: try:
# ping data = await asyncio.wait_for(
ping = await channel.ping() channel.recv(),
await asyncio.wait_for(ping, timeout=self.ping_timeout) timeout=self.reply_timeout
logger.info(f"Connection to {url} still alive...") )
continue except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
except Exception: # We haven't received data yet. Check the connection and continue.
logger.info(f"Ping error {url} - retrying in {self.sleep_time}s") try:
asyncio.sleep(self.sleep_time) # ping
break ping = await channel.ping()
await asyncio.wait_for(ping, timeout=self.ping_timeout)
logger.debug(f"Connection to {url} still alive...")
continue
except Exception:
logger.info(
f"Ping error {url} - retrying in {self.sleep_time}s")
asyncio.sleep(self.sleep_time)
break
with lock: async with lock:
# Should we have a lock here? # Acquire lock so only 1 coro handling at a time
await self._handle_leader_message(data) # as we might call the RPC module in the main thread
await self._handle_leader_message(data)
except socket.gaierror: except socket.gaierror:
logger.info(f"Socket error - retrying connection in {self.sleep_time}s") logger.info(f"Socket error - retrying connection in {self.sleep_time}s")
await asyncio.sleep(self.sleep_time) await asyncio.sleep(self.sleep_time)
continue continue
except ConnectionRefusedError: except ConnectionRefusedError:
logger.info(f"Connection Refused - retrying connection in {self.sleep_time}s") logger.info(f"Connection Refused - retrying connection in {self.sleep_time}s")
await asyncio.sleep(self.sleep_time) await asyncio.sleep(self.sleep_time)
continue continue
except asyncio.CancelledError:
pass
async def _handle_leader_message(self, message): async def _handle_leader_message(self, message):
type = message.get("type") type = message.get('data_type')
data = message.get('data')
message_type_handlers = { if type == LeaderMessageType.whitelist:
LeaderMessageType.Pairlist.value: self._handle_pairlist_message, logger.info(f"Received whitelist from Leader: {data}")
LeaderMessageType.Dataframe.value: self._handle_dataframe_message
}
handler = message_type_handlers.get(type, self._handle_default_message)
return await handler(message)
async def _handle_default_message(self, message):
logger.info(f"Default message handled: {message}")
async def _handle_pairlist_message(self, message):
logger.info(f"Pairlist message handled: {message}")
async def _handle_dataframe_message(self, message):
logger.info(f"Dataframe message handled: {message}")

View File

@ -1,3 +1,4 @@
import logging
from typing import Type from typing import Type
from freqtrade.rpc.replicate.proxy import WebSocketProxy from freqtrade.rpc.replicate.proxy import WebSocketProxy
@ -5,6 +6,9 @@ from freqtrade.rpc.replicate.serializer import JSONWebSocketSerializer, WebSocke
from freqtrade.rpc.replicate.types import WebSocketType from freqtrade.rpc.replicate.types import WebSocketType
logger = logging.getLogger(__name__)
class WebSocketChannel: class WebSocketChannel:
""" """
Object to help facilitate managing a websocket connection Object to help facilitate managing a websocket connection
@ -85,9 +89,12 @@ class ChannelManager:
""" """
if websocket in self.channels.keys(): if websocket in self.channels.keys():
channel = self.channels[websocket] channel = self.channels[websocket]
logger.debug(f"Disconnecting channel - {channel}")
if not channel.is_closed(): if not channel.is_closed():
await channel.close() await channel.close()
del channel del self.channels[websocket]
async def disconnect_all(self): async def disconnect_all(self):
""" """
@ -102,5 +109,15 @@ class ChannelManager:
:param data: The data to send :param data: The data to send
""" """
for channel in self.channels.values(): for websocket, channel in self.channels.items():
await channel.send(data) try:
await channel.send(data)
except RuntimeError:
# Handle cannot send after close cases
await self.on_disconnect(websocket)
def has_channels(self):
"""
Flag for more than 0 channels
"""
return len(self.channels) > 0

View File

@ -1,11 +1,9 @@
from typing import TYPE_CHECKING, Union from typing import Union
from fastapi import WebSocket as FastAPIWebSocket from fastapi import WebSocket as FastAPIWebSocket
from websockets import WebSocketClientProtocol as WebSocket from websockets import WebSocketClientProtocol as WebSocket
from freqtrade.rpc.replicate.types import WebSocketType
if TYPE_CHECKING:
from freqtrade.rpc.replicate.types import WebSocketType
class WebSocketProxy: class WebSocketProxy:
@ -21,6 +19,9 @@ class WebSocketProxy:
""" """
Send data on the wrapped websocket Send data on the wrapped websocket
""" """
if isinstance(data, str):
data = data.encode()
if hasattr(self._websocket, "send_bytes"): if hasattr(self._websocket, "send_bytes"):
await self._websocket.send_bytes(data) await self._websocket.send_bytes(data)
else: else:

View File

@ -33,10 +33,10 @@ class WebSocketSerializer(ABC):
class JSONWebSocketSerializer(WebSocketSerializer): class JSONWebSocketSerializer(WebSocketSerializer):
def _serialize(self, data: bytes) -> bytes: def _serialize(self, data):
# json expects string not bytes # json expects string not bytes
return json.dumps(data.decode()).encode() return json.dumps(data)
def _deserialize(self, data: bytes) -> bytes: def _deserialize(self, data):
# The WebSocketSerializer gives bytes not string # The WebSocketSerializer gives bytes not string
return json.loads(data).encode() return json.loads(data)

View File

@ -3,7 +3,5 @@ from typing import TypeVar
from fastapi import WebSocket as FastAPIWebSocket from fastapi import WebSocket as FastAPIWebSocket
from websockets import WebSocketClientProtocol as WebSocket from websockets import WebSocketClientProtocol as WebSocket
from freqtrade.rpc.replicate.channel import WebSocketProxy
WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket)
WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket, WebSocketProxy)

View File

@ -59,6 +59,9 @@ class RPCManager:
replicate_rpc = ReplicateController(self._rpc, config, apiserver) replicate_rpc = ReplicateController(self._rpc, config, apiserver)
self.registered_modules.append(replicate_rpc) self.registered_modules.append(replicate_rpc)
# Attach the controller to FreqTrade
freqtrade.replicate_controller = replicate_rpc
apiserver.start_api() apiserver.start_api()
def cleanup(self) -> None: def cleanup(self) -> None: