consumer subscriptions, fix serializer bug
This commit is contained in:
parent
7952e0df25
commit
47f7c384fb
@ -2,6 +2,7 @@ import logging
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
||||||
|
|
||||||
|
from freqtrade.enums import RPCMessageType
|
||||||
from freqtrade.rpc.api_server.deps import get_channel_manager
|
from freqtrade.rpc.api_server.deps import get_channel_manager
|
||||||
from freqtrade.rpc.api_server.ws.utils import is_websocket_alive
|
from freqtrade.rpc.api_server.ws.utils import is_websocket_alive
|
||||||
|
|
||||||
@ -34,7 +35,15 @@ async def message_endpoint(
|
|||||||
# be a list of topics to subscribe too. List[str]
|
# be a list of topics to subscribe too. List[str]
|
||||||
# Maybe allow the consumer to update the topics subscribed
|
# Maybe allow the consumer to update the topics subscribed
|
||||||
# during runtime?
|
# during runtime?
|
||||||
logger.info(f"Consumer request - {request}")
|
|
||||||
|
# If the request isn't a list then skip it
|
||||||
|
if not isinstance(request, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if all topics listed are an RPCMessageType
|
||||||
|
if all([any(x.value == topic for x in RPCMessageType) for topic in request]):
|
||||||
|
logger.debug(f"{ws.client} subscribed to topics: {request}")
|
||||||
|
channel.set_subscriptions(request)
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
# Handle client disconnects
|
# Handle client disconnects
|
||||||
|
0
freqtrade/rpc/api_server/ws/__init__.py
Normal file
0
freqtrade/rpc/api_server/ws/__init__.py
Normal file
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
from typing import Type
|
from typing import List, Type
|
||||||
|
|
||||||
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
|
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
|
||||||
from freqtrade.rpc.api_server.ws.serializer import ORJSONWebSocketSerializer, WebSocketSerializer
|
from freqtrade.rpc.api_server.ws.serializer import ORJSONWebSocketSerializer, WebSocketSerializer
|
||||||
@ -25,6 +25,8 @@ class WebSocketChannel:
|
|||||||
# The Serializing class for the WebSocket object
|
# The Serializing class for the WebSocket object
|
||||||
self._serializer_cls = serializer_cls
|
self._serializer_cls = serializer_cls
|
||||||
|
|
||||||
|
self._subscriptions: List[str] = []
|
||||||
|
|
||||||
# Internal event to signify a closed websocket
|
# Internal event to signify a closed websocket
|
||||||
self._closed = False
|
self._closed = False
|
||||||
|
|
||||||
@ -57,9 +59,28 @@ class WebSocketChannel:
|
|||||||
|
|
||||||
self._closed = True
|
self._closed = True
|
||||||
|
|
||||||
def is_closed(self):
|
def is_closed(self) -> bool:
|
||||||
|
"""
|
||||||
|
Closed flag
|
||||||
|
"""
|
||||||
return self._closed
|
return self._closed
|
||||||
|
|
||||||
|
def set_subscriptions(self, subscriptions: List[str] = []) -> None:
|
||||||
|
"""
|
||||||
|
Set which subscriptions this channel is subscribed to
|
||||||
|
|
||||||
|
:param subscriptions: List of subscriptions, List[str]
|
||||||
|
"""
|
||||||
|
self._subscriptions = subscriptions
|
||||||
|
|
||||||
|
def subscribed_to(self, message_type: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this channel is subscribed to the message_type
|
||||||
|
|
||||||
|
:param message_type: The message type to check
|
||||||
|
"""
|
||||||
|
return message_type in self._subscriptions
|
||||||
|
|
||||||
|
|
||||||
class ChannelManager:
|
class ChannelManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -120,10 +141,12 @@ class ChannelManager:
|
|||||||
:param data: The data to send
|
:param data: The data to send
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
logger.debug(f"Broadcasting data: {data}")
|
message_type = data.get('type')
|
||||||
|
logger.debug(f"Broadcasting data: {message_type} - {data}")
|
||||||
for websocket, channel in self.channels.items():
|
for websocket, channel in self.channels.items():
|
||||||
try:
|
try:
|
||||||
await channel.send(data)
|
if channel.subscribed_to(message_type):
|
||||||
|
await channel.send(data)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# Handle cannot send after close cases
|
# Handle cannot send after close cases
|
||||||
await self.on_disconnect(websocket)
|
await self.on_disconnect(websocket)
|
||||||
|
@ -54,7 +54,7 @@ class ORJSONWebSocketSerializer(WebSocketSerializer):
|
|||||||
return orjson.dumps(data, option=self.ORJSON_OPTIONS)
|
return orjson.dumps(data, option=self.ORJSON_OPTIONS)
|
||||||
|
|
||||||
def _deserialize(self, data):
|
def _deserialize(self, data):
|
||||||
return orjson.loads(data, option=self.ORJSON_OPTIONS)
|
return orjson.loads(data)
|
||||||
|
|
||||||
|
|
||||||
class MsgPackWebSocketSerializer(WebSocketSerializer):
|
class MsgPackWebSocketSerializer(WebSocketSerializer):
|
||||||
|
@ -4,22 +4,31 @@ import socket
|
|||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
|
from freqtrade.enums import RPCMessageType
|
||||||
|
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def _client():
|
async def _client():
|
||||||
|
subscribe_topics = [RPCMessageType.WHITELIST]
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
url = "ws://localhost:8080/api/v1/message/ws?token=testtoken"
|
url = "ws://localhost:8080/api/v1/message/ws?token=testtoken"
|
||||||
async with websockets.connect(url) as ws:
|
async with websockets.connect(url) as ws:
|
||||||
|
channel = WebSocketChannel(ws)
|
||||||
|
|
||||||
logger.info("Connection successful")
|
logger.info("Connection successful")
|
||||||
|
# Tell the producer we only want these topics
|
||||||
|
await channel.send(subscribe_topics)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
data = await asyncio.wait_for(
|
data = await asyncio.wait_for(
|
||||||
ws.recv(),
|
channel.recv(),
|
||||||
timeout=5
|
timeout=5
|
||||||
)
|
)
|
||||||
logger.info(f"Data received - {data}")
|
logger.info(f"Data received - {data}")
|
||||||
@ -27,14 +36,14 @@ async def _client():
|
|||||||
# We haven't received data yet. Check the connection and continue.
|
# We haven't received data yet. Check the connection and continue.
|
||||||
try:
|
try:
|
||||||
# ping
|
# ping
|
||||||
ping = await ws.ping()
|
ping = await channel.ping()
|
||||||
await asyncio.wait_for(ping, timeout=2)
|
await asyncio.wait_for(ping, timeout=2)
|
||||||
logger.debug(f"Connection to {url} still alive...")
|
logger.debug(f"Connection to {url} still alive...")
|
||||||
continue
|
continue
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Ping error {url} - retrying in 5s")
|
f"Ping error {url} - retrying in 5s")
|
||||||
asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
break
|
break
|
||||||
|
|
||||||
except (socket.gaierror, ConnectionRefusedError):
|
except (socket.gaierror, ConnectionRefusedError):
|
||||||
|
Loading…
Reference in New Issue
Block a user