diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 384bd4115..f6eb59f87 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -2,6 +2,7 @@ import logging from typing import Any, Dict from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect +from pydantic import ValidationError # fastapi does not make this available through it, so import directly from starlette from starlette.websockets import WebSocketState @@ -9,9 +10,8 @@ from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.rpc.api_server.api_auth import get_ws_token from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc from freqtrade.rpc.api_server.ws.channel import WebSocketChannel -from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage, - WSMessageSchema, WSRequestSchema, - WSWhitelistMessage) +from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, + WSRequestSchema, WSWhitelistMessage) from freqtrade.rpc.rpc import RPC @@ -102,13 +102,11 @@ async def message_endpoint( Message WebSocket endpoint, facilitates sending RPC messages """ try: - if is_websocket_alive(ws): - # TODO: - # Return a channel ID, pass that instead of ws to the rest of the methods - channel = await channel_manager.on_connect(ws) + # TODO: + # Return a channel ID, pass that instead of ws to the rest of the methods + channel = await channel_manager.on_connect(ws) - if not channel: - return + if await is_websocket_alive(ws): logger.info(f"Consumer connected - {channel}") @@ -131,6 +129,9 @@ async def message_endpoint( # RuntimeError('Cannot call "send" once a closed message has been sent') await channel_manager.on_disconnect(ws) + else: + ws.close() + except Exception as e: logger.error(f"Failed to serve - {ws.client}") # Log tracebacks to keep track of what errors are happening diff --git a/freqtrade/rpc/api_server/ws/schema.py b/freqtrade/rpc/api_server/ws_schemas.py similarity index 92% rename from freqtrade/rpc/api_server/ws/schema.py rename to freqtrade/rpc/api_server/ws_schemas.py index 0baa8d233..255226d84 100644 --- a/freqtrade/rpc/api_server/ws/schema.py +++ b/freqtrade/rpc/api_server/ws_schemas.py @@ -2,15 +2,12 @@ from datetime import datetime from typing import Any, Dict, List, Optional from pandas import DataFrame -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel from freqtrade.constants import PairWithTimeframe from freqtrade.enums.rpcmessagetype import RPCMessageType, RPCRequestType -__all__ = ('WSRequestSchema', 'WSMessageSchema', 'ValidationError') - - class BaseArbitraryModel(BaseModel): class Config: arbitrary_types_allowed = True diff --git a/freqtrade/rpc/external_message_consumer.py b/freqtrade/rpc/external_message_consumer.py index 525f4282c..abeedb0a4 100644 --- a/freqtrade/rpc/external_message_consumer.py +++ b/freqtrade/rpc/external_message_consumer.py @@ -11,15 +11,16 @@ from threading import Thread from typing import TYPE_CHECKING, Any, Dict, List import websockets +from pydantic import ValidationError from freqtrade.data.dataprovider import DataProvider from freqtrade.enums import RPCMessageType from freqtrade.misc import remove_entry_exit_signals from freqtrade.rpc.api_server.ws.channel import WebSocketChannel -from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage, - WSAnalyzedDFRequest, WSMessageSchema, - WSRequestSchema, WSSubscribeRequest, - WSWhitelistMessage, WSWhitelistRequest) +from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSAnalyzedDFRequest, + WSMessageSchema, WSRequestSchema, + WSSubscribeRequest, WSWhitelistMessage, + WSWhitelistRequest) if TYPE_CHECKING: diff --git a/tests/rpc/test_rpc_apiserver.py b/tests/rpc/test_rpc_apiserver.py index ccfe31424..de093b66f 100644 --- a/tests/rpc/test_rpc_apiserver.py +++ b/tests/rpc/test_rpc_apiserver.py @@ -56,8 +56,8 @@ def botclient(default_conf, mocker): apiserver.add_rpc_handler(rpc) yield ftbot, TestClient(apiserver.app) # Cleanup ... ? - apiserver.cleanup() finally: + apiserver.cleanup() ApiServer.shutdown() @@ -171,7 +171,7 @@ def test_api_ws_auth(botclient): url = f"/api/v1/message/ws?token={good_token}" with client.websocket_connect(url) as websocket: - websocket.send(1) + pass def test_api_unauthorized(botclient): @@ -1685,3 +1685,23 @@ def test_health(botclient): ret = rc.json() assert ret['last_process_ts'] == 0 assert ret['last_process'] == '1970-01-01T00:00:00+00:00' + + +def test_api_ws_subscribe(botclient, mocker): + ftbot, client = botclient + ws_url = f"/api/v1/message/ws?token={_TEST_WS_TOKEN}" + + sub_mock = mocker.patch( + 'freqtrade.rpc.api_server.ws.channel.WebSocketChannel.set_subscriptions', MagicMock()) + + with client.websocket_connect(ws_url) as ws: + ws.send_json({'type': 'subscribe', 'data': ['whitelist']}) + + # Check call count is now 1 as we sent a valid subscribe request + assert sub_mock.call_count == 1 + + with client.websocket_connect(ws_url) as ws: + ws.send_json({'type': 'subscribe', 'data': 'whitelist'}) + + # Call count hasn't changed as the subscribe request was invalid + assert sub_mock.call_count == 1