Don't use pydantic to type-verify outgoing messages
This commit is contained in:
parent
32600a113f
commit
3fa50077c9
@ -75,7 +75,7 @@ async def _process_consumer_request(
|
|||||||
# Format response
|
# Format response
|
||||||
response = WSWhitelistMessage(data=whitelist)
|
response = WSWhitelistMessage(data=whitelist)
|
||||||
# Send it back
|
# Send it back
|
||||||
await channel_manager.send_direct(channel, response)
|
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
|
||||||
|
|
||||||
elif type == RPCRequestType.ANALYZED_DF:
|
elif type == RPCRequestType.ANALYZED_DF:
|
||||||
limit = None
|
limit = None
|
||||||
@ -90,7 +90,7 @@ async def _process_consumer_request(
|
|||||||
# For every dataframe, send as a separate message
|
# For every dataframe, send as a separate message
|
||||||
for _, message in analyzed_df.items():
|
for _, message in analyzed_df.items():
|
||||||
response = WSAnalyzedDFMessage(data=message)
|
response = WSAnalyzedDFMessage(data=message)
|
||||||
await channel_manager.send_direct(channel, response)
|
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/message/ws")
|
@router.websocket("/message/ws")
|
||||||
|
@ -16,7 +16,7 @@ from freqtrade.constants import Config
|
|||||||
from freqtrade.exceptions import OperationalException
|
from freqtrade.exceptions import OperationalException
|
||||||
from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer
|
from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer
|
||||||
from freqtrade.rpc.api_server.ws import ChannelManager
|
from freqtrade.rpc.api_server.ws import ChannelManager
|
||||||
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchema
|
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
|
||||||
from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler
|
from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler
|
||||||
|
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ class ApiServer(RPCHandler):
|
|||||||
def send_msg(self, msg: Dict[str, Any]) -> None:
|
def send_msg(self, msg: Dict[str, Any]) -> None:
|
||||||
if self._ws_queue:
|
if self._ws_queue:
|
||||||
sync_q = self._ws_queue.sync_q
|
sync_q = self._ws_queue.sync_q
|
||||||
sync_q.put(WSMessageSchema(**msg))
|
sync_q.put(msg)
|
||||||
|
|
||||||
def handle_rpc_exception(self, request, exc):
|
def handle_rpc_exception(self, request, exc):
|
||||||
logger.exception(f"API Error calling: {exc}")
|
logger.exception(f"API Error calling: {exc}")
|
||||||
@ -195,8 +195,8 @@ class ApiServer(RPCHandler):
|
|||||||
while True:
|
while True:
|
||||||
logger.debug("Getting queue messages...")
|
logger.debug("Getting queue messages...")
|
||||||
# Get data from queue
|
# Get data from queue
|
||||||
message: WSMessageSchema = await async_queue.get()
|
message: WSMessageSchemaType = await async_queue.get()
|
||||||
logger.debug(f"Found message of type: {message.type}")
|
logger.debug(f"Found message of type: {message.get('type')}")
|
||||||
# Broadcast it
|
# Broadcast it
|
||||||
await self._ws_channel_manager.broadcast(message)
|
await self._ws_channel_manager.broadcast(message)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
from typing import Any, Dict, List, Optional, Type
|
from typing import Any, Dict, List, Optional, Type, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from fastapi import WebSocket as FastAPIWebSocket
|
from fastapi import WebSocket as FastAPIWebSocket
|
||||||
@ -10,7 +10,7 @@ from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
|
|||||||
from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer,
|
from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer,
|
||||||
WebSocketSerializer)
|
WebSocketSerializer)
|
||||||
from freqtrade.rpc.api_server.ws.types import WebSocketType
|
from freqtrade.rpc.api_server.ws.types import WebSocketType
|
||||||
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchema
|
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -193,7 +193,7 @@ class ChannelManager:
|
|||||||
for websocket in self.channels.copy().keys():
|
for websocket in self.channels.copy().keys():
|
||||||
await self.on_disconnect(websocket)
|
await self.on_disconnect(websocket)
|
||||||
|
|
||||||
async def broadcast(self, message: WSMessageSchema):
|
async def broadcast(self, message: WSMessageSchemaType):
|
||||||
"""
|
"""
|
||||||
Broadcast a message on all Channels
|
Broadcast a message on all Channels
|
||||||
|
|
||||||
@ -201,17 +201,18 @@ class ChannelManager:
|
|||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for channel in self.channels.copy().values():
|
for channel in self.channels.copy().values():
|
||||||
if channel.subscribed_to(message.type):
|
if channel.subscribed_to(message.get('type')):
|
||||||
await self.send_direct(channel, message)
|
await self.send_direct(channel, message)
|
||||||
|
|
||||||
async def send_direct(self, channel: WebSocketChannel, message: WSMessageSchema):
|
async def send_direct(
|
||||||
|
self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
Send a message directly through direct_channel only
|
Send a message directly through direct_channel only
|
||||||
|
|
||||||
:param direct_channel: The WebSocketChannel object to send the message through
|
:param direct_channel: The WebSocketChannel object to send the message through
|
||||||
:param message: The message to send
|
:param message: The message to send
|
||||||
"""
|
"""
|
||||||
if not await channel.send(message.dict(exclude_none=True)):
|
if not await channel.send(message):
|
||||||
await self.on_disconnect(channel.raw_websocket)
|
await self.on_disconnect(channel.raw_websocket)
|
||||||
|
|
||||||
def has_channels(self):
|
def has_channels(self):
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, TypedDict
|
||||||
|
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -18,6 +18,12 @@ class WSRequestSchema(BaseArbitraryModel):
|
|||||||
data: Optional[Any] = None
|
data: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class WSMessageSchemaType(TypedDict):
|
||||||
|
# Type for typing to avoid doing pydantic typechecks.
|
||||||
|
type: RPCMessageType
|
||||||
|
data: Optional[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class WSMessageSchema(BaseArbitraryModel):
|
class WSMessageSchema(BaseArbitraryModel):
|
||||||
type: RPCMessageType
|
type: RPCMessageType
|
||||||
data: Optional[Any] = None
|
data: Optional[Any] = None
|
||||||
|
Loading…
Reference in New Issue
Block a user