diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index f55b2dbd3..b60210143 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -1,6 +1,7 @@ import logging from typing import Any, Dict +import websockets from fastapi import APIRouter, Depends, WebSocketDisconnect from fastapi.websockets import WebSocket, WebSocketState from pydantic import ValidationError @@ -102,7 +103,6 @@ async def message_endpoint( """ try: channel = await channel_manager.on_connect(ws) - if await is_websocket_alive(ws): logger.info(f"Consumer connected - {channel}") @@ -115,26 +115,34 @@ async def message_endpoint( # Process the request here await _process_consumer_request(request, channel, rpc) - except WebSocketDisconnect: + except ( + WebSocketDisconnect, + websockets.exceptions.WebSocketException + ): # Handle client disconnects logger.info(f"Consumer disconnected - {channel}") - await channel_manager.on_disconnect(ws) - except Exception as e: - logger.info(f"Consumer connection failed - {channel}") - logger.exception(e) + except RuntimeError: # Handle cases like - # RuntimeError('Cannot call "send" once a closed message has been sent') + pass + except Exception as e: + logger.info(f"Consumer connection failed - {channel}") + logger.debug(e, exc_info=e) + finally: await channel_manager.on_disconnect(ws) else: + if channel: + await channel_manager.on_disconnect(ws) await ws.close() except RuntimeError: # WebSocket was closed - await channel_manager.on_disconnect(ws) - + # Do nothing + pass except Exception as e: logger.error(f"Failed to serve - {ws.client}") # Log tracebacks to keep track of what errors are happening logger.exception(e) + finally: await channel_manager.on_disconnect(ws) diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index 53af91477..4a09fd78e 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -245,6 +245,7 @@ class ApiServer(RPCHandler): use_colors=False, log_config=None, access_log=True if verbosity != 'error' else False, + ws_ping_interval=None # We do this explicitly ourselves ) try: self._server = UvicornServer(uvconfig) diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 69a32e266..8c6c56d6e 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -1,6 +1,7 @@ +import asyncio import logging from threading import RLock -from typing import List, Optional, Type +from typing import Any, Dict, List, Optional, Type from uuid import uuid4 from fastapi import WebSocket as FastAPIWebSocket @@ -34,6 +35,8 @@ class WebSocketChannel: self._serializer_cls = serializer_cls self._subscriptions: List[str] = [] + self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() + self._relay_task = asyncio.create_task(self.relay()) # Internal event to signify a closed websocket self._closed = False @@ -72,6 +75,7 @@ class WebSocketChannel: """ self._closed = True + self._relay_task.cancel() def is_closed(self) -> bool: """ @@ -95,6 +99,20 @@ class WebSocketChannel: """ return message_type in self._subscriptions + async def relay(self): + """ + Relay messages from the channel's queue and send them out. This is started + as a task. + """ + while True: + message = await self.queue.get() + try: + await self.send(message) + self.queue.task_done() + except RuntimeError: + # The connection was closed, just exit the task + return + class ChannelManager: def __init__(self): @@ -155,12 +173,12 @@ class ChannelManager: with self._lock: message_type = data.get('type') for websocket, channel in self.channels.copy().items(): - try: - if channel.subscribed_to(message_type): - await channel.send(data) - except RuntimeError: - # Handle cannot send after close cases - await self.on_disconnect(websocket) + if channel.subscribed_to(message_type): + if not channel.queue.full(): + channel.queue.put_nowait(data) + else: + logger.info(f"Channel {channel} is too far behind, disconnecting") + await self.on_disconnect(websocket) async def send_direct(self, channel, data): """