Refactoring, minor improvements, data provider improvements
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from threading import RLock
|
||||
from typing import Type
|
||||
|
||||
from freqtrade.rpc.external_signal.proxy import WebSocketProxy
|
||||
@@ -63,6 +64,7 @@ class WebSocketChannel:
|
||||
class ChannelManager:
|
||||
def __init__(self):
|
||||
self.channels = dict()
|
||||
self._lock = RLock() # Re-entrant Lock
|
||||
|
||||
async def on_connect(self, websocket: WebSocketType):
|
||||
"""
|
||||
@@ -78,7 +80,9 @@ class ChannelManager:
|
||||
return
|
||||
|
||||
ws_channel = WebSocketChannel(websocket)
|
||||
self.channels[websocket] = ws_channel
|
||||
|
||||
with self._lock:
|
||||
self.channels[websocket] = ws_channel
|
||||
|
||||
return ws_channel
|
||||
|
||||
@@ -88,21 +92,26 @@ class ChannelManager:
|
||||
|
||||
:param websocket: The WebSocket objet attached to the Channel
|
||||
"""
|
||||
if websocket in self.channels.keys():
|
||||
channel = self.channels[websocket]
|
||||
with self._lock:
|
||||
channel = self.channels.get(websocket)
|
||||
if channel:
|
||||
logger.debug(f"Disconnecting channel - {channel}")
|
||||
|
||||
logger.debug(f"Disconnecting channel - {channel}")
|
||||
if not channel.is_closed():
|
||||
await channel.close()
|
||||
|
||||
if not channel.is_closed():
|
||||
await channel.close()
|
||||
del self.channels[websocket]
|
||||
del self.channels[websocket]
|
||||
|
||||
async def disconnect_all(self):
|
||||
"""
|
||||
Disconnect all Channels
|
||||
"""
|
||||
for websocket in self.channels.keys():
|
||||
await self.on_disconnect(websocket)
|
||||
with self._lock:
|
||||
for websocket, channel in self.channels.items():
|
||||
if not channel.is_closed():
|
||||
await channel.close()
|
||||
|
||||
self.channels = dict()
|
||||
|
||||
async def broadcast(self, data):
|
||||
"""
|
||||
@@ -110,12 +119,13 @@ class ChannelManager:
|
||||
|
||||
:param data: The data to send
|
||||
"""
|
||||
for websocket, channel in self.channels.items():
|
||||
try:
|
||||
await channel.send(data)
|
||||
except RuntimeError:
|
||||
# Handle cannot send after close cases
|
||||
await self.on_disconnect(websocket)
|
||||
with self._lock:
|
||||
for websocket, channel in self.channels.items():
|
||||
try:
|
||||
await channel.send(data)
|
||||
except RuntimeError:
|
||||
# Handle cannot send after close cases
|
||||
await self.on_disconnect(websocket)
|
||||
|
||||
async def send_direct(self, channel, data):
|
||||
"""
|
||||
|
@@ -6,7 +6,7 @@ import logging
|
||||
import secrets
|
||||
import socket
|
||||
from threading import Thread
|
||||
from typing import Any, Coroutine, Dict, Union
|
||||
from typing import Any, Callable, Coroutine, Dict, Union
|
||||
|
||||
import websockets
|
||||
from fastapi import Depends
|
||||
@@ -56,8 +56,13 @@ class ExternalSignalController(RPCHandler):
|
||||
self._main_task = None
|
||||
self._sub_tasks = None
|
||||
|
||||
self.channel_manager = ChannelManager()
|
||||
self._message_handlers = {
|
||||
LeaderMessageType.pairlist: self._rpc._handle_pairlist_message,
|
||||
LeaderMessageType.analyzed_df: self._rpc._handle_analyzed_df_message,
|
||||
LeaderMessageType.default: self._rpc._handle_default_message
|
||||
}
|
||||
|
||||
self.channel_manager = ChannelManager()
|
||||
self.external_signal_config = config.get('external_signal', {})
|
||||
|
||||
# What the config should look like
|
||||
@@ -89,6 +94,8 @@ class ExternalSignalController(RPCHandler):
|
||||
self.ping_timeout = self.external_signal_config.get('follower_ping_timeout', 2)
|
||||
self.sleep_time = self.external_signal_config.get('follower_sleep_time', 5)
|
||||
|
||||
# Validate external_signal_config here?
|
||||
|
||||
if self.mode == ExternalSignalModeType.follower and len(self.leaders_list) == 0:
|
||||
raise ValueError("You must specify at least 1 leader in follower mode.")
|
||||
|
||||
@@ -99,7 +106,6 @@ class ExternalSignalController(RPCHandler):
|
||||
default_api_key = secrets.token_urlsafe(16)
|
||||
self.secret_api_key = self.external_signal_config.get('api_token', default_api_key)
|
||||
|
||||
self.start_threaded_loop()
|
||||
self.start()
|
||||
|
||||
def is_leader(self):
|
||||
@@ -114,6 +120,12 @@ class ExternalSignalController(RPCHandler):
|
||||
"""
|
||||
return self.external_signal_config.get('enabled', False)
|
||||
|
||||
def num_leaders(self):
|
||||
"""
|
||||
The number of leaders we should be connected to
|
||||
"""
|
||||
return len(self.leaders_list)
|
||||
|
||||
def start_threaded_loop(self):
|
||||
"""
|
||||
Start the main internal loop in another thread to run coroutines
|
||||
@@ -144,6 +156,7 @@ class ExternalSignalController(RPCHandler):
|
||||
"""
|
||||
Start the controller main loop
|
||||
"""
|
||||
self.start_threaded_loop()
|
||||
self._main_task = self.submit_coroutine(self.main())
|
||||
|
||||
async def shutdown(self):
|
||||
@@ -242,23 +255,20 @@ class ExternalSignalController(RPCHandler):
|
||||
async def send_initial_data(self, channel):
|
||||
logger.info("Sending initial data through channel")
|
||||
|
||||
# We first send pairlist data
|
||||
# We should move this to a func in the RPC object
|
||||
initial_data = {
|
||||
"data_type": LeaderMessageType.pairlist,
|
||||
"data": self.freqtrade.pairlists.whitelist
|
||||
}
|
||||
data = self._rpc._initial_leader_data()
|
||||
|
||||
await channel.send(initial_data)
|
||||
for message in data:
|
||||
await channel.send(message)
|
||||
|
||||
async def _handle_leader_message(self, message: MessageType):
|
||||
"""
|
||||
Handle message received from a Leader
|
||||
"""
|
||||
type = message.get("data_type")
|
||||
type = message.get("data_type", LeaderMessageType.default)
|
||||
data = message.get("data")
|
||||
|
||||
self._rpc._handle_emitted_data(type, data)
|
||||
handler: Callable = self._message_handlers[type]
|
||||
handler(type, data)
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
@@ -1,5 +1,8 @@
|
||||
from pandas import DataFrame
|
||||
from starlette.websockets import WebSocket, WebSocketState
|
||||
|
||||
from freqtrade.enums.signaltype import SignalTagType, SignalType
|
||||
|
||||
|
||||
async def is_websocket_alive(ws: WebSocket) -> bool:
|
||||
if (
|
||||
@@ -8,3 +11,12 @@ async def is_websocket_alive(ws: WebSocket) -> bool:
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def remove_entry_exit_signals(dataframe: DataFrame):
|
||||
dataframe[SignalType.ENTER_LONG.value] = 0
|
||||
dataframe[SignalType.EXIT_LONG.value] = 0
|
||||
dataframe[SignalType.ENTER_SHORT.value] = 0
|
||||
dataframe[SignalType.EXIT_SHORT.value] = 0
|
||||
dataframe[SignalTagType.ENTER_TAG.value] = None
|
||||
dataframe[SignalTagType.EXIT_TAG.value] = None
|
||||
|
Reference in New Issue
Block a user