removed sleep calls, better channel sending

This commit is contained in:
Timothy Pogue 2022-10-22 19:02:05 -06:00
parent 2b6d00dde4
commit 3d7a311caa
5 changed files with 31 additions and 22 deletions

View File

@ -1,4 +1,3 @@
import asyncio
import logging import logging
from typing import Any, Dict from typing import Any, Dict
@ -11,6 +10,7 @@ from freqtrade.enums import RPCMessageType, RPCRequestType
from freqtrade.rpc.api_server.api_auth import validate_ws_token from freqtrade.rpc.api_server.api_auth import validate_ws_token
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
from freqtrade.rpc.api_server.ws import WebSocketChannel from freqtrade.rpc.api_server.ws import WebSocketChannel
from freqtrade.rpc.api_server.ws.channel import ChannelManager
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
WSRequestSchema, WSWhitelistMessage) WSRequestSchema, WSWhitelistMessage)
from freqtrade.rpc.rpc import RPC from freqtrade.rpc.rpc import RPC
@ -37,7 +37,8 @@ async def is_websocket_alive(ws: WebSocket) -> bool:
async def _process_consumer_request( async def _process_consumer_request(
request: Dict[str, Any], request: Dict[str, Any],
channel: WebSocketChannel, channel: WebSocketChannel,
rpc: RPC rpc: RPC,
channel_manager: ChannelManager
): ):
""" """
Validate and handle a request from a websocket consumer Validate and handle a request from a websocket consumer
@ -72,9 +73,9 @@ async def _process_consumer_request(
whitelist = rpc._ws_request_whitelist() whitelist = rpc._ws_request_whitelist()
# Format response # Format response
response = WSWhitelistMessage(data=whitelist) response = WSWhitelistMessage(data=whitelist).dict(exclude_none=True)
# Send it back # Send it back
await channel.send(response.dict(exclude_none=True)) await channel_manager.send_direct(channel, response)
elif type == RPCRequestType.ANALYZED_DF: elif type == RPCRequestType.ANALYZED_DF:
limit = None limit = None
@ -88,10 +89,8 @@ async def _process_consumer_request(
# For every dataframe, send as a separate message # For every dataframe, send as a separate message
for _, message in analyzed_df.items(): for _, message in analyzed_df.items():
response = WSAnalyzedDFMessage(data=message) response = WSAnalyzedDFMessage(data=message).dict(exclude_none=True)
await channel.send(response.dict(exclude_none=True)) await channel_manager.send_direct(channel, response)
# Throttle the messages to 50/s
await asyncio.sleep(0.02)
@router.websocket("/message/ws") @router.websocket("/message/ws")
@ -116,7 +115,7 @@ async def message_endpoint(
request = await channel.recv() request = await channel.recv()
# Process the request here # Process the request here
await _process_consumer_request(request, channel, rpc) await _process_consumer_request(request, channel, rpc, channel_manager)
except (WebSocketDisconnect, WebSocketException): except (WebSocketDisconnect, WebSocketException):
# Handle client disconnects # Handle client disconnects

View File

@ -198,10 +198,6 @@ class ApiServer(RPCHandler):
logger.debug(f"Found message of type: {message.get('type')}") logger.debug(f"Found message of type: {message.get('type')}")
# Broadcast it # Broadcast it
await self._ws_channel_manager.broadcast(message) await self._ws_channel_manager.broadcast(message)
# Limit messages per sec.
# Could cause problems with queue size if too low, and
# problems with network traffik if too high.
await asyncio.sleep(0.001)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass

View File

@ -25,6 +25,7 @@ class WebSocketChannel:
websocket: WebSocketType, websocket: WebSocketType,
channel_id: Optional[str] = None, channel_id: Optional[str] = None,
drain_timeout: int = 3, drain_timeout: int = 3,
throttle: float = 0.01,
serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer
): ):
@ -36,6 +37,7 @@ class WebSocketChannel:
self._serializer_cls = serializer_cls self._serializer_cls = serializer_cls
self.drain_timeout = drain_timeout self.drain_timeout = drain_timeout
self.throttle = throttle
self._subscriptions: List[str] = [] self._subscriptions: List[str] = []
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32) self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
@ -50,6 +52,10 @@ class WebSocketChannel:
def __repr__(self): def __repr__(self):
return f"WebSocketChannel({self.channel_id}, {self.remote_addr})" return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
@property
def raw(self):
return self._websocket.raw
@property @property
def remote_addr(self): def remote_addr(self):
return self._websocket.remote_addr return self._websocket.remote_addr
@ -131,7 +137,7 @@ class WebSocketChannel:
# Could cause problems with queue size if too low, and # Could cause problems with queue size if too low, and
# problems with network traffik if too high. # problems with network traffik if too high.
# 0.01 = 100/s # 0.01 = 100/s
await asyncio.sleep(0.01) await asyncio.sleep(self.throttle)
except RuntimeError: except RuntimeError:
# The connection was closed, just exit the task # The connection was closed, just exit the task
return return
@ -171,6 +177,7 @@ class ChannelManager:
with self._lock: with self._lock:
channel = self.channels.get(websocket) channel = self.channels.get(websocket)
if channel: if channel:
logger.info(f"Disconnecting channel {channel}")
if not channel.is_closed(): if not channel.is_closed():
await channel.close() await channel.close()
@ -181,9 +188,8 @@ class ChannelManager:
Disconnect all Channels Disconnect all Channels
""" """
with self._lock: with self._lock:
for websocket, channel in self.channels.copy().items(): for websocket in self.channels.copy().keys():
if not channel.is_closed(): await self.on_disconnect(websocket)
await channel.close()
self.channels = dict() self.channels = dict()
@ -195,11 +201,9 @@ class ChannelManager:
""" """
with self._lock: with self._lock:
message_type = data.get('type') message_type = data.get('type')
for websocket, channel in self.channels.copy().items(): for channel in self.channels.copy().values():
if channel.subscribed_to(message_type): if channel.subscribed_to(message_type):
if not await channel.send(data): await self.send_direct(channel, data)
logger.info(f"Channel {channel} is too far behind, disconnecting")
await self.on_disconnect(websocket)
async def send_direct(self, channel, data): async def send_direct(self, channel, data):
""" """
@ -208,7 +212,8 @@ class ChannelManager:
:param direct_channel: The WebSocketChannel object to send data through :param direct_channel: The WebSocketChannel object to send data through
:param data: The data to send :param data: The data to send
""" """
await channel.send(data) if not await channel.send(data):
await self.on_disconnect(channel.raw)
def has_channels(self): def has_channels(self):
""" """

View File

@ -15,6 +15,10 @@ class WebSocketProxy:
def __init__(self, websocket: WebSocketType): def __init__(self, websocket: WebSocketType):
self._websocket: Union[FastAPIWebSocket, WebSocket] = websocket self._websocket: Union[FastAPIWebSocket, WebSocket] = websocket
@property
def raw(self):
return self._websocket
@property @property
def remote_addr(self) -> Tuple[Any, ...]: def remote_addr(self) -> Tuple[Any, ...]:
if isinstance(self._websocket, WebSocket): if isinstance(self._websocket, WebSocket):

View File

@ -270,6 +270,11 @@ class ExternalMessageConsumer:
logger.debug(f"Connection to {channel} still alive...") logger.debug(f"Connection to {channel} still alive...")
continue continue
except (websockets.exceptions.ConnectionClosed):
# Just eat the error and continue reconnecting
logger.warning(f"Disconnection in {channel} - retrying in {self.sleep_time}s")
await asyncio.sleep(self.sleep_time)
break
except Exception as e: except Exception as e:
logger.warning(f"Ping error {channel} - retrying in {self.sleep_time}s") logger.warning(f"Ping error {channel} - retrying in {self.sleep_time}s")
logger.debug(e, exc_info=e) logger.debug(e, exc_info=e)