better error handling, true async sending, more readable api

This commit is contained in:
Timothy Pogue 2022-11-18 13:32:27 -07:00
parent ba493eb7a7
commit 0cb6f71c02
4 changed files with 88 additions and 71 deletions

View File

@ -1,16 +1,14 @@
import asyncio
import logging import logging
from typing import Any, Dict from typing import Any, Dict
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from fastapi.websockets import WebSocket, WebSocketDisconnect from fastapi.websockets import WebSocket
from pydantic import ValidationError from pydantic import ValidationError
from websockets.exceptions import ConnectionClosed
from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.enums import RPCMessageType, RPCRequestType
from freqtrade.rpc.api_server.api_auth import validate_ws_token from freqtrade.rpc.api_server.api_auth import validate_ws_token
from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc
from freqtrade.rpc.api_server.ws import WebSocketChannel from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel
from freqtrade.rpc.api_server.ws.message_stream import MessageStream from freqtrade.rpc.api_server.ws.message_stream import MessageStream
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
WSRequestSchema, WSWhitelistMessage) WSRequestSchema, WSWhitelistMessage)
@ -23,45 +21,20 @@ logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
class WebSocketChannelClosed(Exception):
"""
General WebSocket exception to signal closing the channel
"""
pass
async def channel_reader(channel: WebSocketChannel, rpc: RPC): async def channel_reader(channel: WebSocketChannel, rpc: RPC):
""" """
Iterate over the messages from the channel and process the request Iterate over the messages from the channel and process the request
""" """
try: async for message in channel:
async for message in channel: await _process_consumer_request(message, channel, rpc)
await _process_consumer_request(message, channel, rpc)
except (
RuntimeError,
WebSocketDisconnect,
ConnectionClosed
):
raise WebSocketChannelClosed
except asyncio.CancelledError:
return
async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream): async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream):
""" """
Iterate over messages in the message stream and send them Iterate over messages in the message stream and send them
""" """
try: async for message in message_stream:
async for message in message_stream: await channel.send(message)
await channel.send(message)
except (
RuntimeError,
WebSocketDisconnect,
ConnectionClosed
):
raise WebSocketChannelClosed
except asyncio.CancelledError:
return
async def _process_consumer_request( async def _process_consumer_request(
@ -103,15 +76,11 @@ async def _process_consumer_request(
# Format response # Format response
response = WSWhitelistMessage(data=whitelist) response = WSWhitelistMessage(data=whitelist)
# Send it back
await channel.send(response.dict(exclude_none=True)) await channel.send(response.dict(exclude_none=True))
elif type == RPCRequestType.ANALYZED_DF: elif type == RPCRequestType.ANALYZED_DF:
limit = None # Limit the amount of candles per dataframe to 'limit' or 1500
limit = min(data.get('limit', 1500), 1500) if data else None
if data:
# Limit the amount of candles per dataframe to 'limit' or 1500
limit = max(data.get('limit', 1500), 1500)
# For every pair in the generator, send a separate message # For every pair in the generator, send a separate message
for message in rpc._ws_request_analyzed_df(limit): for message in rpc._ws_request_analyzed_df(limit):
@ -127,17 +96,8 @@ async def message_endpoint(
rpc: RPC = Depends(get_rpc), rpc: RPC = Depends(get_rpc),
message_stream: MessageStream = Depends(get_message_stream) message_stream: MessageStream = Depends(get_message_stream)
): ):
async with WebSocketChannel(websocket).connect() as channel: async with create_channel(websocket) as channel:
try: await channel.run_channel_tasks(
logger.info(f"Channel connected - {channel}") channel_reader(channel, rpc),
channel_broadcaster(channel, message_stream)
channel_tasks = asyncio.gather( )
channel_reader(channel, rpc),
channel_broadcaster(channel, message_stream)
)
await channel_tasks
except WebSocketChannelClosed:
pass
finally:
logger.info(f"Channel disconnected - {channel}")
channel_tasks.cancel()

View File

@ -94,6 +94,7 @@ class ApiServer(RPCHandler):
del ApiServer._rpc del ApiServer._rpc
if self._server and not self._standalone: if self._server and not self._standalone:
logger.info("Stopping API Server") logger.info("Stopping API Server")
# self._server.force_exit, self._server.should_exit = True, True
self._server.cleanup() self._server.cleanup()
@classmethod @classmethod

View File

@ -29,6 +29,7 @@ class WebSocketChannel:
# Internal event to signify a closed websocket # Internal event to signify a closed websocket
self._closed = asyncio.Event() self._closed = asyncio.Event()
self._send_timeout_high_limit = 2
# The subscribed message types # The subscribed message types
self._subscriptions: List[str] = [] self._subscriptions: List[str] = []
@ -36,6 +37,9 @@ class WebSocketChannel:
# Wrap the WebSocket in the Serializing class # Wrap the WebSocket in the Serializing class
self._wrapped_ws = serializer_cls(self._websocket) self._wrapped_ws = serializer_cls(self._websocket)
# The async tasks created for the channel
self._channel_tasks: List[asyncio.Task] = []
def __repr__(self): def __repr__(self):
return f"WebSocketChannel({self.channel_id}, {self.remote_addr})" return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
@ -51,7 +55,14 @@ class WebSocketChannel:
""" """
Send a message on the wrapped websocket Send a message on the wrapped websocket
""" """
await self._wrapped_ws.send(message)
# Without this sleep, messages would send to one channel
# first then another after the first one finished.
# With the sleep call, it gives control to the event
# loop to schedule other channel send methods.
await asyncio.sleep(0)
return await self._wrapped_ws.send(message)
async def recv(self): async def recv(self):
""" """
@ -77,7 +88,6 @@ class WebSocketChannel:
""" """
self._closed.set() self._closed.set()
self._relay_task.cancel()
try: try:
await self._websocket.close() await self._websocket.close()
@ -106,23 +116,68 @@ class WebSocketChannel:
""" """
return message_type in self._subscriptions return message_type in self._subscriptions
async def run_channel_tasks(self, *tasks, **kwargs):
"""
Create and await on the channel tasks unless an exception
was raised, then cancel them all.
:params *tasks: All coros or tasks to be run concurrently
:param **kwargs: Any extra kwargs to pass to gather
"""
# Wrap the coros into tasks if they aren't already
self._channel_tasks = [
task if isinstance(task, asyncio.Task) else asyncio.create_task(task)
for task in tasks
]
try:
await asyncio.gather(*self._channel_tasks, **kwargs)
except Exception:
# If an exception occurred, cancel the rest of the tasks and bubble up
# the error that was caught here
await self.cancel_channel_tasks()
raise
async def cancel_channel_tasks(self):
"""
Cancel and wait on all channel tasks
"""
for task in self._channel_tasks:
task.cancel()
# Wait for tasks to finish cancelling
try:
await asyncio.wait(self._channel_tasks)
except asyncio.CancelledError:
pass
self._channel_tasks = []
async def __aiter__(self): async def __aiter__(self):
""" """
Generator for received messages Generator for received messages
""" """
while True: # We can not catch any errors here as websocket.recv is
try: # the first to catch any disconnects and bubble it up
yield await self.recv() # so the connection is garbage collected right away
except Exception: while not self.is_closed():
break yield await self.recv()
@asynccontextmanager
async def connect(self): @asynccontextmanager
""" async def create_channel(websocket: WebSocketType, **kwargs):
Context manager for safely opening and closing the websocket connection """
""" Context manager for safely opening and closing a WebSocketChannel
try: """
await self.accept() channel = WebSocketChannel(websocket, **kwargs)
yield self try:
finally: await channel.accept()
await self.close() logger.info(f"Connected to channel - {channel}")
yield channel
except Exception:
pass
finally:
await channel.close()
logger.info(f"Disconnected from channel - {channel}")

View File

@ -17,7 +17,8 @@ class MessageStream:
async def subscribe(self): async def subscribe(self):
waiter = self._waiter waiter = self._waiter
while True: while True:
message, waiter = await waiter # Shield the future from being cancelled by a task waiting on it
message, waiter = await asyncio.shield(waiter)
yield message yield message
__aiter__ = subscribe __aiter__ = subscribe