refactor broadcasting to queue per client, only send most recent candles

This commit is contained in:
Timothy Pogue 2022-10-08 18:20:07 -06:00
parent e337d4b78a
commit f9b3b0ef77
5 changed files with 83 additions and 27 deletions

View File

@ -9,7 +9,7 @@ from collections import deque
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from pandas import DataFrame from pandas import DataFrame, concat
from freqtrade.configuration import TimeRange from freqtrade.configuration import TimeRange
from freqtrade.constants import Config, ListPairsWithTimeframes, PairWithTimeframe from freqtrade.constants import Config, ListPairsWithTimeframes, PairWithTimeframe
@ -118,13 +118,13 @@ class DataProvider:
'type': RPCMessageType.ANALYZED_DF, 'type': RPCMessageType.ANALYZED_DF,
'data': { 'data': {
'key': pair_key, 'key': pair_key,
'df': dataframe, 'df': dataframe.tail(1),
'la': datetime.now(timezone.utc) 'la': datetime.now(timezone.utc)
} }
} }
) )
def _add_external_df( def _add_producer_df(
self, self,
pair: str, pair: str,
dataframe: DataFrame, dataframe: DataFrame,
@ -147,7 +147,16 @@ class DataProvider:
_last_analyzed = datetime.now(timezone.utc) if not last_analyzed else last_analyzed _last_analyzed = datetime.now(timezone.utc) if not last_analyzed else last_analyzed
self.__producer_pairs_df[producer_name][pair_key] = (dataframe, _last_analyzed) if pair_key not in self.__producer_pairs_df[producer_name]:
# This is the first message, set the dataframe in that pair key
self.__producer_pairs_df[producer_name][pair_key] = (dataframe, _last_analyzed)
else:
# These are new candles, append them to the dataframe
existing_df, _ = self.__producer_pairs_df[producer_name][pair_key]
existing_df = self._append_candle_to_dataframe(existing_df, dataframe)
self.__producer_pairs_df[producer_name][pair_key] = (existing_df, _last_analyzed)
logger.debug(f"External DataFrame for {pair_key} from {producer_name} added.") logger.debug(f"External DataFrame for {pair_key} from {producer_name} added.")
def get_producer_df( def get_producer_df(
@ -184,6 +193,24 @@ class DataProvider:
df, la = self.__producer_pairs_df[producer_name][pair_key] df, la = self.__producer_pairs_df[producer_name][pair_key]
return (df.copy(), la) return (df.copy(), la)
def _append_candle_to_dataframe(self, existing: DataFrame, new: DataFrame):
"""
Append the `new` dataframe to the `existing` dataframe
:param existing: The full dataframe you want appended to
:param new: The new dataframe containing the data you want appended
:returns:The dataframe with the new data in it
"""
if existing.iloc[-1]['date'] != new.iloc[-1]['date']:
existing = concat([existing, new])
# Only keep the last 1000 candles in memory
# TODO: Do this better
if len(existing) > 1000:
existing = existing[-1000:]
return existing
def add_pairlisthandler(self, pairlists) -> None: def add_pairlisthandler(self, pairlists) -> None:
""" """
Allow adding pairlisthandler after initialization Allow adding pairlisthandler after initialization

View File

@ -1,6 +1,8 @@
import asyncio
import logging import logging
from typing import Any, Dict from typing import Any, Dict
import websockets
from fastapi import APIRouter, Depends, WebSocketDisconnect from fastapi import APIRouter, Depends, WebSocketDisconnect
from fastapi.websockets import WebSocket, WebSocketState from fastapi.websockets import WebSocket, WebSocketState
from pydantic import ValidationError from pydantic import ValidationError
@ -90,6 +92,20 @@ async def _process_consumer_request(
await channel.send(response.dict(exclude_none=True)) await channel.send(response.dict(exclude_none=True))
async def relay(channel, queue):
"""
Relay messages in the queue to the channel
"""
while True:
message = await queue.get()
try:
await channel.send(message)
queue.task_done()
except RuntimeError:
# What do we do here?
return
@router.websocket("/message/ws") @router.websocket("/message/ws")
async def message_endpoint( async def message_endpoint(
ws: WebSocket, ws: WebSocket,
@ -100,12 +116,13 @@ async def message_endpoint(
""" """
Message WebSocket endpoint, facilitates sending RPC messages Message WebSocket endpoint, facilitates sending RPC messages
""" """
relay_task = None
try: try:
channel = await channel_manager.on_connect(ws) channel, queue = await channel_manager.on_connect(ws)
if await is_websocket_alive(ws): if await is_websocket_alive(ws):
logger.info(f"Consumer connected - {channel}") logger.info(f"Consumer connected - {channel}")
relay_task = asyncio.create_task(relay(channel, queue))
# Keep connection open until explicitly closed, and process requests # Keep connection open until explicitly closed, and process requests
try: try:
@ -115,26 +132,32 @@ async def message_endpoint(
# Process the request here # Process the request here
await _process_consumer_request(request, channel, rpc) await _process_consumer_request(request, channel, rpc)
except WebSocketDisconnect: except (
WebSocketDisconnect,
websockets.exceptions.ConnectionClosed
):
# Handle client disconnects # Handle client disconnects
logger.info(f"Consumer disconnected - {channel}") logger.info(f"Consumer disconnected - {channel}")
await channel_manager.on_disconnect(ws)
except Exception as e: except Exception as e:
logger.info(f"Consumer connection failed - {channel}") logger.info(f"Consumer connection failed - {channel}")
logger.exception(e) logger.exception(e)
# Handle cases like - # Handle cases like -
# RuntimeError('Cannot call "send" once a closed message has been sent') # RuntimeError('Cannot call "send" once a closed message has been sent')
finally:
relay_task.cancel()
await channel_manager.on_disconnect(ws) await channel_manager.on_disconnect(ws)
else: else:
await ws.close() await ws.close()
except RuntimeError: except RuntimeError:
# WebSocket was closed # We don't want to log these
await channel_manager.on_disconnect(ws) pass
except Exception as e: except Exception as e:
logger.error(f"Failed to serve - {ws.client}") logger.error(f"Failed to serve - {ws.client}")
# Log tracebacks to keep track of what errors are happening # Log tracebacks to keep track of what errors are happening
logger.exception(e) logger.exception(e)
finally:
await channel_manager.on_disconnect(ws) await channel_manager.on_disconnect(ws)
if relay_task:
relay_task.cancel()

View File

@ -245,6 +245,7 @@ class ApiServer(RPCHandler):
use_colors=False, use_colors=False,
log_config=None, log_config=None,
access_log=True if verbosity != 'error' else False, access_log=True if verbosity != 'error' else False,
ws_ping_interval=None
) )
try: try:
self._server = UvicornServer(uvconfig) self._server = UvicornServer(uvconfig)

View File

@ -1,6 +1,7 @@
import asyncio
import logging import logging
from threading import RLock from threading import RLock
from typing import List, Optional, Type from typing import Any, Dict, List, Optional, Type
from uuid import uuid4 from uuid import uuid4
from fastapi import WebSocket as FastAPIWebSocket from fastapi import WebSocket as FastAPIWebSocket
@ -52,7 +53,7 @@ class WebSocketChannel:
""" """
Send data on the wrapped websocket Send data on the wrapped websocket
""" """
await self._wrapped_ws.send(data) return await self._wrapped_ws.send(data)
async def recv(self): async def recv(self):
""" """
@ -115,11 +116,12 @@ class ChannelManager:
return return
ws_channel = WebSocketChannel(websocket) ws_channel = WebSocketChannel(websocket)
ws_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
with self._lock: with self._lock:
self.channels[websocket] = ws_channel self.channels[websocket] = (ws_channel, ws_queue)
return ws_channel return ws_channel, ws_queue
async def on_disconnect(self, websocket: WebSocketType): async def on_disconnect(self, websocket: WebSocketType):
""" """
@ -128,7 +130,7 @@ class ChannelManager:
:param websocket: The WebSocket objet attached to the Channel :param websocket: The WebSocket objet attached to the Channel
""" """
with self._lock: with self._lock:
channel = self.channels.get(websocket) channel, _ = self.channels.get(websocket, (None, None))
if channel: if channel:
if not channel.is_closed(): if not channel.is_closed():
await channel.close() await channel.close()
@ -140,7 +142,7 @@ class ChannelManager:
Disconnect all Channels Disconnect all Channels
""" """
with self._lock: with self._lock:
for websocket, channel in self.channels.copy().items(): for websocket, (channel, _) in self.channels.copy().items():
if not channel.is_closed(): if not channel.is_closed():
await channel.close() await channel.close()
@ -154,13 +156,12 @@ class ChannelManager:
""" """
with self._lock: with self._lock:
message_type = data.get('type') message_type = data.get('type')
for websocket, channel in self.channels.copy().items(): for websocket, (channel, queue) in self.channels.copy().items():
try: if channel.subscribed_to(message_type):
if channel.subscribed_to(message_type): if not queue.full():
await channel.send(data) queue.put_nowait(data)
except RuntimeError: else:
# Handle cannot send after close cases await self.on_disconnect(websocket)
await self.on_disconnect(websocket)
async def send_direct(self, channel, data): async def send_direct(self, channel, data):
""" """

View File

@ -62,7 +62,7 @@ class ExternalMessageConsumer:
self.enabled = self._emc_config.get('enabled', False) self.enabled = self._emc_config.get('enabled', False)
self.producers: List[Producer] = self._emc_config.get('producers', []) self.producers: List[Producer] = self._emc_config.get('producers', [])
self.wait_timeout = self._emc_config.get('wait_timeout', 300) # in seconds self.wait_timeout = self._emc_config.get('wait_timeout', 30) # in seconds
self.ping_timeout = self._emc_config.get('ping_timeout', 10) # in seconds self.ping_timeout = self._emc_config.get('ping_timeout', 10) # in seconds
self.sleep_time = self._emc_config.get('sleep_time', 10) # in seconds self.sleep_time = self._emc_config.get('sleep_time', 10) # in seconds
@ -182,7 +182,11 @@ class ExternalMessageConsumer:
ws_url = f"ws://{host}:{port}/api/v1/message/ws?token={token}" ws_url = f"ws://{host}:{port}/api/v1/message/ws?token={token}"
# This will raise InvalidURI if the url is bad # This will raise InvalidURI if the url is bad
async with websockets.connect(ws_url, max_size=self.message_size_limit) as ws: async with websockets.connect(
ws_url,
max_size=self.message_size_limit,
ping_interval=None
) as ws:
channel = WebSocketChannel(ws, channel_id=name) channel = WebSocketChannel(ws, channel_id=name)
logger.info(f"Producer connection success - {channel}") logger.info(f"Producer connection success - {channel}")
@ -325,7 +329,7 @@ class ExternalMessageConsumer:
df = remove_entry_exit_signals(df) df = remove_entry_exit_signals(df)
# Add the dataframe to the dataprovider # Add the dataframe to the dataprovider
self._dp._add_external_df(pair, df, self._dp._add_producer_df(pair, df,
last_analyzed=la, last_analyzed=la,
timeframe=timeframe, timeframe=timeframe,
candle_type=candle_type, candle_type=candle_type,