Merge pull request #7621 from wizrds/fix/channel-api

Improved WebSocketChannel API
This commit is contained in:
Matthias 2022-10-26 06:31:42 +02:00 committed by GitHub
commit 110db8b241
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 66 additions and 41 deletions

View File

@ -1,4 +1,3 @@
import asyncio
import logging import logging
from typing import Any, Dict from typing import Any, Dict
@ -11,6 +10,7 @@ from freqtrade.enums import RPCMessageType, RPCRequestType
from freqtrade.rpc.api_server.api_auth import validate_ws_token from freqtrade.rpc.api_server.api_auth import validate_ws_token
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
from freqtrade.rpc.api_server.ws import WebSocketChannel from freqtrade.rpc.api_server.ws import WebSocketChannel
from freqtrade.rpc.api_server.ws.channel import ChannelManager
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
WSRequestSchema, WSWhitelistMessage) WSRequestSchema, WSWhitelistMessage)
from freqtrade.rpc.rpc import RPC from freqtrade.rpc.rpc import RPC
@ -37,7 +37,8 @@ async def is_websocket_alive(ws: WebSocket) -> bool:
async def _process_consumer_request( async def _process_consumer_request(
request: Dict[str, Any], request: Dict[str, Any],
channel: WebSocketChannel, channel: WebSocketChannel,
rpc: RPC rpc: RPC,
channel_manager: ChannelManager
): ):
""" """
Validate and handle a request from a websocket consumer Validate and handle a request from a websocket consumer
@ -74,7 +75,7 @@ async def _process_consumer_request(
# Format response # Format response
response = WSWhitelistMessage(data=whitelist) response = WSWhitelistMessage(data=whitelist)
# Send it back # Send it back
await channel.send(response.dict(exclude_none=True)) await channel_manager.send_direct(channel, response.dict(exclude_none=True))
elif type == RPCRequestType.ANALYZED_DF: elif type == RPCRequestType.ANALYZED_DF:
limit = None limit = None
@ -89,9 +90,7 @@ async def _process_consumer_request(
# For every dataframe, send as a separate message # For every dataframe, send as a separate message
for _, message in analyzed_df.items(): for _, message in analyzed_df.items():
response = WSAnalyzedDFMessage(data=message) response = WSAnalyzedDFMessage(data=message)
await channel.send(response.dict(exclude_none=True)) await channel_manager.send_direct(channel, response.dict(exclude_none=True))
# Throttle the messages to 50/s
await asyncio.sleep(0.02)
@router.websocket("/message/ws") @router.websocket("/message/ws")
@ -116,7 +115,7 @@ async def message_endpoint(
request = await channel.recv() request = await channel.recv()
# Process the request here # Process the request here
await _process_consumer_request(request, channel, rpc) await _process_consumer_request(request, channel, rpc, channel_manager)
except (WebSocketDisconnect, WebSocketException): except (WebSocketDisconnect, WebSocketException):
# Handle client disconnects # Handle client disconnects

View File

@ -16,6 +16,7 @@ from freqtrade.constants import Config
from freqtrade.exceptions import OperationalException from freqtrade.exceptions import OperationalException
from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer
from freqtrade.rpc.api_server.ws import ChannelManager from freqtrade.rpc.api_server.ws import ChannelManager
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler
@ -127,7 +128,7 @@ class ApiServer(RPCHandler):
cls._has_rpc = False cls._has_rpc = False
cls._rpc = None cls._rpc = None
def send_msg(self, msg: Dict[str, str]) -> None: def send_msg(self, msg: Dict[str, Any]) -> None:
if self._ws_queue: if self._ws_queue:
sync_q = self._ws_queue.sync_q sync_q = self._ws_queue.sync_q
sync_q.put(msg) sync_q.put(msg)
@ -194,14 +195,10 @@ class ApiServer(RPCHandler):
while True: while True:
logger.debug("Getting queue messages...") logger.debug("Getting queue messages...")
# Get data from queue # Get data from queue
message = await async_queue.get() message: WSMessageSchemaType = await async_queue.get()
logger.debug(f"Found message of type: {message.get('type')}") logger.debug(f"Found message of type: {message.get('type')}")
# Broadcast it # Broadcast it
await self._ws_channel_manager.broadcast(message) await self._ws_channel_manager.broadcast(message)
# Limit messages per sec.
# Could cause problems with queue size if too low, and
# problems with network traffik if too high.
await asyncio.sleep(0.001)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
from threading import RLock from threading import RLock
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type, Union
from uuid import uuid4 from uuid import uuid4
from fastapi import WebSocket as FastAPIWebSocket from fastapi import WebSocket as FastAPIWebSocket
@ -10,6 +10,7 @@ from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer, from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer,
WebSocketSerializer) WebSocketSerializer)
from freqtrade.rpc.api_server.ws.types import WebSocketType from freqtrade.rpc.api_server.ws.types import WebSocketType
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,6 +25,8 @@ class WebSocketChannel:
self, self,
websocket: WebSocketType, websocket: WebSocketType,
channel_id: Optional[str] = None, channel_id: Optional[str] = None,
drain_timeout: int = 3,
throttle: float = 0.01,
serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer
): ):
@ -34,7 +37,11 @@ class WebSocketChannel:
# The Serializing class for the WebSocket object # The Serializing class for the WebSocket object
self._serializer_cls = serializer_cls self._serializer_cls = serializer_cls
self.drain_timeout = drain_timeout
self.throttle = throttle
self._subscriptions: List[str] = [] self._subscriptions: List[str] = []
# 32 is the size of the receiving queue in websockets package
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32) self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
self._relay_task = asyncio.create_task(self.relay()) self._relay_task = asyncio.create_task(self.relay())
@ -47,6 +54,10 @@ class WebSocketChannel:
def __repr__(self): def __repr__(self):
return f"WebSocketChannel({self.channel_id}, {self.remote_addr})" return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
@property
def raw_websocket(self):
return self._websocket.raw_websocket
@property @property
def remote_addr(self): def remote_addr(self):
return self._websocket.remote_addr return self._websocket.remote_addr
@ -57,11 +68,19 @@ class WebSocketChannel:
""" """
await self._wrapped_ws.send(data) await self._wrapped_ws.send(data)
async def send(self, data): async def send(self, data) -> bool:
""" """
Add the data to the queue to be sent Add the data to the queue to be sent.
:returns: True if data added to queue, False otherwise
""" """
self.queue.put_nowait(data) try:
await asyncio.wait_for(
self.queue.put(data),
timeout=self.drain_timeout
)
return True
except asyncio.TimeoutError:
return False
async def recv(self): async def recv(self):
""" """
@ -119,8 +138,8 @@ class WebSocketChannel:
# Limit messages per sec. # Limit messages per sec.
# Could cause problems with queue size if too low, and # Could cause problems with queue size if too low, and
# problems with network traffik if too high. # problems with network traffik if too high.
# 0.001 = 1000/s # 0.01 = 100/s
await asyncio.sleep(0.001) await asyncio.sleep(self.throttle)
except RuntimeError: except RuntimeError:
# The connection was closed, just exit the task # The connection was closed, just exit the task
return return
@ -160,6 +179,7 @@ class ChannelManager:
with self._lock: with self._lock:
channel = self.channels.get(websocket) channel = self.channels.get(websocket)
if channel: if channel:
logger.info(f"Disconnecting channel {channel}")
if not channel.is_closed(): if not channel.is_closed():
await channel.close() await channel.close()
@ -170,36 +190,30 @@ class ChannelManager:
Disconnect all Channels Disconnect all Channels
""" """
with self._lock: with self._lock:
for websocket, channel in self.channels.copy().items(): for websocket in self.channels.copy().keys():
if not channel.is_closed(): await self.on_disconnect(websocket)
await channel.close()
self.channels = dict() async def broadcast(self, message: WSMessageSchemaType):
async def broadcast(self, data):
""" """
Broadcast data on all Channels Broadcast a message on all Channels
:param data: The data to send :param message: The message to send
""" """
with self._lock: with self._lock:
message_type = data.get('type') for channel in self.channels.copy().values():
for websocket, channel in self.channels.copy().items(): if channel.subscribed_to(message.get('type')):
if channel.subscribed_to(message_type): await self.send_direct(channel, message)
if not channel.queue.full():
await channel.send(data)
else:
logger.info(f"Channel {channel} is too far behind, disconnecting")
await self.on_disconnect(websocket)
async def send_direct(self, channel, data): async def send_direct(
self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]):
""" """
Send data directly through direct_channel only Send a message directly through direct_channel only
:param direct_channel: The WebSocketChannel object to send data through :param direct_channel: The WebSocketChannel object to send the message through
:param data: The data to send :param message: The message to send
""" """
await channel.send(data) if not await channel.send(message):
await self.on_disconnect(channel.raw_websocket)
def has_channels(self): def has_channels(self):
""" """

View File

@ -15,6 +15,10 @@ class WebSocketProxy:
def __init__(self, websocket: WebSocketType): def __init__(self, websocket: WebSocketType):
self._websocket: Union[FastAPIWebSocket, WebSocket] = websocket self._websocket: Union[FastAPIWebSocket, WebSocket] = websocket
@property
def raw_websocket(self):
return self._websocket
@property @property
def remote_addr(self) -> Tuple[Any, ...]: def remote_addr(self) -> Tuple[Any, ...]:
if isinstance(self._websocket, WebSocket): if isinstance(self._websocket, WebSocket):

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, TypedDict
from pandas import DataFrame from pandas import DataFrame
from pydantic import BaseModel from pydantic import BaseModel
@ -18,6 +18,12 @@ class WSRequestSchema(BaseArbitraryModel):
data: Optional[Any] = None data: Optional[Any] = None
class WSMessageSchemaType(TypedDict):
# Type for typing to avoid doing pydantic typechecks.
type: RPCMessageType
data: Optional[Dict[str, Any]]
class WSMessageSchema(BaseArbitraryModel): class WSMessageSchema(BaseArbitraryModel):
type: RPCMessageType type: RPCMessageType
data: Optional[Any] = None data: Optional[Any] = None

View File

@ -270,6 +270,11 @@ class ExternalMessageConsumer:
logger.debug(f"Connection to {channel} still alive...") logger.debug(f"Connection to {channel} still alive...")
continue continue
except (websockets.exceptions.ConnectionClosed):
# Just eat the error and continue reconnecting
logger.warning(f"Disconnection in {channel} - retrying in {self.sleep_time}s")
await asyncio.sleep(self.sleep_time)
break
except Exception as e: except Exception as e:
logger.warning(f"Ping error {channel} - retrying in {self.sleep_time}s") logger.warning(f"Ping error {channel} - retrying in {self.sleep_time}s")
logger.debug(e, exc_info=e) logger.debug(e, exc_info=e)