add websocket request/message schemas
This commit is contained in:
parent
8bfaf0a998
commit
5934495dda
@ -121,7 +121,8 @@ class DataProvider:
|
|||||||
'type': RPCMessageType.ANALYZED_DF,
|
'type': RPCMessageType.ANALYZED_DF,
|
||||||
'data': {
|
'data': {
|
||||||
'key': pair_key,
|
'key': pair_key,
|
||||||
'value': (dataframe, datetime.now(timezone.utc))
|
'df': dataframe,
|
||||||
|
'la': datetime.now(timezone.utc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -130,7 +131,7 @@ class DataProvider:
|
|||||||
self,
|
self,
|
||||||
pair: str,
|
pair: str,
|
||||||
dataframe: DataFrame,
|
dataframe: DataFrame,
|
||||||
last_analyzed: Optional[str] = None,
|
last_analyzed: Optional[datetime] = None,
|
||||||
timeframe: Optional[str] = None,
|
timeframe: Optional[str] = None,
|
||||||
candle_type: Optional[CandleType] = None,
|
candle_type: Optional[CandleType] = None,
|
||||||
producer_name: str = "default"
|
producer_name: str = "default"
|
||||||
@ -150,10 +151,7 @@ class DataProvider:
|
|||||||
if producer_name not in self.__producer_pairs_df:
|
if producer_name not in self.__producer_pairs_df:
|
||||||
self.__producer_pairs_df[producer_name] = {}
|
self.__producer_pairs_df[producer_name] = {}
|
||||||
|
|
||||||
if not last_analyzed:
|
_last_analyzed = datetime.now(timezone.utc) if not last_analyzed else last_analyzed
|
||||||
_last_analyzed = datetime.now(timezone.utc)
|
|
||||||
else:
|
|
||||||
_last_analyzed = datetime.fromisoformat(last_analyzed)
|
|
||||||
|
|
||||||
self.__producer_pairs_df[producer_name][pair_key] = (dataframe, _last_analyzed)
|
self.__producer_pairs_df[producer_name][pair_key] = (dataframe, _last_analyzed)
|
||||||
logger.debug(f"External DataFrame for {pair_key} from {producer_name} added.")
|
logger.debug(f"External DataFrame for {pair_key} from {producer_name} added.")
|
||||||
|
@ -8,6 +8,8 @@ from starlette.websockets import WebSocketState
|
|||||||
from freqtrade.enums import RPCMessageType, RPCRequestType
|
from freqtrade.enums import RPCMessageType, RPCRequestType
|
||||||
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
|
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
|
||||||
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
|
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
|
||||||
|
from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage,
|
||||||
|
WSRequestSchema, WSWhitelistMessage)
|
||||||
from freqtrade.rpc.rpc import RPC
|
from freqtrade.rpc.rpc import RPC
|
||||||
|
|
||||||
|
|
||||||
@ -18,6 +20,9 @@ router = APIRouter()
|
|||||||
|
|
||||||
|
|
||||||
async def is_websocket_alive(ws: WebSocket) -> bool:
|
async def is_websocket_alive(ws: WebSocket) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a FastAPI Websocket is still open
|
||||||
|
"""
|
||||||
if (
|
if (
|
||||||
ws.application_state == WebSocketState.CONNECTED and
|
ws.application_state == WebSocketState.CONNECTED and
|
||||||
ws.client_state == WebSocketState.CONNECTED
|
ws.client_state == WebSocketState.CONNECTED
|
||||||
@ -31,7 +36,17 @@ async def _process_consumer_request(
|
|||||||
channel: WebSocketChannel,
|
channel: WebSocketChannel,
|
||||||
rpc: RPC
|
rpc: RPC
|
||||||
):
|
):
|
||||||
type, data = request.get('type'), request.get('data')
|
"""
|
||||||
|
Validate and handle a request from a websocket consumer
|
||||||
|
"""
|
||||||
|
# Validate the request, makes sure it matches the schema
|
||||||
|
try:
|
||||||
|
websocket_request = WSRequestSchema.parse_obj(request)
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.error(f"Invalid request from {channel}: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
type, data = websocket_request.type, websocket_request.data
|
||||||
|
|
||||||
logger.debug(f"Request of type {type} from {channel}")
|
logger.debug(f"Request of type {type} from {channel}")
|
||||||
|
|
||||||
@ -41,35 +56,35 @@ async def _process_consumer_request(
|
|||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not isinstance(data, list):
|
|
||||||
logger.error(f"Improper subscribe request from channel: {channel} - {request}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# If all topics passed are a valid RPCMessageType, set subscriptions on channel
|
# If all topics passed are a valid RPCMessageType, set subscriptions on channel
|
||||||
if all([any(x.value == topic for x in RPCMessageType) for topic in data]):
|
if all([any(x.value == topic for x in RPCMessageType) for topic in data]):
|
||||||
|
|
||||||
logger.debug(f"{channel} subscribed to topics: {data}")
|
|
||||||
channel.set_subscriptions(data)
|
channel.set_subscriptions(data)
|
||||||
|
|
||||||
|
# We don't send a response for subscriptions
|
||||||
|
|
||||||
elif type == RPCRequestType.WHITELIST:
|
elif type == RPCRequestType.WHITELIST:
|
||||||
# They requested the whitelist
|
# Get whitelist
|
||||||
whitelist = rpc._ws_request_whitelist()
|
whitelist = rpc._ws_request_whitelist()
|
||||||
|
|
||||||
await channel.send({"type": RPCMessageType.WHITELIST, "data": whitelist})
|
# Format response
|
||||||
|
response = WSWhitelistMessage(data=whitelist)
|
||||||
|
# Send it back
|
||||||
|
await channel.send(response.dict(exclude_none=True))
|
||||||
|
|
||||||
elif type == RPCRequestType.ANALYZED_DF:
|
elif type == RPCRequestType.ANALYZED_DF:
|
||||||
limit = None
|
limit = None
|
||||||
|
|
||||||
if data:
|
if data:
|
||||||
# Limit the amount of candles per dataframe to 'limit' or 1500
|
# Limit the amount of candles per dataframe to 'limit' or 1500
|
||||||
limit = max(data.get('limit', 500), 1500)
|
limit = max(data.get('limit', 1500), 1500)
|
||||||
|
|
||||||
# They requested the full historical analyzed dataframes
|
# They requested the full historical analyzed dataframes
|
||||||
analyzed_df = rpc._ws_request_analyzed_df(limit)
|
analyzed_df = rpc._ws_request_analyzed_df(limit)
|
||||||
|
|
||||||
# For every dataframe, send as a separate message
|
# For every dataframe, send as a separate message
|
||||||
for _, message in analyzed_df.items():
|
for _, message in analyzed_df.items():
|
||||||
await channel.send({"type": RPCMessageType.ANALYZED_DF, "data": message})
|
response = WSAnalyzedDFMessage(data=message)
|
||||||
|
await channel.send(response.dict(exclude_none=True))
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/message/ws")
|
@router.websocket("/message/ws")
|
||||||
@ -78,6 +93,9 @@ async def message_endpoint(
|
|||||||
rpc: RPC = Depends(get_rpc),
|
rpc: RPC = Depends(get_rpc),
|
||||||
channel_manager=Depends(get_channel_manager),
|
channel_manager=Depends(get_channel_manager),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Message WebSocket endpoint, facilitates sending RPC messages
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
if is_websocket_alive(ws):
|
if is_websocket_alive(ws):
|
||||||
# TODO:
|
# TODO:
|
||||||
|
78
freqtrade/rpc/api_server/ws/schema.py
Normal file
78
freqtrade/rpc/api_server/ws/schema.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pandas import DataFrame
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
from freqtrade.constants import PairWithTimeframe
|
||||||
|
from freqtrade.enums.rpcmessagetype import RPCMessageType, RPCRequestType
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ('WSRequestSchema', 'WSMessageSchema', 'ValidationError')
|
||||||
|
|
||||||
|
|
||||||
|
class BaseArbitraryModel(BaseModel):
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class WSRequestSchema(BaseArbitraryModel):
|
||||||
|
type: RPCRequestType
|
||||||
|
data: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class WSMessageSchema(BaseArbitraryModel):
|
||||||
|
type: RPCMessageType
|
||||||
|
data: Optional[Any] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = 'allow'
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------ REQUEST SCHEMAS ----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class WSSubscribeRequest(WSRequestSchema):
|
||||||
|
type: RPCRequestType = RPCRequestType.SUBSCRIBE
|
||||||
|
data: List[RPCMessageType]
|
||||||
|
|
||||||
|
|
||||||
|
class WSWhitelistRequest(WSRequestSchema):
|
||||||
|
type: RPCRequestType = RPCRequestType.WHITELIST
|
||||||
|
data: None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WSAnalyzedDFRequest(WSRequestSchema):
|
||||||
|
type: RPCRequestType = RPCRequestType.ANALYZED_DF
|
||||||
|
data: Dict[str, Any] = {"limit": 1500}
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------ MESSAGE SCHEMAS ----------------------------
|
||||||
|
|
||||||
|
class WSWhitelistMessage(WSMessageSchema):
|
||||||
|
type: RPCMessageType = RPCMessageType.WHITELIST
|
||||||
|
data: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class WSAnalyzedDFMessage(WSMessageSchema):
|
||||||
|
class AnalyzedDFData(BaseArbitraryModel):
|
||||||
|
key: PairWithTimeframe
|
||||||
|
df: DataFrame
|
||||||
|
la: datetime
|
||||||
|
|
||||||
|
type: RPCMessageType = RPCMessageType.ANALYZED_DF
|
||||||
|
data: AnalyzedDFData
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
message = WSAnalyzedDFMessage(
|
||||||
|
data={
|
||||||
|
"key": ("1", "5m", "spot"),
|
||||||
|
"df": DataFrame(),
|
||||||
|
"la": datetime.now()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(message)
|
@ -8,14 +8,18 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
from freqtrade.data.dataprovider import DataProvider
|
from freqtrade.data.dataprovider import DataProvider
|
||||||
from freqtrade.enums import RPCMessageType, RPCRequestType
|
from freqtrade.enums import RPCMessageType
|
||||||
from freqtrade.misc import remove_entry_exit_signals
|
from freqtrade.misc import remove_entry_exit_signals
|
||||||
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
|
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
|
||||||
|
from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage,
|
||||||
|
WSAnalyzedDFRequest, WSMessageSchema,
|
||||||
|
WSRequestSchema, WSSubscribeRequest,
|
||||||
|
WSWhitelistMessage, WSWhitelistRequest)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -67,15 +71,10 @@ class ExternalMessageConsumer:
|
|||||||
self.topics = [RPCMessageType.WHITELIST, RPCMessageType.ANALYZED_DF]
|
self.topics = [RPCMessageType.WHITELIST, RPCMessageType.ANALYZED_DF]
|
||||||
|
|
||||||
# Allow setting data for each initial request
|
# Allow setting data for each initial request
|
||||||
self._initial_requests: List[Dict[str, Any]] = [
|
self._initial_requests: List[WSRequestSchema] = [
|
||||||
{
|
WSSubscribeRequest(data=self.topics),
|
||||||
"type": RPCRequestType.WHITELIST,
|
WSWhitelistRequest(),
|
||||||
"data": None
|
WSAnalyzedDFRequest()
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": RPCRequestType.ANALYZED_DF,
|
|
||||||
"data": {"limit": self.initial_candle_limit}
|
|
||||||
}
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Specify which function to use for which RPCMessageType
|
# Specify which function to use for which RPCMessageType
|
||||||
@ -174,16 +173,10 @@ class ExternalMessageConsumer:
|
|||||||
|
|
||||||
logger.info(f"Producer connection success - {channel}")
|
logger.info(f"Producer connection success - {channel}")
|
||||||
|
|
||||||
# Tell the producer we only want these topics
|
|
||||||
# Should always be the first thing we send
|
|
||||||
await channel.send(
|
|
||||||
self.compose_consumer_request(RPCRequestType.SUBSCRIBE, self.topics)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now request the initial data from this Producer
|
# Now request the initial data from this Producer
|
||||||
for request in self._initial_requests:
|
for request in self._initial_requests:
|
||||||
await channel.send(
|
await channel.send(
|
||||||
self.compose_consumer_request(request['type'], request['data'])
|
request.dict(exclude_none=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now receive data, if none is within the time limit, ping
|
# Now receive data, if none is within the time limit, ping
|
||||||
@ -253,71 +246,63 @@ class ExternalMessageConsumer:
|
|||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
def compose_consumer_request(
|
|
||||||
self,
|
|
||||||
type_: RPCRequestType,
|
|
||||||
data: Optional[Any] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Create a request for sending to a producer
|
|
||||||
|
|
||||||
:param type_: The RPCRequestType
|
|
||||||
:param data: The data to send
|
|
||||||
:returns: Dict[str, Any]
|
|
||||||
"""
|
|
||||||
return {'type': type_, 'data': data}
|
|
||||||
|
|
||||||
def handle_producer_message(self, producer: Dict[str, Any], message: Dict[str, Any]):
|
def handle_producer_message(self, producer: Dict[str, Any], message: Dict[str, Any]):
|
||||||
"""
|
"""
|
||||||
Handles external messages from a Producer
|
Handles external messages from a Producer
|
||||||
"""
|
"""
|
||||||
producer_name = producer.get('name', 'default')
|
producer_name = producer.get('name', 'default')
|
||||||
# Should we have a default message type?
|
|
||||||
message_type = message.get('type', RPCMessageType.STATUS)
|
try:
|
||||||
message_data = message.get('data')
|
producer_message = WSMessageSchema.parse_obj(message)
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.error(f"Invalid message from {producer_name}: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
# We shouldn't get empty messages
|
# We shouldn't get empty messages
|
||||||
if message_data is None:
|
if producer_message.data is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"Received message of type {message_type} from `{producer_name}`")
|
logger.info(f"Received message of type {producer_message.type} from `{producer_name}`")
|
||||||
|
|
||||||
message_handler = self._message_handlers.get(message_type)
|
message_handler = self._message_handlers.get(producer_message.type)
|
||||||
|
|
||||||
if not message_handler:
|
if not message_handler:
|
||||||
logger.info(f"Received unhandled message: {message_data}, ignoring...")
|
logger.info(f"Received unhandled message: {producer_message.data}, ignoring...")
|
||||||
return
|
return
|
||||||
|
|
||||||
message_handler(producer_name, message_data)
|
message_handler(producer_name, producer_message)
|
||||||
|
|
||||||
def _consume_whitelist_message(self, producer_name: str, message_data: Any):
|
def _consume_whitelist_message(self, producer_name: str, message: Any):
|
||||||
# We expect List[str]
|
try:
|
||||||
if not isinstance(message_data, list):
|
# Validate the message
|
||||||
|
message = WSWhitelistMessage.parse_obj(message)
|
||||||
|
except ValidationError:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Add the pairlist data to the DataProvider
|
# Add the pairlist data to the DataProvider
|
||||||
self._dp._set_producer_pairs(message_data, producer_name=producer_name)
|
self._dp._set_producer_pairs(message.data, producer_name=producer_name)
|
||||||
|
|
||||||
logger.debug(f"Consumed message from {producer_name} of type RPCMessageType.WHITELIST")
|
logger.debug(f"Consumed message from {producer_name} of type `RPCMessageType.WHITELIST`")
|
||||||
|
|
||||||
def _consume_analyzed_df_message(self, producer_name: str, message_data: Any):
|
def _consume_analyzed_df_message(self, producer_name: str, message: Any):
|
||||||
# We expect a Dict[str, Any]
|
try:
|
||||||
if not isinstance(message_data, dict):
|
message = WSAnalyzedDFMessage.parse_obj(message)
|
||||||
|
except ValidationError:
|
||||||
return
|
return
|
||||||
|
|
||||||
key, value = message_data.get('key'), message_data.get('value')
|
key = message.data.key
|
||||||
|
df = message.data.df
|
||||||
|
la = message.data.la
|
||||||
|
|
||||||
if key and value:
|
|
||||||
pair, timeframe, candle_type = key
|
pair, timeframe, candle_type = key
|
||||||
dataframe, last_analyzed = value
|
|
||||||
|
|
||||||
# If set, remove the Entry and Exit signals from the Producer
|
# If set, remove the Entry and Exit signals from the Producer
|
||||||
if self._emc_config.get('remove_entry_exit_signals', False):
|
if self._emc_config.get('remove_entry_exit_signals', False):
|
||||||
dataframe = remove_entry_exit_signals(dataframe)
|
df = remove_entry_exit_signals(df)
|
||||||
|
|
||||||
# Add the dataframe to the dataprovider
|
# Add the dataframe to the dataprovider
|
||||||
self._dp._add_external_df(pair, dataframe,
|
self._dp._add_external_df(pair, df,
|
||||||
last_analyzed=last_analyzed,
|
last_analyzed=la,
|
||||||
timeframe=timeframe,
|
timeframe=timeframe,
|
||||||
candle_type=candle_type,
|
candle_type=candle_type,
|
||||||
producer_name=producer_name)
|
producer_name=producer_name)
|
||||||
|
@ -1068,8 +1068,12 @@ class RPC:
|
|||||||
|
|
||||||
for pair in pairlist:
|
for pair in pairlist:
|
||||||
dataframe, last_analyzed = self.__rpc_analysed_dataframe_raw(pair, timeframe, limit)
|
dataframe, last_analyzed = self.__rpc_analysed_dataframe_raw(pair, timeframe, limit)
|
||||||
_data[pair] = {"key": (pair, timeframe, candle_type),
|
|
||||||
"value": (dataframe, last_analyzed)}
|
_data[pair] = {
|
||||||
|
"key": (pair, timeframe, candle_type),
|
||||||
|
"df": dataframe,
|
||||||
|
"la": last_analyzed
|
||||||
|
}
|
||||||
|
|
||||||
return _data
|
return _data
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user