moved ws_schemas, first ws tests
This commit is contained in:
		| @@ -2,6 +2,7 @@ import logging | |||||||
| from typing import Any, Dict | from typing import Any, Dict | ||||||
|  |  | ||||||
| from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect | from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect | ||||||
|  | from pydantic import ValidationError | ||||||
| # fastapi does not make this available through it, so import directly from starlette | # fastapi does not make this available through it, so import directly from starlette | ||||||
| from starlette.websockets import WebSocketState | from starlette.websockets import WebSocketState | ||||||
|  |  | ||||||
| @@ -9,9 +10,8 @@ from freqtrade.enums import RPCMessageType, RPCRequestType | |||||||
| from freqtrade.rpc.api_server.api_auth import get_ws_token | from freqtrade.rpc.api_server.api_auth import get_ws_token | ||||||
| from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc | from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc | ||||||
| from freqtrade.rpc.api_server.ws.channel import WebSocketChannel | from freqtrade.rpc.api_server.ws.channel import WebSocketChannel | ||||||
| from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage, | from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, | ||||||
|                                                 WSMessageSchema, WSRequestSchema, |                                                  WSRequestSchema, WSWhitelistMessage) | ||||||
|                                                 WSWhitelistMessage) |  | ||||||
| from freqtrade.rpc.rpc import RPC | from freqtrade.rpc.rpc import RPC | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -102,13 +102,11 @@ async def message_endpoint( | |||||||
|     Message WebSocket endpoint, facilitates sending RPC messages |     Message WebSocket endpoint, facilitates sending RPC messages | ||||||
|     """ |     """ | ||||||
|     try: |     try: | ||||||
|         if is_websocket_alive(ws): |  | ||||||
|         # TODO: |         # TODO: | ||||||
|         # Return a channel ID, pass that instead of ws to the rest of the methods |         # Return a channel ID, pass that instead of ws to the rest of the methods | ||||||
|         channel = await channel_manager.on_connect(ws) |         channel = await channel_manager.on_connect(ws) | ||||||
|  |  | ||||||
|             if not channel: |         if await is_websocket_alive(ws): | ||||||
|                 return |  | ||||||
|  |  | ||||||
|             logger.info(f"Consumer connected - {channel}") |             logger.info(f"Consumer connected - {channel}") | ||||||
|  |  | ||||||
| @@ -131,6 +129,9 @@ async def message_endpoint( | |||||||
|                 # RuntimeError('Cannot call "send" once a closed message has been sent') |                 # RuntimeError('Cannot call "send" once a closed message has been sent') | ||||||
|                 await channel_manager.on_disconnect(ws) |                 await channel_manager.on_disconnect(ws) | ||||||
|  |  | ||||||
|  |         else: | ||||||
|  |             ws.close() | ||||||
|  |  | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         logger.error(f"Failed to serve - {ws.client}") |         logger.error(f"Failed to serve - {ws.client}") | ||||||
|         # Log tracebacks to keep track of what errors are happening |         # Log tracebacks to keep track of what errors are happening | ||||||
|   | |||||||
| @@ -2,15 +2,12 @@ from datetime import datetime | |||||||
| from typing import Any, Dict, List, Optional | from typing import Any, Dict, List, Optional | ||||||
| 
 | 
 | ||||||
| from pandas import DataFrame | from pandas import DataFrame | ||||||
| from pydantic import BaseModel, ValidationError | from pydantic import BaseModel | ||||||
| 
 | 
 | ||||||
| from freqtrade.constants import PairWithTimeframe | from freqtrade.constants import PairWithTimeframe | ||||||
| from freqtrade.enums.rpcmessagetype import RPCMessageType, RPCRequestType | from freqtrade.enums.rpcmessagetype import RPCMessageType, RPCRequestType | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| __all__ = ('WSRequestSchema', 'WSMessageSchema', 'ValidationError') |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class BaseArbitraryModel(BaseModel): | class BaseArbitraryModel(BaseModel): | ||||||
|     class Config: |     class Config: | ||||||
|         arbitrary_types_allowed = True |         arbitrary_types_allowed = True | ||||||
| @@ -11,15 +11,16 @@ from threading import Thread | |||||||
| from typing import TYPE_CHECKING, Any, Dict, List | from typing import TYPE_CHECKING, Any, Dict, List | ||||||
|  |  | ||||||
| import websockets | import websockets | ||||||
|  | from pydantic import ValidationError | ||||||
|  |  | ||||||
| from freqtrade.data.dataprovider import DataProvider | from freqtrade.data.dataprovider import DataProvider | ||||||
| from freqtrade.enums import RPCMessageType | from freqtrade.enums import RPCMessageType | ||||||
| from freqtrade.misc import remove_entry_exit_signals | from freqtrade.misc import remove_entry_exit_signals | ||||||
| from freqtrade.rpc.api_server.ws.channel import WebSocketChannel | from freqtrade.rpc.api_server.ws.channel import WebSocketChannel | ||||||
| from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage, | from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSAnalyzedDFRequest, | ||||||
|                                                 WSAnalyzedDFRequest, WSMessageSchema, |                                                  WSMessageSchema, WSRequestSchema, | ||||||
|                                                 WSRequestSchema, WSSubscribeRequest, |                                                  WSSubscribeRequest, WSWhitelistMessage, | ||||||
|                                                 WSWhitelistMessage, WSWhitelistRequest) |                                                  WSWhitelistRequest) | ||||||
|  |  | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|   | |||||||
| @@ -56,8 +56,8 @@ def botclient(default_conf, mocker): | |||||||
|         apiserver.add_rpc_handler(rpc) |         apiserver.add_rpc_handler(rpc) | ||||||
|         yield ftbot, TestClient(apiserver.app) |         yield ftbot, TestClient(apiserver.app) | ||||||
|         # Cleanup ... ? |         # Cleanup ... ? | ||||||
|         apiserver.cleanup() |  | ||||||
|     finally: |     finally: | ||||||
|  |         apiserver.cleanup() | ||||||
|         ApiServer.shutdown() |         ApiServer.shutdown() | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -171,7 +171,7 @@ def test_api_ws_auth(botclient): | |||||||
|     url = f"/api/v1/message/ws?token={good_token}" |     url = f"/api/v1/message/ws?token={good_token}" | ||||||
|  |  | ||||||
|     with client.websocket_connect(url) as websocket: |     with client.websocket_connect(url) as websocket: | ||||||
|         websocket.send(1) |         pass | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_api_unauthorized(botclient): | def test_api_unauthorized(botclient): | ||||||
| @@ -1685,3 +1685,23 @@ def test_health(botclient): | |||||||
|     ret = rc.json() |     ret = rc.json() | ||||||
|     assert ret['last_process_ts'] == 0 |     assert ret['last_process_ts'] == 0 | ||||||
|     assert ret['last_process'] == '1970-01-01T00:00:00+00:00' |     assert ret['last_process'] == '1970-01-01T00:00:00+00:00' | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_api_ws_subscribe(botclient, mocker): | ||||||
|  |     ftbot, client = botclient | ||||||
|  |     ws_url = f"/api/v1/message/ws?token={_TEST_WS_TOKEN}" | ||||||
|  |  | ||||||
|  |     sub_mock = mocker.patch( | ||||||
|  |         'freqtrade.rpc.api_server.ws.channel.WebSocketChannel.set_subscriptions', MagicMock()) | ||||||
|  |  | ||||||
|  |     with client.websocket_connect(ws_url) as ws: | ||||||
|  |         ws.send_json({'type': 'subscribe', 'data': ['whitelist']}) | ||||||
|  |  | ||||||
|  |     # Check call count is now 1 as we sent a valid subscribe request | ||||||
|  |     assert sub_mock.call_count == 1 | ||||||
|  |  | ||||||
|  |     with client.websocket_connect(ws_url) as ws: | ||||||
|  |         ws.send_json({'type': 'subscribe', 'data': 'whitelist'}) | ||||||
|  |  | ||||||
|  |     # Call count hasn't changed as the subscribe request was invalid | ||||||
|  |     assert sub_mock.call_count == 1 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user