moved ws_schemas, first ws tests
This commit is contained in:
parent
b9e7af1ce2
commit
2b9c8550b0
@ -2,6 +2,7 @@ import logging
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
||||||
|
from pydantic import ValidationError
|
||||||
# fastapi does not make this available through it, so import directly from starlette
|
# fastapi does not make this available through it, so import directly from starlette
|
||||||
from starlette.websockets import WebSocketState
|
from starlette.websockets import WebSocketState
|
||||||
|
|
||||||
@ -9,9 +10,8 @@ from freqtrade.enums import RPCMessageType, RPCRequestType
|
|||||||
from freqtrade.rpc.api_server.api_auth import get_ws_token
|
from freqtrade.rpc.api_server.api_auth import get_ws_token
|
||||||
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
|
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
|
||||||
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
|
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
|
||||||
from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage,
|
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
|
||||||
WSMessageSchema, WSRequestSchema,
|
WSRequestSchema, WSWhitelistMessage)
|
||||||
WSWhitelistMessage)
|
|
||||||
from freqtrade.rpc.rpc import RPC
|
from freqtrade.rpc.rpc import RPC
|
||||||
|
|
||||||
|
|
||||||
@ -102,13 +102,11 @@ async def message_endpoint(
|
|||||||
Message WebSocket endpoint, facilitates sending RPC messages
|
Message WebSocket endpoint, facilitates sending RPC messages
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if is_websocket_alive(ws):
|
# TODO:
|
||||||
# TODO:
|
# Return a channel ID, pass that instead of ws to the rest of the methods
|
||||||
# Return a channel ID, pass that instead of ws to the rest of the methods
|
channel = await channel_manager.on_connect(ws)
|
||||||
channel = await channel_manager.on_connect(ws)
|
|
||||||
|
|
||||||
if not channel:
|
if await is_websocket_alive(ws):
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Consumer connected - {channel}")
|
logger.info(f"Consumer connected - {channel}")
|
||||||
|
|
||||||
@ -131,6 +129,9 @@ async def message_endpoint(
|
|||||||
# RuntimeError('Cannot call "send" once a closed message has been sent')
|
# RuntimeError('Cannot call "send" once a closed message has been sent')
|
||||||
await channel_manager.on_disconnect(ws)
|
await channel_manager.on_disconnect(ws)
|
||||||
|
|
||||||
|
else:
|
||||||
|
ws.close()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to serve - {ws.client}")
|
logger.error(f"Failed to serve - {ws.client}")
|
||||||
# Log tracebacks to keep track of what errors are happening
|
# Log tracebacks to keep track of what errors are happening
|
||||||
|
@ -2,15 +2,12 @@ from datetime import datetime
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from freqtrade.constants import PairWithTimeframe
|
from freqtrade.constants import PairWithTimeframe
|
||||||
from freqtrade.enums.rpcmessagetype import RPCMessageType, RPCRequestType
|
from freqtrade.enums.rpcmessagetype import RPCMessageType, RPCRequestType
|
||||||
|
|
||||||
|
|
||||||
__all__ = ('WSRequestSchema', 'WSMessageSchema', 'ValidationError')
|
|
||||||
|
|
||||||
|
|
||||||
class BaseArbitraryModel(BaseModel):
|
class BaseArbitraryModel(BaseModel):
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
@ -11,15 +11,16 @@ from threading import Thread
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, List
|
from typing import TYPE_CHECKING, Any, Dict, List
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from freqtrade.data.dataprovider import DataProvider
|
from freqtrade.data.dataprovider import DataProvider
|
||||||
from freqtrade.enums import RPCMessageType
|
from freqtrade.enums import RPCMessageType
|
||||||
from freqtrade.misc import remove_entry_exit_signals
|
from freqtrade.misc import remove_entry_exit_signals
|
||||||
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
|
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
|
||||||
from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage,
|
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSAnalyzedDFRequest,
|
||||||
WSAnalyzedDFRequest, WSMessageSchema,
|
WSMessageSchema, WSRequestSchema,
|
||||||
WSRequestSchema, WSSubscribeRequest,
|
WSSubscribeRequest, WSWhitelistMessage,
|
||||||
WSWhitelistMessage, WSWhitelistRequest)
|
WSWhitelistRequest)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -56,8 +56,8 @@ def botclient(default_conf, mocker):
|
|||||||
apiserver.add_rpc_handler(rpc)
|
apiserver.add_rpc_handler(rpc)
|
||||||
yield ftbot, TestClient(apiserver.app)
|
yield ftbot, TestClient(apiserver.app)
|
||||||
# Cleanup ... ?
|
# Cleanup ... ?
|
||||||
apiserver.cleanup()
|
|
||||||
finally:
|
finally:
|
||||||
|
apiserver.cleanup()
|
||||||
ApiServer.shutdown()
|
ApiServer.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@ -171,7 +171,7 @@ def test_api_ws_auth(botclient):
|
|||||||
url = f"/api/v1/message/ws?token={good_token}"
|
url = f"/api/v1/message/ws?token={good_token}"
|
||||||
|
|
||||||
with client.websocket_connect(url) as websocket:
|
with client.websocket_connect(url) as websocket:
|
||||||
websocket.send(1)
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_api_unauthorized(botclient):
|
def test_api_unauthorized(botclient):
|
||||||
@ -1685,3 +1685,23 @@ def test_health(botclient):
|
|||||||
ret = rc.json()
|
ret = rc.json()
|
||||||
assert ret['last_process_ts'] == 0
|
assert ret['last_process_ts'] == 0
|
||||||
assert ret['last_process'] == '1970-01-01T00:00:00+00:00'
|
assert ret['last_process'] == '1970-01-01T00:00:00+00:00'
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_ws_subscribe(botclient, mocker):
|
||||||
|
ftbot, client = botclient
|
||||||
|
ws_url = f"/api/v1/message/ws?token={_TEST_WS_TOKEN}"
|
||||||
|
|
||||||
|
sub_mock = mocker.patch(
|
||||||
|
'freqtrade.rpc.api_server.ws.channel.WebSocketChannel.set_subscriptions', MagicMock())
|
||||||
|
|
||||||
|
with client.websocket_connect(ws_url) as ws:
|
||||||
|
ws.send_json({'type': 'subscribe', 'data': ['whitelist']})
|
||||||
|
|
||||||
|
# Check call count is now 1 as we sent a valid subscribe request
|
||||||
|
assert sub_mock.call_count == 1
|
||||||
|
|
||||||
|
with client.websocket_connect(ws_url) as ws:
|
||||||
|
ws.send_json({'type': 'subscribe', 'data': 'whitelist'})
|
||||||
|
|
||||||
|
# Call count hasn't changed as the subscribe request was invalid
|
||||||
|
assert sub_mock.call_count == 1
|
||||||
|
Loading…
Reference in New Issue
Block a user