Don't use pydantic to type-verify outgoing messages

This commit is contained in:
Matthias 2022-10-25 19:36:40 +02:00
parent 32600a113f
commit 3fa50077c9
4 changed files with 20 additions and 13 deletions

View File

@ -75,7 +75,7 @@ async def _process_consumer_request(
# Format response
response = WSWhitelistMessage(data=whitelist)
# Send it back
await channel_manager.send_direct(channel, response)
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
elif type == RPCRequestType.ANALYZED_DF:
limit = None
@ -90,7 +90,7 @@ async def _process_consumer_request(
# For every dataframe, send as a separate message
for _, message in analyzed_df.items():
response = WSAnalyzedDFMessage(data=message)
await channel_manager.send_direct(channel, response)
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
@router.websocket("/message/ws")

View File

@ -16,7 +16,7 @@ from freqtrade.constants import Config
from freqtrade.exceptions import OperationalException
from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer
from freqtrade.rpc.api_server.ws import ChannelManager
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchema
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler
@ -131,7 +131,7 @@ class ApiServer(RPCHandler):
def send_msg(self, msg: Dict[str, Any]) -> None:
if self._ws_queue:
sync_q = self._ws_queue.sync_q
sync_q.put(WSMessageSchema(**msg))
sync_q.put(msg)
def handle_rpc_exception(self, request, exc):
logger.exception(f"API Error calling: {exc}")
@ -195,8 +195,8 @@ class ApiServer(RPCHandler):
while True:
logger.debug("Getting queue messages...")
# Get data from queue
message: WSMessageSchema = await async_queue.get()
logger.debug(f"Found message of type: {message.type}")
message: WSMessageSchemaType = await async_queue.get()
logger.debug(f"Found message of type: {message.get('type')}")
# Broadcast it
await self._ws_channel_manager.broadcast(message)
except asyncio.CancelledError:

View File

@ -1,7 +1,7 @@
import asyncio
import logging
from threading import RLock
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type, Union
from uuid import uuid4
from fastapi import WebSocket as FastAPIWebSocket
@ -10,7 +10,7 @@ from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer,
WebSocketSerializer)
from freqtrade.rpc.api_server.ws.types import WebSocketType
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchema
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
logger = logging.getLogger(__name__)
@ -193,7 +193,7 @@ class ChannelManager:
for websocket in self.channels.copy().keys():
await self.on_disconnect(websocket)
async def broadcast(self, message: WSMessageSchema):
async def broadcast(self, message: WSMessageSchemaType):
"""
Broadcast a message on all Channels
@ -201,17 +201,18 @@ class ChannelManager:
"""
with self._lock:
for channel in self.channels.copy().values():
if channel.subscribed_to(message.type):
if channel.subscribed_to(message.get('type')):
await self.send_direct(channel, message)
async def send_direct(self, channel: WebSocketChannel, message: WSMessageSchema):
async def send_direct(
self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]):
"""
Send a message directly through direct_channel only
:param direct_channel: The WebSocketChannel object to send the message through
:param message: The message to send
"""
if not await channel.send(message.dict(exclude_none=True)):
if not await channel.send(message):
await self.on_disconnect(channel.raw_websocket)
def has_channels(self):

View File

@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TypedDict
from pandas import DataFrame
from pydantic import BaseModel
@ -18,6 +18,12 @@ class WSRequestSchema(BaseArbitraryModel):
data: Optional[Any] = None
class WSMessageSchemaType(TypedDict):
# Type for typing to avoid doing pydantic typechecks.
type: RPCMessageType
data: Optional[Dict[str, Any]]
class WSMessageSchema(BaseArbitraryModel):
type: RPCMessageType
data: Optional[Any] = None