add websocket request/message schemas
This commit is contained in:
@@ -8,6 +8,8 @@ from starlette.websockets import WebSocketState
|
||||
from freqtrade.enums import RPCMessageType, RPCRequestType
|
||||
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
|
||||
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
|
||||
from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage,
|
||||
WSRequestSchema, WSWhitelistMessage)
|
||||
from freqtrade.rpc.rpc import RPC
|
||||
|
||||
|
||||
@@ -18,6 +20,9 @@ router = APIRouter()
|
||||
|
||||
|
||||
async def is_websocket_alive(ws: WebSocket) -> bool:
|
||||
"""
|
||||
Check if a FastAPI Websocket is still open
|
||||
"""
|
||||
if (
|
||||
ws.application_state == WebSocketState.CONNECTED and
|
||||
ws.client_state == WebSocketState.CONNECTED
|
||||
@@ -31,7 +36,17 @@ async def _process_consumer_request(
|
||||
channel: WebSocketChannel,
|
||||
rpc: RPC
|
||||
):
|
||||
type, data = request.get('type'), request.get('data')
|
||||
"""
|
||||
Validate and handle a request from a websocket consumer
|
||||
"""
|
||||
# Validate the request, makes sure it matches the schema
|
||||
try:
|
||||
websocket_request = WSRequestSchema.parse_obj(request)
|
||||
except ValidationError as e:
|
||||
logger.error(f"Invalid request from {channel}: {e}")
|
||||
return
|
||||
|
||||
type, data = websocket_request.type, websocket_request.data
|
||||
|
||||
logger.debug(f"Request of type {type} from {channel}")
|
||||
|
||||
@@ -41,35 +56,35 @@ async def _process_consumer_request(
|
||||
if not data:
|
||||
return
|
||||
|
||||
if not isinstance(data, list):
|
||||
logger.error(f"Improper subscribe request from channel: {channel} - {request}")
|
||||
return
|
||||
|
||||
# If all topics passed are a valid RPCMessageType, set subscriptions on channel
|
||||
if all([any(x.value == topic for x in RPCMessageType) for topic in data]):
|
||||
|
||||
logger.debug(f"{channel} subscribed to topics: {data}")
|
||||
channel.set_subscriptions(data)
|
||||
|
||||
# We don't send a response for subscriptions
|
||||
|
||||
elif type == RPCRequestType.WHITELIST:
|
||||
# They requested the whitelist
|
||||
# Get whitelist
|
||||
whitelist = rpc._ws_request_whitelist()
|
||||
|
||||
await channel.send({"type": RPCMessageType.WHITELIST, "data": whitelist})
|
||||
# Format response
|
||||
response = WSWhitelistMessage(data=whitelist)
|
||||
# Send it back
|
||||
await channel.send(response.dict(exclude_none=True))
|
||||
|
||||
elif type == RPCRequestType.ANALYZED_DF:
|
||||
limit = None
|
||||
|
||||
if data:
|
||||
# Limit the amount of candles per dataframe to 'limit' or 1500
|
||||
limit = max(data.get('limit', 500), 1500)
|
||||
limit = max(data.get('limit', 1500), 1500)
|
||||
|
||||
# They requested the full historical analyzed dataframes
|
||||
analyzed_df = rpc._ws_request_analyzed_df(limit)
|
||||
|
||||
# For every dataframe, send as a separate message
|
||||
for _, message in analyzed_df.items():
|
||||
await channel.send({"type": RPCMessageType.ANALYZED_DF, "data": message})
|
||||
response = WSAnalyzedDFMessage(data=message)
|
||||
await channel.send(response.dict(exclude_none=True))
|
||||
|
||||
|
||||
@router.websocket("/message/ws")
|
||||
@@ -78,6 +93,9 @@ async def message_endpoint(
|
||||
rpc: RPC = Depends(get_rpc),
|
||||
channel_manager=Depends(get_channel_manager),
|
||||
):
|
||||
"""
|
||||
Message WebSocket endpoint, facilitates sending RPC messages
|
||||
"""
|
||||
try:
|
||||
if is_websocket_alive(ws):
|
||||
# TODO:
|
||||
|
78
freqtrade/rpc/api_server/ws/schema.py
Normal file
78
freqtrade/rpc/api_server/ws/schema.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pandas import DataFrame
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from freqtrade.constants import PairWithTimeframe
|
||||
from freqtrade.enums.rpcmessagetype import RPCMessageType, RPCRequestType
|
||||
|
||||
|
||||
__all__ = ('WSRequestSchema', 'WSMessageSchema', 'ValidationError')
|
||||
|
||||
|
||||
class BaseArbitraryModel(BaseModel):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class WSRequestSchema(BaseArbitraryModel):
|
||||
type: RPCRequestType
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class WSMessageSchema(BaseArbitraryModel):
|
||||
type: RPCMessageType
|
||||
data: Optional[Any] = None
|
||||
|
||||
class Config:
|
||||
extra = 'allow'
|
||||
|
||||
|
||||
# ------------------------------ REQUEST SCHEMAS ----------------------------
|
||||
|
||||
|
||||
class WSSubscribeRequest(WSRequestSchema):
|
||||
type: RPCRequestType = RPCRequestType.SUBSCRIBE
|
||||
data: List[RPCMessageType]
|
||||
|
||||
|
||||
class WSWhitelistRequest(WSRequestSchema):
|
||||
type: RPCRequestType = RPCRequestType.WHITELIST
|
||||
data: None = None
|
||||
|
||||
|
||||
class WSAnalyzedDFRequest(WSRequestSchema):
|
||||
type: RPCRequestType = RPCRequestType.ANALYZED_DF
|
||||
data: Dict[str, Any] = {"limit": 1500}
|
||||
|
||||
|
||||
# ------------------------------ MESSAGE SCHEMAS ----------------------------
|
||||
|
||||
class WSWhitelistMessage(WSMessageSchema):
|
||||
type: RPCMessageType = RPCMessageType.WHITELIST
|
||||
data: List[str]
|
||||
|
||||
|
||||
class WSAnalyzedDFMessage(WSMessageSchema):
|
||||
class AnalyzedDFData(BaseArbitraryModel):
|
||||
key: PairWithTimeframe
|
||||
df: DataFrame
|
||||
la: datetime
|
||||
|
||||
type: RPCMessageType = RPCMessageType.ANALYZED_DF
|
||||
data: AnalyzedDFData
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
message = WSAnalyzedDFMessage(
|
||||
data={
|
||||
"key": ("1", "5m", "spot"),
|
||||
"df": DataFrame(),
|
||||
"la": datetime.now()
|
||||
}
|
||||
)
|
||||
|
||||
print(message)
|
Reference in New Issue
Block a user