better error handling, true async sending, more readable api

This commit is contained in:
Timothy Pogue 2022-11-18 13:32:27 -07:00
parent ba493eb7a7
commit 0cb6f71c02
4 changed files with 88 additions and 71 deletions

View File

@ -1,16 +1,14 @@
import asyncio
import logging
from typing import Any, Dict
from fastapi import APIRouter, Depends
from fastapi.websockets import WebSocket, WebSocketDisconnect
from fastapi.websockets import WebSocket
from pydantic import ValidationError
from websockets.exceptions import ConnectionClosed
from freqtrade.enums import RPCMessageType, RPCRequestType
from freqtrade.rpc.api_server.api_auth import validate_ws_token
from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc
from freqtrade.rpc.api_server.ws import WebSocketChannel
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
WSRequestSchema, WSWhitelistMessage)
@ -23,45 +21,20 @@ logger = logging.getLogger(__name__)
router = APIRouter()
class WebSocketChannelClosed(Exception):
"""
General WebSocket exception to signal closing the channel
"""
pass
async def channel_reader(channel: WebSocketChannel, rpc: RPC):
"""
Iterate over the messages from the channel and process the request
"""
try:
async for message in channel:
await _process_consumer_request(message, channel, rpc)
except (
RuntimeError,
WebSocketDisconnect,
ConnectionClosed
):
raise WebSocketChannelClosed
except asyncio.CancelledError:
return
async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream):
"""
Iterate over messages in the message stream and send them
"""
try:
async for message in message_stream:
await channel.send(message)
except (
RuntimeError,
WebSocketDisconnect,
ConnectionClosed
):
raise WebSocketChannelClosed
except asyncio.CancelledError:
return
async def _process_consumer_request(
@ -103,15 +76,11 @@ async def _process_consumer_request(
# Format response
response = WSWhitelistMessage(data=whitelist)
# Send it back
await channel.send(response.dict(exclude_none=True))
elif type == RPCRequestType.ANALYZED_DF:
limit = None
if data:
# Limit the amount of candles per dataframe to 'limit' or 1500
limit = max(data.get('limit', 1500), 1500)
limit = min(data.get('limit', 1500), 1500) if data else None
# For every pair in the generator, send a separate message
for message in rpc._ws_request_analyzed_df(limit):
@ -127,17 +96,8 @@ async def message_endpoint(
rpc: RPC = Depends(get_rpc),
message_stream: MessageStream = Depends(get_message_stream)
):
async with WebSocketChannel(websocket).connect() as channel:
try:
logger.info(f"Channel connected - {channel}")
channel_tasks = asyncio.gather(
async with create_channel(websocket) as channel:
await channel.run_channel_tasks(
channel_reader(channel, rpc),
channel_broadcaster(channel, message_stream)
)
await channel_tasks
except WebSocketChannelClosed:
pass
finally:
logger.info(f"Channel disconnected - {channel}")
channel_tasks.cancel()

View File

@ -94,6 +94,7 @@ class ApiServer(RPCHandler):
del ApiServer._rpc
if self._server and not self._standalone:
logger.info("Stopping API Server")
# self._server.force_exit, self._server.should_exit = True, True
self._server.cleanup()
@classmethod

View File

@ -29,6 +29,7 @@ class WebSocketChannel:
# Internal event to signify a closed websocket
self._closed = asyncio.Event()
self._send_timeout_high_limit = 2
# The subscribed message types
self._subscriptions: List[str] = []
@ -36,6 +37,9 @@ class WebSocketChannel:
# Wrap the WebSocket in the Serializing class
self._wrapped_ws = serializer_cls(self._websocket)
# The async tasks created for the channel
self._channel_tasks: List[asyncio.Task] = []
def __repr__(self):
return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
@ -51,7 +55,14 @@ class WebSocketChannel:
"""
Send a message on the wrapped websocket
"""
await self._wrapped_ws.send(message)
# Without this sleep, messages would send to one channel
# first then another after the first one finished.
# With the sleep call, it gives control to the event
# loop to schedule other channel send methods.
await asyncio.sleep(0)
return await self._wrapped_ws.send(message)
async def recv(self):
"""
@ -77,7 +88,6 @@ class WebSocketChannel:
"""
self._closed.set()
self._relay_task.cancel()
try:
await self._websocket.close()
@ -106,23 +116,68 @@ class WebSocketChannel:
"""
return message_type in self._subscriptions
async def run_channel_tasks(self, *tasks, **kwargs):
"""
Create and await on the channel tasks unless an exception
was raised, then cancel them all.
:params *tasks: All coros or tasks to be run concurrently
:param **kwargs: Any extra kwargs to pass to gather
"""
# Wrap the coros into tasks if they aren't already
self._channel_tasks = [
task if isinstance(task, asyncio.Task) else asyncio.create_task(task)
for task in tasks
]
try:
await asyncio.gather(*self._channel_tasks, **kwargs)
except Exception:
# If an exception occurred, cancel the rest of the tasks and bubble up
# the error that was caught here
await self.cancel_channel_tasks()
raise
async def cancel_channel_tasks(self):
"""
Cancel and wait on all channel tasks
"""
for task in self._channel_tasks:
task.cancel()
# Wait for tasks to finish cancelling
try:
await asyncio.wait(self._channel_tasks)
except asyncio.CancelledError:
pass
self._channel_tasks = []
async def __aiter__(self):
"""
Generator for received messages
"""
while True:
try:
# We can not catch any errors here as websocket.recv is
# the first to catch any disconnects and bubble it up
# so the connection is garbage collected right away
while not self.is_closed():
yield await self.recv()
except Exception:
break
@asynccontextmanager
async def connect(self):
async def create_channel(websocket: WebSocketType, **kwargs):
"""
Context manager for safely opening and closing the websocket connection
Context manager for safely opening and closing a WebSocketChannel
"""
channel = WebSocketChannel(websocket, **kwargs)
try:
await self.accept()
yield self
await channel.accept()
logger.info(f"Connected to channel - {channel}")
yield channel
except Exception:
pass
finally:
await self.close()
await channel.close()
logger.info(f"Disconnected from channel - {channel}")

View File

@ -17,7 +17,8 @@ class MessageStream:
async def subscribe(self):
waiter = self._waiter
while True:
message, waiter = await waiter
# Shield the future from being cancelled by a task waiting on it
message, waiter = await asyncio.shield(waiter)
yield message
__aiter__ = subscribe