Don't use pydantic to type-verify outgoing messages

This commit is contained in:
Matthias 2022-10-25 19:36:40 +02:00
parent 32600a113f
commit 3fa50077c9
4 changed files with 20 additions and 13 deletions

View File

@ -75,7 +75,7 @@ async def _process_consumer_request(
# Format response # Format response
response = WSWhitelistMessage(data=whitelist) response = WSWhitelistMessage(data=whitelist)
# Send it back # Send it back
await channel_manager.send_direct(channel, response) await channel_manager.send_direct(channel, response.dict(exclude_none=True))
elif type == RPCRequestType.ANALYZED_DF: elif type == RPCRequestType.ANALYZED_DF:
limit = None limit = None
@ -90,7 +90,7 @@ async def _process_consumer_request(
# For every dataframe, send as a separate message # For every dataframe, send as a separate message
for _, message in analyzed_df.items(): for _, message in analyzed_df.items():
response = WSAnalyzedDFMessage(data=message) response = WSAnalyzedDFMessage(data=message)
await channel_manager.send_direct(channel, response) await channel_manager.send_direct(channel, response.dict(exclude_none=True))
@router.websocket("/message/ws") @router.websocket("/message/ws")

View File

@ -16,7 +16,7 @@ from freqtrade.constants import Config
from freqtrade.exceptions import OperationalException from freqtrade.exceptions import OperationalException
from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer
from freqtrade.rpc.api_server.ws import ChannelManager from freqtrade.rpc.api_server.ws import ChannelManager
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchema from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler
@ -131,7 +131,7 @@ class ApiServer(RPCHandler):
def send_msg(self, msg: Dict[str, Any]) -> None: def send_msg(self, msg: Dict[str, Any]) -> None:
if self._ws_queue: if self._ws_queue:
sync_q = self._ws_queue.sync_q sync_q = self._ws_queue.sync_q
sync_q.put(WSMessageSchema(**msg)) sync_q.put(msg)
def handle_rpc_exception(self, request, exc): def handle_rpc_exception(self, request, exc):
logger.exception(f"API Error calling: {exc}") logger.exception(f"API Error calling: {exc}")
@ -195,8 +195,8 @@ class ApiServer(RPCHandler):
while True: while True:
logger.debug("Getting queue messages...") logger.debug("Getting queue messages...")
# Get data from queue # Get data from queue
message: WSMessageSchema = await async_queue.get() message: WSMessageSchemaType = await async_queue.get()
logger.debug(f"Found message of type: {message.type}") logger.debug(f"Found message of type: {message.get('type')}")
# Broadcast it # Broadcast it
await self._ws_channel_manager.broadcast(message) await self._ws_channel_manager.broadcast(message)
except asyncio.CancelledError: except asyncio.CancelledError:

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
from threading import RLock from threading import RLock
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type, Union
from uuid import uuid4 from uuid import uuid4
from fastapi import WebSocket as FastAPIWebSocket from fastapi import WebSocket as FastAPIWebSocket
@ -10,7 +10,7 @@ from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer, from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer,
WebSocketSerializer) WebSocketSerializer)
from freqtrade.rpc.api_server.ws.types import WebSocketType from freqtrade.rpc.api_server.ws.types import WebSocketType
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchema from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -193,7 +193,7 @@ class ChannelManager:
for websocket in self.channels.copy().keys(): for websocket in self.channels.copy().keys():
await self.on_disconnect(websocket) await self.on_disconnect(websocket)
async def broadcast(self, message: WSMessageSchema): async def broadcast(self, message: WSMessageSchemaType):
""" """
Broadcast a message on all Channels Broadcast a message on all Channels
@ -201,17 +201,18 @@ class ChannelManager:
""" """
with self._lock: with self._lock:
for channel in self.channels.copy().values(): for channel in self.channels.copy().values():
if channel.subscribed_to(message.type): if channel.subscribed_to(message.get('type')):
await self.send_direct(channel, message) await self.send_direct(channel, message)
async def send_direct(self, channel: WebSocketChannel, message: WSMessageSchema): async def send_direct(
self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]):
""" """
Send a message directly through direct_channel only Send a message directly through direct_channel only
:param direct_channel: The WebSocketChannel object to send the message through :param direct_channel: The WebSocketChannel object to send the message through
:param message: The message to send :param message: The message to send
""" """
if not await channel.send(message.dict(exclude_none=True)): if not await channel.send(message):
await self.on_disconnect(channel.raw_websocket) await self.on_disconnect(channel.raw_websocket)
def has_channels(self): def has_channels(self):

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, TypedDict
from pandas import DataFrame from pandas import DataFrame
from pydantic import BaseModel from pydantic import BaseModel
@ -18,6 +18,12 @@ class WSRequestSchema(BaseArbitraryModel):
data: Optional[Any] = None data: Optional[Any] = None
class WSMessageSchemaType(TypedDict):
# Type for typing to avoid doing pydantic typechecks.
type: RPCMessageType
data: Optional[Dict[str, Any]]
class WSMessageSchema(BaseArbitraryModel): class WSMessageSchema(BaseArbitraryModel):
type: RPCMessageType type: RPCMessageType
data: Optional[Any] = None data: Optional[Any] = None