consumer subscriptions, fix serializer bug

This commit is contained in:
Timothy Pogue 2022-08-29 15:48:29 -06:00
parent 7952e0df25
commit 47f7c384fb
5 changed files with 50 additions and 9 deletions

View File

@ -2,6 +2,7 @@ import logging
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from freqtrade.enums import RPCMessageType
from freqtrade.rpc.api_server.deps import get_channel_manager
from freqtrade.rpc.api_server.ws.utils import is_websocket_alive
@ -34,7 +35,15 @@ async def message_endpoint(
# be a list of topics to subscribe too. List[str]
# Maybe allow the consumer to update the topics subscribed
# during runtime?
logger.info(f"Consumer request - {request}")
# If the request isn't a list then skip it
if not isinstance(request, list):
continue
# Check if all topics listed are an RPCMessageType
if all([any(x.value == topic for x in RPCMessageType) for topic in request]):
logger.debug(f"{ws.client} subscribed to topics: {request}")
channel.set_subscriptions(request)
except WebSocketDisconnect:
# Handle client disconnects

View File

View File

@ -1,6 +1,6 @@
import logging
from threading import RLock
from typing import Type
from typing import List, Type
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import ORJSONWebSocketSerializer, WebSocketSerializer
@ -25,6 +25,8 @@ class WebSocketChannel:
# The Serializing class for the WebSocket object
self._serializer_cls = serializer_cls
self._subscriptions: List[str] = []
# Internal event to signify a closed websocket
self._closed = False
@ -57,9 +59,28 @@ class WebSocketChannel:
self._closed = True
def is_closed(self):
def is_closed(self) -> bool:
"""
Closed flag
"""
return self._closed
def set_subscriptions(self, subscriptions: List[str] = []) -> None:
"""
Set which subscriptions this channel is subscribed to
:param subscriptions: List of subscriptions, List[str]
"""
self._subscriptions = subscriptions
def subscribed_to(self, message_type: str) -> bool:
"""
Check if this channel is subscribed to the message_type
:param message_type: The message type to check
"""
return message_type in self._subscriptions
class ChannelManager:
def __init__(self):
@ -120,10 +141,12 @@ class ChannelManager:
:param data: The data to send
"""
with self._lock:
logger.debug(f"Broadcasting data: {data}")
message_type = data.get('type')
logger.debug(f"Broadcasting data: {message_type} - {data}")
for websocket, channel in self.channels.items():
try:
await channel.send(data)
if channel.subscribed_to(message_type):
await channel.send(data)
except RuntimeError:
# Handle cannot send after close cases
await self.on_disconnect(websocket)

View File

@ -54,7 +54,7 @@ class ORJSONWebSocketSerializer(WebSocketSerializer):
return orjson.dumps(data, option=self.ORJSON_OPTIONS)
def _deserialize(self, data):
return orjson.loads(data, option=self.ORJSON_OPTIONS)
return orjson.loads(data)
class MsgPackWebSocketSerializer(WebSocketSerializer):

View File

@ -4,22 +4,31 @@ import socket
import websockets
from freqtrade.enums import RPCMessageType
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
async def _client():
subscribe_topics = [RPCMessageType.WHITELIST]
try:
while True:
try:
url = "ws://localhost:8080/api/v1/message/ws?token=testtoken"
async with websockets.connect(url) as ws:
channel = WebSocketChannel(ws)
logger.info("Connection successful")
# Tell the producer we only want these topics
await channel.send(subscribe_topics)
while True:
try:
data = await asyncio.wait_for(
ws.recv(),
channel.recv(),
timeout=5
)
logger.info(f"Data received - {data}")
@ -27,14 +36,14 @@ async def _client():
# We haven't received data yet. Check the connection and continue.
try:
# ping
ping = await ws.ping()
ping = await channel.ping()
await asyncio.wait_for(ping, timeout=2)
logger.debug(f"Connection to {url} still alive...")
continue
except Exception:
logger.info(
f"Ping error {url} - retrying in 5s")
asyncio.sleep(2)
await asyncio.sleep(2)
break
except (socket.gaierror, ConnectionRefusedError):