diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 464ea22b2..405beed79 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -2,6 +2,7 @@ import logging from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect +from freqtrade.enums import RPCMessageType from freqtrade.rpc.api_server.deps import get_channel_manager from freqtrade.rpc.api_server.ws.utils import is_websocket_alive @@ -34,7 +35,15 @@ async def message_endpoint( # be a list of topics to subscribe too. List[str] # Maybe allow the consumer to update the topics subscribed # during runtime? - logger.info(f"Consumer request - {request}") + + # If the request isn't a list then skip it + if not isinstance(request, list): + continue + + # Check if all topics listed are an RPCMessageType + if all([any(x.value == topic for x in RPCMessageType) for topic in request]): + logger.debug(f"{ws.client} subscribed to topics: {request}") + channel.set_subscriptions(request) except WebSocketDisconnect: # Handle client disconnects diff --git a/freqtrade/rpc/api_server/ws/__init__.py b/freqtrade/rpc/api_server/ws/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 486e8657b..f24713a77 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -1,6 +1,6 @@ import logging from threading import RLock -from typing import Type +from typing import List, Type from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.serializer import ORJSONWebSocketSerializer, WebSocketSerializer @@ -25,6 +25,8 @@ class WebSocketChannel: # The Serializing class for the WebSocket object self._serializer_cls = serializer_cls + self._subscriptions: List[str] = [] + # Internal event to signify a closed websocket self._closed = False @@ -57,9 +59,28 @@ class WebSocketChannel: self._closed = True - def is_closed(self): + def is_closed(self) -> bool: + """ + Closed flag + """ return self._closed + def set_subscriptions(self, subscriptions: List[str] = []) -> None: + """ + Set which subscriptions this channel is subscribed to + + :param subscriptions: List of subscriptions, List[str] + """ + self._subscriptions = subscriptions + + def subscribed_to(self, message_type: str) -> bool: + """ + Check if this channel is subscribed to the message_type + + :param message_type: The message type to check + """ + return message_type in self._subscriptions + class ChannelManager: def __init__(self): @@ -120,10 +141,12 @@ class ChannelManager: :param data: The data to send """ with self._lock: - logger.debug(f"Broadcasting data: {data}") + message_type = data.get('type') + logger.debug(f"Broadcasting data: {message_type} - {data}") for websocket, channel in self.channels.items(): try: - await channel.send(data) + if channel.subscribed_to(message_type): + await channel.send(data) except RuntimeError: # Handle cannot send after close cases await self.on_disconnect(websocket) diff --git a/freqtrade/rpc/api_server/ws/serializer.py b/freqtrade/rpc/api_server/ws/serializer.py index 40cbbfad7..ae2857f0b 100644 --- a/freqtrade/rpc/api_server/ws/serializer.py +++ b/freqtrade/rpc/api_server/ws/serializer.py @@ -54,7 +54,7 @@ class ORJSONWebSocketSerializer(WebSocketSerializer): return orjson.dumps(data, option=self.ORJSON_OPTIONS) def _deserialize(self, data): - return orjson.loads(data, option=self.ORJSON_OPTIONS) + return orjson.loads(data) class MsgPackWebSocketSerializer(WebSocketSerializer): diff --git a/scripts/test_ws_client.py b/scripts/test_ws_client.py index caa495a19..2c64ae867 100644 --- a/scripts/test_ws_client.py +++ b/scripts/test_ws_client.py @@ -4,22 +4,31 @@ import socket import websockets +from freqtrade.enums import RPCMessageType +from freqtrade.rpc.api_server.ws.channel import WebSocketChannel + logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) async def _client(): + subscribe_topics = [RPCMessageType.WHITELIST] try: while True: try: url = "ws://localhost:8080/api/v1/message/ws?token=testtoken" async with websockets.connect(url) as ws: + channel = WebSocketChannel(ws) + logger.info("Connection successful") + # Tell the producer we only want these topics + await channel.send(subscribe_topics) + while True: try: data = await asyncio.wait_for( - ws.recv(), + channel.recv(), timeout=5 ) logger.info(f"Data received - {data}") @@ -27,14 +36,14 @@ async def _client(): # We haven't received data yet. Check the connection and continue. try: # ping - ping = await ws.ping() + ping = await channel.ping() await asyncio.wait_for(ping, timeout=2) logger.debug(f"Connection to {url} still alive...") continue except Exception: logger.info( f"Ping error {url} - retrying in 5s") - asyncio.sleep(2) + await asyncio.sleep(2) break except (socket.gaierror, ConnectionRefusedError):