refactor broadcasting to queue per client, only send most recent candles

This commit is contained in:
Timothy Pogue 2022-10-08 18:20:07 -06:00
parent e337d4b78a
commit f9b3b0ef77
5 changed files with 83 additions and 27 deletions

View File

@ -9,7 +9,7 @@ from collections import deque
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple
from pandas import DataFrame
from pandas import DataFrame, concat
from freqtrade.configuration import TimeRange
from freqtrade.constants import Config, ListPairsWithTimeframes, PairWithTimeframe
@ -118,13 +118,13 @@ class DataProvider:
'type': RPCMessageType.ANALYZED_DF,
'data': {
'key': pair_key,
'df': dataframe,
'df': dataframe.tail(1),
'la': datetime.now(timezone.utc)
}
}
)
def _add_external_df(
def _add_producer_df(
self,
pair: str,
dataframe: DataFrame,
@ -147,7 +147,16 @@ class DataProvider:
_last_analyzed = datetime.now(timezone.utc) if not last_analyzed else last_analyzed
self.__producer_pairs_df[producer_name][pair_key] = (dataframe, _last_analyzed)
if pair_key not in self.__producer_pairs_df[producer_name]:
# This is the first message, set the dataframe in that pair key
self.__producer_pairs_df[producer_name][pair_key] = (dataframe, _last_analyzed)
else:
# These are new candles, append them to the dataframe
existing_df, _ = self.__producer_pairs_df[producer_name][pair_key]
existing_df = self._append_candle_to_dataframe(existing_df, dataframe)
self.__producer_pairs_df[producer_name][pair_key] = (existing_df, _last_analyzed)
logger.debug(f"External DataFrame for {pair_key} from {producer_name} added.")
def get_producer_df(
@ -184,6 +193,24 @@ class DataProvider:
df, la = self.__producer_pairs_df[producer_name][pair_key]
return (df.copy(), la)
def _append_candle_to_dataframe(self, existing: DataFrame, new: DataFrame):
"""
Append the `new` dataframe to the `existing` dataframe
:param existing: The full dataframe you want appended to
:param new: The new dataframe containing the data you want appended
:returns:The dataframe with the new data in it
"""
if existing.iloc[-1]['date'] != new.iloc[-1]['date']:
existing = concat([existing, new])
# Only keep the last 1000 candles in memory
# TODO: Do this better
if len(existing) > 1000:
existing = existing[-1000:]
return existing
def add_pairlisthandler(self, pairlists) -> None:
"""
Allow adding pairlisthandler after initialization

View File

@ -1,6 +1,8 @@
import asyncio
import logging
from typing import Any, Dict
import websockets
from fastapi import APIRouter, Depends, WebSocketDisconnect
from fastapi.websockets import WebSocket, WebSocketState
from pydantic import ValidationError
@ -90,6 +92,20 @@ async def _process_consumer_request(
await channel.send(response.dict(exclude_none=True))
async def relay(channel, queue):
"""
Relay messages in the queue to the channel
"""
while True:
message = await queue.get()
try:
await channel.send(message)
queue.task_done()
except RuntimeError:
# What do we do here?
return
@router.websocket("/message/ws")
async def message_endpoint(
ws: WebSocket,
@ -100,12 +116,13 @@ async def message_endpoint(
"""
Message WebSocket endpoint, facilitates sending RPC messages
"""
relay_task = None
try:
channel = await channel_manager.on_connect(ws)
channel, queue = await channel_manager.on_connect(ws)
if await is_websocket_alive(ws):
logger.info(f"Consumer connected - {channel}")
relay_task = asyncio.create_task(relay(channel, queue))
# Keep connection open until explicitly closed, and process requests
try:
@ -115,26 +132,32 @@ async def message_endpoint(
# Process the request here
await _process_consumer_request(request, channel, rpc)
except WebSocketDisconnect:
except (
WebSocketDisconnect,
websockets.exceptions.ConnectionClosed
):
# Handle client disconnects
logger.info(f"Consumer disconnected - {channel}")
await channel_manager.on_disconnect(ws)
except Exception as e:
logger.info(f"Consumer connection failed - {channel}")
logger.exception(e)
# Handle cases like -
# RuntimeError('Cannot call "send" once a closed message has been sent')
finally:
relay_task.cancel()
await channel_manager.on_disconnect(ws)
else:
await ws.close()
except RuntimeError:
# WebSocket was closed
await channel_manager.on_disconnect(ws)
# We don't want to log these
pass
except Exception as e:
logger.error(f"Failed to serve - {ws.client}")
# Log tracebacks to keep track of what errors are happening
logger.exception(e)
finally:
await channel_manager.on_disconnect(ws)
if relay_task:
relay_task.cancel()

View File

@ -245,6 +245,7 @@ class ApiServer(RPCHandler):
use_colors=False,
log_config=None,
access_log=True if verbosity != 'error' else False,
ws_ping_interval=None
)
try:
self._server = UvicornServer(uvconfig)

View File

@ -1,6 +1,7 @@
import asyncio
import logging
from threading import RLock
from typing import List, Optional, Type
from typing import Any, Dict, List, Optional, Type
from uuid import uuid4
from fastapi import WebSocket as FastAPIWebSocket
@ -52,7 +53,7 @@ class WebSocketChannel:
"""
Send data on the wrapped websocket
"""
await self._wrapped_ws.send(data)
return await self._wrapped_ws.send(data)
async def recv(self):
"""
@ -115,11 +116,12 @@ class ChannelManager:
return
ws_channel = WebSocketChannel(websocket)
ws_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
with self._lock:
self.channels[websocket] = ws_channel
self.channels[websocket] = (ws_channel, ws_queue)
return ws_channel
return ws_channel, ws_queue
async def on_disconnect(self, websocket: WebSocketType):
"""
@ -128,7 +130,7 @@ class ChannelManager:
:param websocket: The WebSocket objet attached to the Channel
"""
with self._lock:
channel = self.channels.get(websocket)
channel, _ = self.channels.get(websocket, (None, None))
if channel:
if not channel.is_closed():
await channel.close()
@ -140,7 +142,7 @@ class ChannelManager:
Disconnect all Channels
"""
with self._lock:
for websocket, channel in self.channels.copy().items():
for websocket, (channel, _) in self.channels.copy().items():
if not channel.is_closed():
await channel.close()
@ -154,13 +156,12 @@ class ChannelManager:
"""
with self._lock:
message_type = data.get('type')
for websocket, channel in self.channels.copy().items():
try:
if channel.subscribed_to(message_type):
await channel.send(data)
except RuntimeError:
# Handle cannot send after close cases
await self.on_disconnect(websocket)
for websocket, (channel, queue) in self.channels.copy().items():
if channel.subscribed_to(message_type):
if not queue.full():
queue.put_nowait(data)
else:
await self.on_disconnect(websocket)
async def send_direct(self, channel, data):
"""

View File

@ -62,7 +62,7 @@ class ExternalMessageConsumer:
self.enabled = self._emc_config.get('enabled', False)
self.producers: List[Producer] = self._emc_config.get('producers', [])
self.wait_timeout = self._emc_config.get('wait_timeout', 300) # in seconds
self.wait_timeout = self._emc_config.get('wait_timeout', 30) # in seconds
self.ping_timeout = self._emc_config.get('ping_timeout', 10) # in seconds
self.sleep_time = self._emc_config.get('sleep_time', 10) # in seconds
@ -182,7 +182,11 @@ class ExternalMessageConsumer:
ws_url = f"ws://{host}:{port}/api/v1/message/ws?token={token}"
# This will raise InvalidURI if the url is bad
async with websockets.connect(ws_url, max_size=self.message_size_limit) as ws:
async with websockets.connect(
ws_url,
max_size=self.message_size_limit,
ping_interval=None
) as ws:
channel = WebSocketChannel(ws, channel_id=name)
logger.info(f"Producer connection success - {channel}")
@ -325,7 +329,7 @@ class ExternalMessageConsumer:
df = remove_entry_exit_signals(df)
# Add the dataframe to the dataprovider
self._dp._add_external_df(pair, df,
self._dp._add_producer_df(pair, df,
last_analyzed=la,
timeframe=timeframe,
candle_type=candle_type,