minor improvements and pairlist data transmission
This commit is contained in:
parent
9f6bba40af
commit
6834db11f3
@ -7,5 +7,4 @@ class ReplicateModeType(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class LeaderMessageType(str, Enum):
|
class LeaderMessageType(str, Enum):
|
||||||
Pairlist = "pairlist"
|
whitelist = "whitelist"
|
||||||
Dataframe = "dataframe"
|
|
||||||
|
@ -75,6 +75,8 @@ class FreqtradeBot(LoggingMixin):
|
|||||||
|
|
||||||
PairLocks.timeframe = self.config['timeframe']
|
PairLocks.timeframe = self.config['timeframe']
|
||||||
|
|
||||||
|
self.replicate_controller = None
|
||||||
|
|
||||||
# RPC runs in separate threads, can start handling external commands just after
|
# RPC runs in separate threads, can start handling external commands just after
|
||||||
# initialization, even before Freqtradebot has a chance to start its throttling,
|
# initialization, even before Freqtradebot has a chance to start its throttling,
|
||||||
# so anything in the Freqtradebot instance should be ready (initialized), including
|
# so anything in the Freqtradebot instance should be ready (initialized), including
|
||||||
@ -264,6 +266,17 @@ class FreqtradeBot(LoggingMixin):
|
|||||||
# Extend active-pair whitelist with pairs of open trades
|
# Extend active-pair whitelist with pairs of open trades
|
||||||
# It ensures that candle (OHLCV) data are downloaded for open trades as well
|
# It ensures that candle (OHLCV) data are downloaded for open trades as well
|
||||||
_whitelist.extend([trade.pair for trade in trades if trade.pair not in _whitelist])
|
_whitelist.extend([trade.pair for trade in trades if trade.pair not in _whitelist])
|
||||||
|
|
||||||
|
# If replicate leader, broadcast whitelist data
|
||||||
|
if self.replicate_controller:
|
||||||
|
if self.replicate_controller.is_leader():
|
||||||
|
self.replicate_controller.send_message(
|
||||||
|
{
|
||||||
|
"data_type": "whitelist",
|
||||||
|
"data": _whitelist
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return _whitelist
|
return _whitelist
|
||||||
|
|
||||||
def get_free_open_trades(self) -> int:
|
def get_free_open_trades(self) -> int:
|
||||||
|
59
freqtrade/plugins/pairlist/ExternalPairList.py
Normal file
59
freqtrade/plugins/pairlist/ExternalPairList.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
"""
|
||||||
|
External Pair List provider
|
||||||
|
|
||||||
|
Provides pair list from Leader data
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from freqtrade.plugins.pairlist.IPairList import IPairList
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalPairList(IPairList):
|
||||||
|
|
||||||
|
def __init__(self, exchange, pairlistmanager,
|
||||||
|
config: Dict[str, Any], pairlistconfig: Dict[str, Any],
|
||||||
|
pairlist_pos: int) -> None:
|
||||||
|
super().__init__(exchange, pairlistmanager, config, pairlistconfig, pairlist_pos)
|
||||||
|
|
||||||
|
self._num_assets = self._pairlistconfig.get('num_assets')
|
||||||
|
self._allow_inactive = self._pairlistconfig.get('allow_inactive', False)
|
||||||
|
|
||||||
|
self._leader_pairs: List[str] = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def needstickers(self) -> bool:
|
||||||
|
"""
|
||||||
|
Boolean property defining if tickers are necessary.
|
||||||
|
If no Pairlist requires tickers, an empty Dict is passed
|
||||||
|
as tickers argument to filter_pairlist
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def short_desc(self) -> str:
|
||||||
|
"""
|
||||||
|
Short whitelist method description - used for startup-messages
|
||||||
|
-> Please overwrite in subclasses
|
||||||
|
"""
|
||||||
|
return f"{self.name}"
|
||||||
|
|
||||||
|
def gen_pairlist(self, tickers: Dict) -> List[str]:
|
||||||
|
"""
|
||||||
|
Generate the pairlist
|
||||||
|
:param tickers: Tickers (from exchange.get_tickers()). May be cached.
|
||||||
|
:return: List of pairs
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def filter_pairlist(self, pairlist: List[str], tickers: Dict) -> List[str]:
|
||||||
|
"""
|
||||||
|
Filters and sorts pairlist and returns the whitelist again.
|
||||||
|
Called on each bot iteration - please use internal caching if necessary
|
||||||
|
:param pairlist: pairlist to filter or sort
|
||||||
|
:param tickers: Tickers (from exchange.get_tickers()). May be cached.
|
||||||
|
:return: new whitelist
|
||||||
|
"""
|
||||||
|
pass
|
@ -5,7 +5,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
import socket
|
import socket
|
||||||
from threading import Thread
|
from threading import Event, Thread
|
||||||
from typing import Any, Coroutine, Dict, Union
|
from typing import Any, Coroutine, Dict, Union
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
@ -50,6 +50,9 @@ class ReplicateController(RPCHandler):
|
|||||||
self._thread = None
|
self._thread = None
|
||||||
self._queue = None
|
self._queue = None
|
||||||
|
|
||||||
|
self._stop_event = Event()
|
||||||
|
self._follower_tasks = None
|
||||||
|
|
||||||
self.channel_manager = ChannelManager()
|
self.channel_manager = ChannelManager()
|
||||||
|
|
||||||
self.replicate_config = config.get('replicate', {})
|
self.replicate_config = config.get('replicate', {})
|
||||||
@ -93,10 +96,7 @@ class ReplicateController(RPCHandler):
|
|||||||
|
|
||||||
self.start_threaded_loop()
|
self.start_threaded_loop()
|
||||||
|
|
||||||
if self.mode == ReplicateModeType.follower:
|
self.start()
|
||||||
self.start_follower_mode()
|
|
||||||
elif self.mode == ReplicateModeType.leader:
|
|
||||||
self.start_leader_mode()
|
|
||||||
|
|
||||||
def start_threaded_loop(self):
|
def start_threaded_loop(self):
|
||||||
"""
|
"""
|
||||||
@ -129,6 +129,29 @@ class ReplicateController(RPCHandler):
|
|||||||
logger.error(f"Error running coroutine - {str(e)}")
|
logger.error(f"Error running coroutine - {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def main_loop(self):
|
||||||
|
"""
|
||||||
|
Main loop coro
|
||||||
|
|
||||||
|
Start the loop based on what mode we're in
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.mode == ReplicateModeType.leader:
|
||||||
|
await self.leader_loop()
|
||||||
|
elif self.mode == ReplicateModeType.follower:
|
||||||
|
await self.follower_loop()
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self._loop.stop()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""
|
||||||
|
Start the controller main loop
|
||||||
|
"""
|
||||||
|
self.submit_coroutine(self.main_loop())
|
||||||
|
|
||||||
def cleanup(self) -> None:
|
def cleanup(self) -> None:
|
||||||
"""
|
"""
|
||||||
Cleanup pending module resources.
|
Cleanup pending module resources.
|
||||||
@ -144,27 +167,62 @@ class ReplicateController(RPCHandler):
|
|||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
self._loop.call_soon_threadsafe(self.channel_manager.disconnect_all)
|
self._loop.call_soon_threadsafe(self.channel_manager.disconnect_all)
|
||||||
# This must be called threadsafe, otherwise would not work
|
|
||||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
||||||
|
|
||||||
self._thread.join()
|
self._thread.join()
|
||||||
|
|
||||||
def send_msg(self, msg: Dict[str, Any]) -> None:
|
def send_msg(self, msg: Dict[str, Any]) -> None:
|
||||||
""" Push message through """
|
"""
|
||||||
|
Support RPC calls
|
||||||
|
"""
|
||||||
if msg["type"] == RPCMessageType.EMIT_DATA:
|
if msg["type"] == RPCMessageType.EMIT_DATA:
|
||||||
self._send_message(
|
self.send_message(
|
||||||
{
|
{
|
||||||
"type": msg["data_type"],
|
"data_type": msg.get("data_type"),
|
||||||
"content": msg["data"]
|
"data": msg.get("data")
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def send_message(self, msg: Dict[str, Any]) -> None:
|
||||||
|
""" Push message through """
|
||||||
|
|
||||||
|
if self.channel_manager.has_channels():
|
||||||
|
self._send_message(msg)
|
||||||
|
else:
|
||||||
|
logger.debug("No listening followers, skipping...")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _send_message(self, msg: Dict[Any, Any]):
|
||||||
|
"""
|
||||||
|
Add data to the internal queue to be broadcasted. This func will block
|
||||||
|
if the queue is full. This is meant to be called in the main thread.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._queue:
|
||||||
|
queue = self._queue.sync_q
|
||||||
|
queue.put(msg)
|
||||||
|
else:
|
||||||
|
logger.warning("Can not send data, leader loop has not started yet!")
|
||||||
|
|
||||||
|
def is_leader(self):
|
||||||
|
"""
|
||||||
|
Leader flag
|
||||||
|
"""
|
||||||
|
return self.enabled() and self.mode == ReplicateModeType.leader
|
||||||
|
|
||||||
|
def enabled(self):
|
||||||
|
"""
|
||||||
|
Enabled flag
|
||||||
|
"""
|
||||||
|
return self.replicate_config.get('enabled', False)
|
||||||
|
|
||||||
# ----------------------- LEADER LOGIC ------------------------------
|
# ----------------------- LEADER LOGIC ------------------------------
|
||||||
|
|
||||||
def start_leader_mode(self):
|
async def leader_loop(self):
|
||||||
"""
|
"""
|
||||||
Register the endpoint and start the leader loop
|
Main leader coroutine
|
||||||
|
|
||||||
|
This starts all of the leader coros and registers the endpoint on
|
||||||
|
the ApiServer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger.info("Running rpc.replicate in Leader mode")
|
logger.info("Running rpc.replicate in Leader mode")
|
||||||
@ -173,30 +231,13 @@ class ReplicateController(RPCHandler):
|
|||||||
logger.info("-" * 15)
|
logger.info("-" * 15)
|
||||||
|
|
||||||
self.register_leader_endpoint()
|
self.register_leader_endpoint()
|
||||||
self.submit_coroutine(self.leader_loop())
|
|
||||||
|
|
||||||
async def leader_loop(self):
|
|
||||||
"""
|
|
||||||
Main leader coroutine
|
|
||||||
At the moment this just broadcasts data that's in the queue to the followers
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
await self._broadcast_queue_data()
|
await self._broadcast_queue_data()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Exception occurred in leader loop: ")
|
logger.error("Exception occurred in leader loop: ")
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
|
|
||||||
def _send_message(self, data: Dict[Any, Any]):
|
|
||||||
"""
|
|
||||||
Add data to the internal queue to be broadcasted. This func will block
|
|
||||||
if the queue is full. This is meant to be called in the main thread.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self._queue:
|
|
||||||
self._queue.put(data)
|
|
||||||
else:
|
|
||||||
logger.warning("Can not send data, leader loop has not started yet!")
|
|
||||||
|
|
||||||
async def _broadcast_queue_data(self):
|
async def _broadcast_queue_data(self):
|
||||||
"""
|
"""
|
||||||
Loop over queue data and broadcast it
|
Loop over queue data and broadcast it
|
||||||
@ -210,6 +251,8 @@ class ReplicateController(RPCHandler):
|
|||||||
# Get data from queue
|
# Get data from queue
|
||||||
data = await async_queue.get()
|
data = await async_queue.get()
|
||||||
|
|
||||||
|
logger.info(f"Found data - broadcasting: {data}")
|
||||||
|
|
||||||
# Broadcast it to everyone
|
# Broadcast it to everyone
|
||||||
await self.channel_manager.broadcast(data)
|
await self.channel_manager.broadcast(data)
|
||||||
|
|
||||||
@ -263,6 +306,9 @@ class ReplicateController(RPCHandler):
|
|||||||
channel = await self.channel_manager.on_connect(websocket)
|
channel = await self.channel_manager.on_connect(websocket)
|
||||||
|
|
||||||
# Send initial data here
|
# Send initial data here
|
||||||
|
# Data is being broadcasted right away as soon as startup,
|
||||||
|
# we may not have to send initial data at all. Further testing
|
||||||
|
# required.
|
||||||
|
|
||||||
# Keep connection open until explicitly closed, and sleep
|
# Keep connection open until explicitly closed, and sleep
|
||||||
try:
|
try:
|
||||||
@ -286,20 +332,15 @@ class ReplicateController(RPCHandler):
|
|||||||
|
|
||||||
# -------------------------------FOLLOWER LOGIC----------------------------
|
# -------------------------------FOLLOWER LOGIC----------------------------
|
||||||
|
|
||||||
def start_follower_mode(self):
|
|
||||||
"""
|
|
||||||
Start the ReplicateController in Follower mode
|
|
||||||
"""
|
|
||||||
logger.info("Starting rpc.replicate in Follower mode")
|
|
||||||
|
|
||||||
self.submit_coroutine(self.follower_loop())
|
|
||||||
|
|
||||||
async def follower_loop(self):
|
async def follower_loop(self):
|
||||||
"""
|
"""
|
||||||
Main follower coroutine
|
Main follower coroutine
|
||||||
|
|
||||||
This starts all of the leader connection coros
|
This starts all of the follower connection coros
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
logger.info("Starting rpc.replicate in Follower mode")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._connect_to_leaders()
|
await self._connect_to_leaders()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -307,79 +348,76 @@ class ReplicateController(RPCHandler):
|
|||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
|
|
||||||
async def _connect_to_leaders(self):
|
async def _connect_to_leaders(self):
|
||||||
|
"""
|
||||||
|
For each leader in `self.leaders_list` create a connection and
|
||||||
|
listen for data.
|
||||||
|
"""
|
||||||
rpc_lock = asyncio.Lock()
|
rpc_lock = asyncio.Lock()
|
||||||
|
|
||||||
logger.info("Starting connections to Leaders...")
|
logger.info("Starting connections to Leaders...")
|
||||||
await asyncio.wait(
|
|
||||||
[
|
self.follower_tasks = [
|
||||||
self._handle_leader_connection(leader, rpc_lock)
|
self._loop.create_task(self._handle_leader_connection(leader, rpc_lock))
|
||||||
for leader in self.leaders_list
|
for leader in self.leaders_list
|
||||||
]
|
]
|
||||||
)
|
return await asyncio.gather(*self.follower_tasks, return_exceptions=True)
|
||||||
|
|
||||||
async def _handle_leader_connection(self, leader, lock):
|
async def _handle_leader_connection(self, leader, lock):
|
||||||
"""
|
"""
|
||||||
Given a leader, connect and wait on data. If connection is lost,
|
Given a leader, connect and wait on data. If connection is lost,
|
||||||
it will attempt to reconnect.
|
it will attempt to reconnect.
|
||||||
"""
|
"""
|
||||||
url, token = leader["url"], leader["token"]
|
try:
|
||||||
|
url, token = leader["url"], leader["token"]
|
||||||
|
|
||||||
websocket_url = f"{url}?token={token}"
|
websocket_url = f"{url}?token={token}"
|
||||||
|
|
||||||
logger.info(f"Attempting to connect to leader at: {url}")
|
logger.info(f"Attempting to connect to leader at: {url}")
|
||||||
# TODO: limit the amount of connection retries
|
# TODO: limit the amount of connection retries
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
async with websockets.connect(websocket_url) as ws:
|
async with websockets.connect(websocket_url) as ws:
|
||||||
channel = await self.channel_manager.on_connect(ws)
|
channel = await self.channel_manager.on_connect(ws)
|
||||||
while True:
|
while True:
|
||||||
try:
|
|
||||||
data = await asyncio.wait_for(
|
|
||||||
channel.recv(),
|
|
||||||
timeout=self.reply_timeout
|
|
||||||
)
|
|
||||||
except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
|
|
||||||
# We haven't received data yet. Just check the connection and continue.
|
|
||||||
try:
|
try:
|
||||||
# ping
|
data = await asyncio.wait_for(
|
||||||
ping = await channel.ping()
|
channel.recv(),
|
||||||
await asyncio.wait_for(ping, timeout=self.ping_timeout)
|
timeout=self.reply_timeout
|
||||||
logger.info(f"Connection to {url} still alive...")
|
)
|
||||||
continue
|
except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
|
||||||
except Exception:
|
# We haven't received data yet. Check the connection and continue.
|
||||||
logger.info(f"Ping error {url} - retrying in {self.sleep_time}s")
|
try:
|
||||||
asyncio.sleep(self.sleep_time)
|
# ping
|
||||||
break
|
ping = await channel.ping()
|
||||||
|
await asyncio.wait_for(ping, timeout=self.ping_timeout)
|
||||||
|
logger.debug(f"Connection to {url} still alive...")
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
logger.info(
|
||||||
|
f"Ping error {url} - retrying in {self.sleep_time}s")
|
||||||
|
asyncio.sleep(self.sleep_time)
|
||||||
|
break
|
||||||
|
|
||||||
with lock:
|
async with lock:
|
||||||
# Should we have a lock here?
|
# Acquire lock so only 1 coro handling at a time
|
||||||
await self._handle_leader_message(data)
|
# as we might call the RPC module in the main thread
|
||||||
|
await self._handle_leader_message(data)
|
||||||
|
|
||||||
except socket.gaierror:
|
except socket.gaierror:
|
||||||
logger.info(f"Socket error - retrying connection in {self.sleep_time}s")
|
logger.info(f"Socket error - retrying connection in {self.sleep_time}s")
|
||||||
await asyncio.sleep(self.sleep_time)
|
await asyncio.sleep(self.sleep_time)
|
||||||
continue
|
continue
|
||||||
except ConnectionRefusedError:
|
except ConnectionRefusedError:
|
||||||
logger.info(f"Connection Refused - retrying connection in {self.sleep_time}s")
|
logger.info(f"Connection Refused - retrying connection in {self.sleep_time}s")
|
||||||
await asyncio.sleep(self.sleep_time)
|
await asyncio.sleep(self.sleep_time)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
async def _handle_leader_message(self, message):
|
async def _handle_leader_message(self, message):
|
||||||
type = message.get("type")
|
type = message.get('data_type')
|
||||||
|
data = message.get('data')
|
||||||
|
|
||||||
message_type_handlers = {
|
if type == LeaderMessageType.whitelist:
|
||||||
LeaderMessageType.Pairlist.value: self._handle_pairlist_message,
|
logger.info(f"Received whitelist from Leader: {data}")
|
||||||
LeaderMessageType.Dataframe.value: self._handle_dataframe_message
|
|
||||||
}
|
|
||||||
|
|
||||||
handler = message_type_handlers.get(type, self._handle_default_message)
|
|
||||||
return await handler(message)
|
|
||||||
|
|
||||||
async def _handle_default_message(self, message):
|
|
||||||
logger.info(f"Default message handled: {message}")
|
|
||||||
|
|
||||||
async def _handle_pairlist_message(self, message):
|
|
||||||
logger.info(f"Pairlist message handled: {message}")
|
|
||||||
|
|
||||||
async def _handle_dataframe_message(self, message):
|
|
||||||
logger.info(f"Dataframe message handled: {message}")
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from freqtrade.rpc.replicate.proxy import WebSocketProxy
|
from freqtrade.rpc.replicate.proxy import WebSocketProxy
|
||||||
@ -5,6 +6,9 @@ from freqtrade.rpc.replicate.serializer import JSONWebSocketSerializer, WebSocke
|
|||||||
from freqtrade.rpc.replicate.types import WebSocketType
|
from freqtrade.rpc.replicate.types import WebSocketType
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WebSocketChannel:
|
class WebSocketChannel:
|
||||||
"""
|
"""
|
||||||
Object to help facilitate managing a websocket connection
|
Object to help facilitate managing a websocket connection
|
||||||
@ -85,9 +89,12 @@ class ChannelManager:
|
|||||||
"""
|
"""
|
||||||
if websocket in self.channels.keys():
|
if websocket in self.channels.keys():
|
||||||
channel = self.channels[websocket]
|
channel = self.channels[websocket]
|
||||||
|
|
||||||
|
logger.debug(f"Disconnecting channel - {channel}")
|
||||||
|
|
||||||
if not channel.is_closed():
|
if not channel.is_closed():
|
||||||
await channel.close()
|
await channel.close()
|
||||||
del channel
|
del self.channels[websocket]
|
||||||
|
|
||||||
async def disconnect_all(self):
|
async def disconnect_all(self):
|
||||||
"""
|
"""
|
||||||
@ -102,5 +109,15 @@ class ChannelManager:
|
|||||||
|
|
||||||
:param data: The data to send
|
:param data: The data to send
|
||||||
"""
|
"""
|
||||||
for channel in self.channels.values():
|
for websocket, channel in self.channels.items():
|
||||||
await channel.send(data)
|
try:
|
||||||
|
await channel.send(data)
|
||||||
|
except RuntimeError:
|
||||||
|
# Handle cannot send after close cases
|
||||||
|
await self.on_disconnect(websocket)
|
||||||
|
|
||||||
|
def has_channels(self):
|
||||||
|
"""
|
||||||
|
Flag for more than 0 channels
|
||||||
|
"""
|
||||||
|
return len(self.channels) > 0
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
from typing import TYPE_CHECKING, Union
|
from typing import Union
|
||||||
|
|
||||||
from fastapi import WebSocket as FastAPIWebSocket
|
from fastapi import WebSocket as FastAPIWebSocket
|
||||||
from websockets import WebSocketClientProtocol as WebSocket
|
from websockets import WebSocketClientProtocol as WebSocket
|
||||||
|
|
||||||
|
from freqtrade.rpc.replicate.types import WebSocketType
|
||||||
if TYPE_CHECKING:
|
|
||||||
from freqtrade.rpc.replicate.types import WebSocketType
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketProxy:
|
class WebSocketProxy:
|
||||||
@ -21,6 +19,9 @@ class WebSocketProxy:
|
|||||||
"""
|
"""
|
||||||
Send data on the wrapped websocket
|
Send data on the wrapped websocket
|
||||||
"""
|
"""
|
||||||
|
if isinstance(data, str):
|
||||||
|
data = data.encode()
|
||||||
|
|
||||||
if hasattr(self._websocket, "send_bytes"):
|
if hasattr(self._websocket, "send_bytes"):
|
||||||
await self._websocket.send_bytes(data)
|
await self._websocket.send_bytes(data)
|
||||||
else:
|
else:
|
||||||
|
@ -33,10 +33,10 @@ class WebSocketSerializer(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class JSONWebSocketSerializer(WebSocketSerializer):
|
class JSONWebSocketSerializer(WebSocketSerializer):
|
||||||
def _serialize(self, data: bytes) -> bytes:
|
def _serialize(self, data):
|
||||||
# json expects string not bytes
|
# json expects string not bytes
|
||||||
return json.dumps(data.decode()).encode()
|
return json.dumps(data)
|
||||||
|
|
||||||
def _deserialize(self, data: bytes) -> bytes:
|
def _deserialize(self, data):
|
||||||
# The WebSocketSerializer gives bytes not string
|
# The WebSocketSerializer gives bytes not string
|
||||||
return json.loads(data).encode()
|
return json.loads(data)
|
||||||
|
@ -3,7 +3,5 @@ from typing import TypeVar
|
|||||||
from fastapi import WebSocket as FastAPIWebSocket
|
from fastapi import WebSocket as FastAPIWebSocket
|
||||||
from websockets import WebSocketClientProtocol as WebSocket
|
from websockets import WebSocketClientProtocol as WebSocket
|
||||||
|
|
||||||
from freqtrade.rpc.replicate.channel import WebSocketProxy
|
|
||||||
|
|
||||||
|
WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket)
|
||||||
WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket, WebSocketProxy)
|
|
||||||
|
@ -59,6 +59,9 @@ class RPCManager:
|
|||||||
replicate_rpc = ReplicateController(self._rpc, config, apiserver)
|
replicate_rpc = ReplicateController(self._rpc, config, apiserver)
|
||||||
self.registered_modules.append(replicate_rpc)
|
self.registered_modules.append(replicate_rpc)
|
||||||
|
|
||||||
|
# Attach the controller to FreqTrade
|
||||||
|
freqtrade.replicate_controller = replicate_rpc
|
||||||
|
|
||||||
apiserver.start_api()
|
apiserver.start_api()
|
||||||
|
|
||||||
def cleanup(self) -> None:
|
def cleanup(self) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user