diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 01243b0cc..2454646ea 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -1,16 +1,14 @@ -import asyncio import logging from typing import Any, Dict from fastapi import APIRouter, Depends -from fastapi.websockets import WebSocket, WebSocketDisconnect +from fastapi.websockets import WebSocket from pydantic import ValidationError -from websockets.exceptions import ConnectionClosed from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.rpc.api_server.api_auth import validate_ws_token from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc -from freqtrade.rpc.api_server.ws import WebSocketChannel +from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel from freqtrade.rpc.api_server.ws.message_stream import MessageStream from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, WSRequestSchema, WSWhitelistMessage) @@ -23,45 +21,20 @@ logger = logging.getLogger(__name__) router = APIRouter() -class WebSocketChannelClosed(Exception): - """ - General WebSocket exception to signal closing the channel - """ - pass - - async def channel_reader(channel: WebSocketChannel, rpc: RPC): """ Iterate over the messages from the channel and process the request """ - try: - async for message in channel: - await _process_consumer_request(message, channel, rpc) - except ( - RuntimeError, - WebSocketDisconnect, - ConnectionClosed - ): - raise WebSocketChannelClosed - except asyncio.CancelledError: - return + async for message in channel: + await _process_consumer_request(message, channel, rpc) async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream): """ Iterate over messages in the message stream and send them """ - try: - async for message in message_stream: - await channel.send(message) - except ( - RuntimeError, - WebSocketDisconnect, - ConnectionClosed - ): - raise WebSocketChannelClosed - except asyncio.CancelledError: - return + async for message in message_stream: + await channel.send(message) async def _process_consumer_request( @@ -103,15 +76,11 @@ async def _process_consumer_request( # Format response response = WSWhitelistMessage(data=whitelist) - # Send it back await channel.send(response.dict(exclude_none=True)) elif type == RPCRequestType.ANALYZED_DF: - limit = None - - if data: - # Limit the amount of candles per dataframe to 'limit' or 1500 - limit = max(data.get('limit', 1500), 1500) + # Limit the amount of candles per dataframe to 'limit' or 1500 + limit = min(data.get('limit', 1500), 1500) if data else None # For every pair in the generator, send a separate message for message in rpc._ws_request_analyzed_df(limit): @@ -127,17 +96,8 @@ async def message_endpoint( rpc: RPC = Depends(get_rpc), message_stream: MessageStream = Depends(get_message_stream) ): - async with WebSocketChannel(websocket).connect() as channel: - try: - logger.info(f"Channel connected - {channel}") - - channel_tasks = asyncio.gather( - channel_reader(channel, rpc), - channel_broadcaster(channel, message_stream) - ) - await channel_tasks - except WebSocketChannelClosed: - pass - finally: - logger.info(f"Channel disconnected - {channel}") - channel_tasks.cancel() + async with create_channel(websocket) as channel: + await channel.run_channel_tasks( + channel_reader(channel, rpc), + channel_broadcaster(channel, message_stream) + ) diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index f100a46ef..4a9f089d1 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -94,6 +94,7 @@ class ApiServer(RPCHandler): del ApiServer._rpc if self._server and not self._standalone: logger.info("Stopping API Server") + # self._server.force_exit, self._server.should_exit = True, True self._server.cleanup() @classmethod diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 8e248d368..d4d4d6453 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -29,6 +29,7 @@ class WebSocketChannel: # Internal event to signify a closed websocket self._closed = asyncio.Event() + self._send_timeout_high_limit = 2 # The subscribed message types self._subscriptions: List[str] = [] @@ -36,6 +37,9 @@ class WebSocketChannel: # Wrap the WebSocket in the Serializing class self._wrapped_ws = serializer_cls(self._websocket) + # The async tasks created for the channel + self._channel_tasks: List[asyncio.Task] = [] + def __repr__(self): return f"WebSocketChannel({self.channel_id}, {self.remote_addr})" @@ -51,7 +55,14 @@ class WebSocketChannel: """ Send a message on the wrapped websocket """ - await self._wrapped_ws.send(message) + + # Without this sleep, messages would send to one channel + # first then another after the first one finished. + # With the sleep call, it gives control to the event + # loop to schedule other channel send methods. + await asyncio.sleep(0) + + return await self._wrapped_ws.send(message) async def recv(self): """ @@ -77,7 +88,6 @@ class WebSocketChannel: """ self._closed.set() - self._relay_task.cancel() try: await self._websocket.close() @@ -106,23 +116,68 @@ class WebSocketChannel: """ return message_type in self._subscriptions + async def run_channel_tasks(self, *tasks, **kwargs): + """ + Create and await on the channel tasks unless an exception + was raised, then cancel them all. + + :params *tasks: All coros or tasks to be run concurrently + :param **kwargs: Any extra kwargs to pass to gather + """ + + # Wrap the coros into tasks if they aren't already + self._channel_tasks = [ + task if isinstance(task, asyncio.Task) else asyncio.create_task(task) + for task in tasks + ] + + try: + await asyncio.gather(*self._channel_tasks, **kwargs) + except Exception: + # If an exception occurred, cancel the rest of the tasks and bubble up + # the error that was caught here + await self.cancel_channel_tasks() + raise + + async def cancel_channel_tasks(self): + """ + Cancel and wait on all channel tasks + """ + for task in self._channel_tasks: + task.cancel() + + # Wait for tasks to finish cancelling + try: + await asyncio.wait(self._channel_tasks) + except asyncio.CancelledError: + pass + + self._channel_tasks = [] + async def __aiter__(self): """ Generator for received messages """ - while True: - try: - yield await self.recv() - except Exception: - break + # We can not catch any errors here as websocket.recv is + # the first to catch any disconnects and bubble it up + # so the connection is garbage collected right away + while not self.is_closed(): + yield await self.recv() - @asynccontextmanager - async def connect(self): - """ - Context manager for safely opening and closing the websocket connection - """ - try: - await self.accept() - yield self - finally: - await self.close() + +@asynccontextmanager +async def create_channel(websocket: WebSocketType, **kwargs): + """ + Context manager for safely opening and closing a WebSocketChannel + """ + channel = WebSocketChannel(websocket, **kwargs) + try: + await channel.accept() + logger.info(f"Connected to channel - {channel}") + + yield channel + except Exception: + pass + finally: + await channel.close() + logger.info(f"Disconnected from channel - {channel}") diff --git a/freqtrade/rpc/api_server/ws/message_stream.py b/freqtrade/rpc/api_server/ws/message_stream.py index f77242719..9592908ab 100644 --- a/freqtrade/rpc/api_server/ws/message_stream.py +++ b/freqtrade/rpc/api_server/ws/message_stream.py @@ -17,7 +17,8 @@ class MessageStream: async def subscribe(self): waiter = self._waiter while True: - message, waiter = await waiter + # Shield the future from being cancelled by a task waiting on it + message, waiter = await asyncio.shield(waiter) yield message __aiter__ = subscribe