diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index 4a9f089d1..e4eb3895d 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -212,7 +212,6 @@ class ApiServer(RPCHandler): if self._standalone: self._server.run() else: - # self.start_message_queue() self._server.run_in_thread() except Exception: logger.exception("Api server failed to start.") diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 8699de66c..9dea21f3b 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -6,6 +6,9 @@ from contextlib import asynccontextmanager from typing import Any, AsyncIterator, Deque, Dict, List, Optional, Type, Union from uuid import uuid4 +from fastapi import WebSocketDisconnect +from websockets.exceptions import ConnectionClosed + from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer, WebSocketSerializer) @@ -189,7 +192,16 @@ class WebSocketChannel: task.cancel() # Wait for tasks to finish cancelling - await asyncio.wait(self._channel_tasks) + try: + await task + except ( + asyncio.CancelledError, + WebSocketDisconnect, + ConnectionClosed + ): + pass + except Exception as e: + logger.info(f"Encountered unknown exception: {e}", exc_info=e) self._channel_tasks = [] diff --git a/tests/rpc/test_rpc_apiserver.py b/tests/rpc/test_rpc_apiserver.py index 969728b6f..25d6a32e3 100644 --- a/tests/rpc/test_rpc_apiserver.py +++ b/tests/rpc/test_rpc_apiserver.py @@ -57,7 +57,10 @@ def botclient(default_conf, mocker): try: apiserver = ApiServer(default_conf) apiserver.add_rpc_handler(rpc) - yield ftbot, TestClient(apiserver.app) + # We need to use the TestClient as a context manager to + # handle lifespan events correctly + with TestClient(apiserver.app) as client: + yield ftbot, client # Cleanup ... ? finally: if apiserver: @@ -438,7 +441,6 @@ def test_api_cleanup(default_conf, mocker, caplog): apiserver.cleanup() assert apiserver._server.cleanup.call_count == 1 assert log_has("Stopping API Server", caplog) - assert log_has("Stopping API Server background tasks", caplog) ApiServer.shutdown() @@ -1714,12 +1716,14 @@ def test_api_ws_subscribe(botclient, mocker): with client.websocket_connect(ws_url) as ws: ws.send_json({'type': 'subscribe', 'data': ['whitelist']}) + time.sleep(1) # Check call count is now 1 as we sent a valid subscribe request assert sub_mock.call_count == 1 with client.websocket_connect(ws_url) as ws: ws.send_json({'type': 'subscribe', 'data': 'whitelist'}) + time.sleep(1) # Call count hasn't changed as the subscribe request was invalid assert sub_mock.call_count == 1 @@ -1773,24 +1777,18 @@ def test_api_ws_send_msg(default_conf, mocker, caplog): mocker.patch('freqtrade.rpc.api_server.ApiServer.start_api') apiserver = ApiServer(default_conf) apiserver.add_rpc_handler(RPC(get_patched_freqtradebot(mocker, default_conf))) - apiserver.start_message_queue() - # Give the queue thread time to start - time.sleep(0.2) - # Test message_queue coro receives the message - test_message = {"type": "status", "data": "test"} - apiserver.send_msg(test_message) - time.sleep(0.1) # Not sure how else to wait for the coro to receive the data - assert log_has("Found message of type: status", caplog) + # Start test client context manager to run lifespan events + with TestClient(apiserver.app): + # Test message is published on the Message Stream + test_message = {"type": "status", "data": "test"} + first_waiter = apiserver._message_stream._waiter + apiserver.send_msg(test_message) + assert first_waiter.result()[0] == test_message - # Test if exception logged when error occurs in sending - mocker.patch('freqtrade.rpc.api_server.ws.channel.ChannelManager.broadcast', - side_effect=Exception) - - apiserver.send_msg(test_message) - time.sleep(0.1) # Not sure how else to wait for the coro to receive the data - assert log_has_re(r"Exception happened in background task.*", caplog) + second_waiter = apiserver._message_stream._waiter + apiserver.send_msg(test_message) + assert first_waiter != second_waiter finally: - apiserver.cleanup() ApiServer.shutdown()