diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index cafbaefcb..c33f9c730 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -75,7 +75,7 @@ async def _process_consumer_request( # Format response response = WSWhitelistMessage(data=whitelist) # Send it back - await channel_manager.send_direct(channel, response) + await channel_manager.send_direct(channel, response.dict(exclude_none=True)) elif type == RPCRequestType.ANALYZED_DF: limit = None @@ -90,7 +90,7 @@ async def _process_consumer_request( # For every dataframe, send as a separate message for _, message in analyzed_df.items(): response = WSAnalyzedDFMessage(data=message) - await channel_manager.send_direct(channel, response) + await channel_manager.send_direct(channel, response.dict(exclude_none=True)) @router.websocket("/message/ws") diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index 22a05f07b..1d0192a89 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -16,7 +16,7 @@ from freqtrade.constants import Config from freqtrade.exceptions import OperationalException from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer from freqtrade.rpc.api_server.ws import ChannelManager -from freqtrade.rpc.api_server.ws_schemas import WSMessageSchema +from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler @@ -131,7 +131,7 @@ class ApiServer(RPCHandler): def send_msg(self, msg: Dict[str, Any]) -> None: if self._ws_queue: sync_q = self._ws_queue.sync_q - sync_q.put(WSMessageSchema(**msg)) + sync_q.put(msg) def handle_rpc_exception(self, request, exc): logger.exception(f"API Error calling: {exc}") @@ -195,8 +195,8 @@ class ApiServer(RPCHandler): while True: logger.debug("Getting queue messages...") # Get data from queue - message: WSMessageSchema = await async_queue.get() - logger.debug(f"Found message of type: {message.type}") + message: WSMessageSchemaType = await async_queue.get() + logger.debug(f"Found message of type: {message.get('type')}") # Broadcast it await self._ws_channel_manager.broadcast(message) except asyncio.CancelledError: diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 942a3df70..34f03f0c4 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -1,7 +1,7 @@ import asyncio import logging from threading import RLock -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Union from uuid import uuid4 from fastapi import WebSocket as FastAPIWebSocket @@ -10,7 +10,7 @@ from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer, WebSocketSerializer) from freqtrade.rpc.api_server.ws.types import WebSocketType -from freqtrade.rpc.api_server.ws_schemas import WSMessageSchema +from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType logger = logging.getLogger(__name__) @@ -193,7 +193,7 @@ class ChannelManager: for websocket in self.channels.copy().keys(): await self.on_disconnect(websocket) - async def broadcast(self, message: WSMessageSchema): + async def broadcast(self, message: WSMessageSchemaType): """ Broadcast a message on all Channels @@ -201,17 +201,18 @@ class ChannelManager: """ with self._lock: for channel in self.channels.copy().values(): - if channel.subscribed_to(message.type): + if channel.subscribed_to(message.get('type')): await self.send_direct(channel, message) - async def send_direct(self, channel: WebSocketChannel, message: WSMessageSchema): + async def send_direct( + self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]): """ Send a message directly through direct_channel only :param direct_channel: The WebSocketChannel object to send the message through :param message: The message to send """ - if not await channel.send(message.dict(exclude_none=True)): + if not await channel.send(message): await self.on_disconnect(channel.raw_websocket) def has_channels(self): diff --git a/freqtrade/rpc/api_server/ws_schemas.py b/freqtrade/rpc/api_server/ws_schemas.py index 255226d84..877232213 100644 --- a/freqtrade/rpc/api_server/ws_schemas.py +++ b/freqtrade/rpc/api_server/ws_schemas.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, TypedDict from pandas import DataFrame from pydantic import BaseModel @@ -18,6 +18,12 @@ class WSRequestSchema(BaseArbitraryModel): data: Optional[Any] = None +class WSMessageSchemaType(TypedDict): + # Type for typing to avoid doing pydantic typechecks. + type: RPCMessageType + data: Optional[Dict[str, Any]] + + class WSMessageSchema(BaseArbitraryModel): type: RPCMessageType data: Optional[Any] = None