Merge pull request #7558 from wizrds/feat/queue-per-client-ws

Refactor broadcasting in Message Websocket
This commit is contained in:
Matthias 2022-10-13 09:52:29 +02:00 committed by GitHub
commit e3ca740704
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 21 deletions

View File

@ -4,6 +4,7 @@ from typing import Any, Dict
from fastapi import APIRouter, Depends, WebSocketDisconnect
from fastapi.websockets import WebSocket, WebSocketState
from pydantic import ValidationError
from websockets.exceptions import WebSocketException
from freqtrade.enums import RPCMessageType, RPCRequestType
from freqtrade.rpc.api_server.api_auth import validate_ws_token
@ -102,7 +103,6 @@ async def message_endpoint(
"""
try:
channel = await channel_manager.on_connect(ws)
if await is_websocket_alive(ws):
logger.info(f"Consumer connected - {channel}")
@ -115,26 +115,31 @@ async def message_endpoint(
# Process the request here
await _process_consumer_request(request, channel, rpc)
except WebSocketDisconnect:
except (WebSocketDisconnect, WebSocketException):
# Handle client disconnects
logger.info(f"Consumer disconnected - {channel}")
await channel_manager.on_disconnect(ws)
except Exception as e:
logger.info(f"Consumer connection failed - {channel}")
logger.exception(e)
except RuntimeError:
# Handle cases like -
# RuntimeError('Cannot call "send" once a closed message has been sent')
pass
except Exception as e:
logger.info(f"Consumer connection failed - {channel}: {e}")
logger.debug(e, exc_info=e)
finally:
await channel_manager.on_disconnect(ws)
else:
if channel:
await channel_manager.on_disconnect(ws)
await ws.close()
except RuntimeError:
# WebSocket was closed
await channel_manager.on_disconnect(ws)
# Do nothing
pass
except Exception as e:
logger.error(f"Failed to serve - {ws.client}")
# Log tracebacks to keep track of what errors are happening
logger.exception(e)
finally:
await channel_manager.on_disconnect(ws)

View File

@ -198,10 +198,6 @@ class ApiServer(RPCHandler):
logger.debug(f"Found message of type: {message.get('type')}")
# Broadcast it
await self._ws_channel_manager.broadcast(message)
# Limit messages per sec.
# Could cause problems with queue size if too low, and
# problems with network traffik if too high.
await asyncio.sleep(0.001)
except asyncio.CancelledError:
pass
@ -245,6 +241,7 @@ class ApiServer(RPCHandler):
use_colors=False,
log_config=None,
access_log=True if verbosity != 'error' else False,
ws_ping_interval=None # We do this explicitly ourselves
)
try:
self._server = UvicornServer(uvconfig)

View File

@ -1,6 +1,7 @@
import asyncio
import logging
from threading import RLock
from typing import List, Optional, Type
from typing import Any, Dict, List, Optional, Type
from uuid import uuid4
from fastapi import WebSocket as FastAPIWebSocket
@ -34,6 +35,8 @@ class WebSocketChannel:
self._serializer_cls = serializer_cls
self._subscriptions: List[str] = []
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
self._relay_task = asyncio.create_task(self.relay())
# Internal event to signify a closed websocket
self._closed = False
@ -48,12 +51,18 @@ class WebSocketChannel:
def remote_addr(self):
return self._websocket.remote_addr
async def send(self, data):
async def _send(self, data):
"""
Send data on the wrapped websocket
"""
await self._wrapped_ws.send(data)
async def send(self, data):
"""
Add the data to the queue to be sent
"""
self.queue.put_nowait(data)
async def recv(self):
"""
Receive data on the wrapped websocket
@ -72,6 +81,7 @@ class WebSocketChannel:
"""
self._closed = True
self._relay_task.cancel()
def is_closed(self) -> bool:
"""
@ -95,6 +105,26 @@ class WebSocketChannel:
"""
return message_type in self._subscriptions
async def relay(self):
"""
Relay messages from the channel's queue and send them out. This is started
as a task.
"""
while True:
message = await self.queue.get()
try:
await self._send(message)
self.queue.task_done()
# Limit messages per sec.
# Could cause problems with queue size if too low, and
# problems with network traffik if too high.
# 0.001 = 1000/s
await asyncio.sleep(0.001)
except RuntimeError:
# The connection was closed, just exit the task
return
class ChannelManager:
def __init__(self):
@ -155,11 +185,11 @@ class ChannelManager:
with self._lock:
message_type = data.get('type')
for websocket, channel in self.channels.copy().items():
try:
if channel.subscribed_to(message_type):
if not channel.queue.full():
await channel.send(data)
except RuntimeError:
# Handle cannot send after close cases
else:
logger.info(f"Channel {channel} is too far behind, disconnecting")
await self.on_disconnect(websocket)
async def send_direct(self, channel, data):

View File

@ -62,7 +62,7 @@ class ExternalMessageConsumer:
self.enabled = self._emc_config.get('enabled', False)
self.producers: List[Producer] = self._emc_config.get('producers', [])
self.wait_timeout = self._emc_config.get('wait_timeout', 300) # in seconds
self.wait_timeout = self._emc_config.get('wait_timeout', 30) # in seconds
self.ping_timeout = self._emc_config.get('ping_timeout', 10) # in seconds
self.sleep_time = self._emc_config.get('sleep_time', 10) # in seconds
@ -174,6 +174,7 @@ class ExternalMessageConsumer:
:param producer: Dictionary containing producer info
:param lock: An asyncio Lock
"""
channel = None
while self._running:
try:
host, port = producer['host'], producer['port']
@ -182,7 +183,11 @@ class ExternalMessageConsumer:
ws_url = f"ws://{host}:{port}/api/v1/message/ws?token={token}"
# This will raise InvalidURI if the url is bad
async with websockets.connect(ws_url, max_size=self.message_size_limit) as ws:
async with websockets.connect(
ws_url,
max_size=self.message_size_limit,
ping_interval=None
) as ws:
channel = WebSocketChannel(ws, channel_id=name)
logger.info(f"Producer connection success - {channel}")
@ -224,6 +229,10 @@ class ExternalMessageConsumer:
logger.exception(e)
continue
finally:
if channel:
await channel.close()
async def _receive_messages(
self,
channel: WebSocketChannel,