Refactoring, minor improvements, data provider improvements

This commit is contained in:
Timothy Pogue
2022-08-26 23:40:13 -06:00
parent a998d6d773
commit 2b5f067877
14 changed files with 218 additions and 98 deletions

View File

@@ -1,4 +1,5 @@
import logging
from threading import RLock
from typing import Type
from freqtrade.rpc.external_signal.proxy import WebSocketProxy
@@ -63,6 +64,7 @@ class WebSocketChannel:
class ChannelManager:
def __init__(self):
self.channels = dict()
self._lock = RLock() # Re-entrant Lock
async def on_connect(self, websocket: WebSocketType):
"""
@@ -78,7 +80,9 @@ class ChannelManager:
return
ws_channel = WebSocketChannel(websocket)
self.channels[websocket] = ws_channel
with self._lock:
self.channels[websocket] = ws_channel
return ws_channel
@@ -88,21 +92,26 @@ class ChannelManager:
:param websocket: The WebSocket objet attached to the Channel
"""
if websocket in self.channels.keys():
channel = self.channels[websocket]
with self._lock:
channel = self.channels.get(websocket)
if channel:
logger.debug(f"Disconnecting channel - {channel}")
logger.debug(f"Disconnecting channel - {channel}")
if not channel.is_closed():
await channel.close()
if not channel.is_closed():
await channel.close()
del self.channels[websocket]
del self.channels[websocket]
async def disconnect_all(self):
"""
Disconnect all Channels
"""
for websocket in self.channels.keys():
await self.on_disconnect(websocket)
with self._lock:
for websocket, channel in self.channels.items():
if not channel.is_closed():
await channel.close()
self.channels = dict()
async def broadcast(self, data):
"""
@@ -110,12 +119,13 @@ class ChannelManager:
:param data: The data to send
"""
for websocket, channel in self.channels.items():
try:
await channel.send(data)
except RuntimeError:
# Handle cannot send after close cases
await self.on_disconnect(websocket)
with self._lock:
for websocket, channel in self.channels.items():
try:
await channel.send(data)
except RuntimeError:
# Handle cannot send after close cases
await self.on_disconnect(websocket)
async def send_direct(self, channel, data):
"""

View File

@@ -6,7 +6,7 @@ import logging
import secrets
import socket
from threading import Thread
from typing import Any, Coroutine, Dict, Union
from typing import Any, Callable, Coroutine, Dict, Union
import websockets
from fastapi import Depends
@@ -56,8 +56,13 @@ class ExternalSignalController(RPCHandler):
self._main_task = None
self._sub_tasks = None
self.channel_manager = ChannelManager()
self._message_handlers = {
LeaderMessageType.pairlist: self._rpc._handle_pairlist_message,
LeaderMessageType.analyzed_df: self._rpc._handle_analyzed_df_message,
LeaderMessageType.default: self._rpc._handle_default_message
}
self.channel_manager = ChannelManager()
self.external_signal_config = config.get('external_signal', {})
# What the config should look like
@@ -89,6 +94,8 @@ class ExternalSignalController(RPCHandler):
self.ping_timeout = self.external_signal_config.get('follower_ping_timeout', 2)
self.sleep_time = self.external_signal_config.get('follower_sleep_time', 5)
# Validate external_signal_config here?
if self.mode == ExternalSignalModeType.follower and len(self.leaders_list) == 0:
raise ValueError("You must specify at least 1 leader in follower mode.")
@@ -99,7 +106,6 @@ class ExternalSignalController(RPCHandler):
default_api_key = secrets.token_urlsafe(16)
self.secret_api_key = self.external_signal_config.get('api_token', default_api_key)
self.start_threaded_loop()
self.start()
def is_leader(self):
@@ -114,6 +120,12 @@ class ExternalSignalController(RPCHandler):
"""
return self.external_signal_config.get('enabled', False)
def num_leaders(self):
"""
The number of leaders we should be connected to
"""
return len(self.leaders_list)
def start_threaded_loop(self):
"""
Start the main internal loop in another thread to run coroutines
@@ -144,6 +156,7 @@ class ExternalSignalController(RPCHandler):
"""
Start the controller main loop
"""
self.start_threaded_loop()
self._main_task = self.submit_coroutine(self.main())
async def shutdown(self):
@@ -242,23 +255,20 @@ class ExternalSignalController(RPCHandler):
async def send_initial_data(self, channel):
logger.info("Sending initial data through channel")
# We first send pairlist data
# We should move this to a func in the RPC object
initial_data = {
"data_type": LeaderMessageType.pairlist,
"data": self.freqtrade.pairlists.whitelist
}
data = self._rpc._initial_leader_data()
await channel.send(initial_data)
for message in data:
await channel.send(message)
async def _handle_leader_message(self, message: MessageType):
"""
Handle message received from a Leader
"""
type = message.get("data_type")
type = message.get("data_type", LeaderMessageType.default)
data = message.get("data")
self._rpc._handle_emitted_data(type, data)
handler: Callable = self._message_handlers[type]
handler(type, data)
# ----------------------------------------------------------------------

View File

@@ -1,5 +1,8 @@
from pandas import DataFrame
from starlette.websockets import WebSocket, WebSocketState
from freqtrade.enums.signaltype import SignalTagType, SignalType
async def is_websocket_alive(ws: WebSocket) -> bool:
if (
@@ -8,3 +11,12 @@ async def is_websocket_alive(ws: WebSocket) -> bool:
):
return True
return False
def remove_entry_exit_signals(dataframe: DataFrame):
dataframe[SignalType.ENTER_LONG.value] = 0
dataframe[SignalType.EXIT_LONG.value] = 0
dataframe[SignalType.ENTER_SHORT.value] = 0
dataframe[SignalType.EXIT_SHORT.value] = 0
dataframe[SignalTagType.ENTER_TAG.value] = None
dataframe[SignalTagType.EXIT_TAG.value] = None