better error handling, true async sending, more readable api
This commit is contained in:
parent
ba493eb7a7
commit
0cb6f71c02
@ -1,16 +1,14 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.websockets import WebSocket, WebSocketDisconnect
|
||||
from fastapi.websockets import WebSocket
|
||||
from pydantic import ValidationError
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from freqtrade.enums import RPCMessageType, RPCRequestType
|
||||
from freqtrade.rpc.api_server.api_auth import validate_ws_token
|
||||
from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc
|
||||
from freqtrade.rpc.api_server.ws import WebSocketChannel
|
||||
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel
|
||||
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
|
||||
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
|
||||
WSRequestSchema, WSWhitelistMessage)
|
||||
@ -23,45 +21,20 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class WebSocketChannelClosed(Exception):
|
||||
"""
|
||||
General WebSocket exception to signal closing the channel
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def channel_reader(channel: WebSocketChannel, rpc: RPC):
|
||||
"""
|
||||
Iterate over the messages from the channel and process the request
|
||||
"""
|
||||
try:
|
||||
async for message in channel:
|
||||
await _process_consumer_request(message, channel, rpc)
|
||||
except (
|
||||
RuntimeError,
|
||||
WebSocketDisconnect,
|
||||
ConnectionClosed
|
||||
):
|
||||
raise WebSocketChannelClosed
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
|
||||
async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream):
|
||||
"""
|
||||
Iterate over messages in the message stream and send them
|
||||
"""
|
||||
try:
|
||||
async for message in message_stream:
|
||||
await channel.send(message)
|
||||
except (
|
||||
RuntimeError,
|
||||
WebSocketDisconnect,
|
||||
ConnectionClosed
|
||||
):
|
||||
raise WebSocketChannelClosed
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
|
||||
async def _process_consumer_request(
|
||||
@ -103,15 +76,11 @@ async def _process_consumer_request(
|
||||
|
||||
# Format response
|
||||
response = WSWhitelistMessage(data=whitelist)
|
||||
# Send it back
|
||||
await channel.send(response.dict(exclude_none=True))
|
||||
|
||||
elif type == RPCRequestType.ANALYZED_DF:
|
||||
limit = None
|
||||
|
||||
if data:
|
||||
# Limit the amount of candles per dataframe to 'limit' or 1500
|
||||
limit = max(data.get('limit', 1500), 1500)
|
||||
limit = min(data.get('limit', 1500), 1500) if data else None
|
||||
|
||||
# For every pair in the generator, send a separate message
|
||||
for message in rpc._ws_request_analyzed_df(limit):
|
||||
@ -127,17 +96,8 @@ async def message_endpoint(
|
||||
rpc: RPC = Depends(get_rpc),
|
||||
message_stream: MessageStream = Depends(get_message_stream)
|
||||
):
|
||||
async with WebSocketChannel(websocket).connect() as channel:
|
||||
try:
|
||||
logger.info(f"Channel connected - {channel}")
|
||||
|
||||
channel_tasks = asyncio.gather(
|
||||
async with create_channel(websocket) as channel:
|
||||
await channel.run_channel_tasks(
|
||||
channel_reader(channel, rpc),
|
||||
channel_broadcaster(channel, message_stream)
|
||||
)
|
||||
await channel_tasks
|
||||
except WebSocketChannelClosed:
|
||||
pass
|
||||
finally:
|
||||
logger.info(f"Channel disconnected - {channel}")
|
||||
channel_tasks.cancel()
|
||||
|
@ -94,6 +94,7 @@ class ApiServer(RPCHandler):
|
||||
del ApiServer._rpc
|
||||
if self._server and not self._standalone:
|
||||
logger.info("Stopping API Server")
|
||||
# self._server.force_exit, self._server.should_exit = True, True
|
||||
self._server.cleanup()
|
||||
|
||||
@classmethod
|
||||
|
@ -29,6 +29,7 @@ class WebSocketChannel:
|
||||
|
||||
# Internal event to signify a closed websocket
|
||||
self._closed = asyncio.Event()
|
||||
self._send_timeout_high_limit = 2
|
||||
|
||||
# The subscribed message types
|
||||
self._subscriptions: List[str] = []
|
||||
@ -36,6 +37,9 @@ class WebSocketChannel:
|
||||
# Wrap the WebSocket in the Serializing class
|
||||
self._wrapped_ws = serializer_cls(self._websocket)
|
||||
|
||||
# The async tasks created for the channel
|
||||
self._channel_tasks: List[asyncio.Task] = []
|
||||
|
||||
def __repr__(self):
|
||||
return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
|
||||
|
||||
@ -51,7 +55,14 @@ class WebSocketChannel:
|
||||
"""
|
||||
Send a message on the wrapped websocket
|
||||
"""
|
||||
await self._wrapped_ws.send(message)
|
||||
|
||||
# Without this sleep, messages would send to one channel
|
||||
# first then another after the first one finished.
|
||||
# With the sleep call, it gives control to the event
|
||||
# loop to schedule other channel send methods.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
return await self._wrapped_ws.send(message)
|
||||
|
||||
async def recv(self):
|
||||
"""
|
||||
@ -77,7 +88,6 @@ class WebSocketChannel:
|
||||
"""
|
||||
|
||||
self._closed.set()
|
||||
self._relay_task.cancel()
|
||||
|
||||
try:
|
||||
await self._websocket.close()
|
||||
@ -106,23 +116,68 @@ class WebSocketChannel:
|
||||
"""
|
||||
return message_type in self._subscriptions
|
||||
|
||||
async def run_channel_tasks(self, *tasks, **kwargs):
|
||||
"""
|
||||
Create and await on the channel tasks unless an exception
|
||||
was raised, then cancel them all.
|
||||
|
||||
:params *tasks: All coros or tasks to be run concurrently
|
||||
:param **kwargs: Any extra kwargs to pass to gather
|
||||
"""
|
||||
|
||||
# Wrap the coros into tasks if they aren't already
|
||||
self._channel_tasks = [
|
||||
task if isinstance(task, asyncio.Task) else asyncio.create_task(task)
|
||||
for task in tasks
|
||||
]
|
||||
|
||||
try:
|
||||
await asyncio.gather(*self._channel_tasks, **kwargs)
|
||||
except Exception:
|
||||
# If an exception occurred, cancel the rest of the tasks and bubble up
|
||||
# the error that was caught here
|
||||
await self.cancel_channel_tasks()
|
||||
raise
|
||||
|
||||
async def cancel_channel_tasks(self):
|
||||
"""
|
||||
Cancel and wait on all channel tasks
|
||||
"""
|
||||
for task in self._channel_tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for tasks to finish cancelling
|
||||
try:
|
||||
await asyncio.wait(self._channel_tasks)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._channel_tasks = []
|
||||
|
||||
async def __aiter__(self):
|
||||
"""
|
||||
Generator for received messages
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# We can not catch any errors here as websocket.recv is
|
||||
# the first to catch any disconnects and bubble it up
|
||||
# so the connection is garbage collected right away
|
||||
while not self.is_closed():
|
||||
yield await self.recv()
|
||||
except Exception:
|
||||
break
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(self):
|
||||
async def create_channel(websocket: WebSocketType, **kwargs):
|
||||
"""
|
||||
Context manager for safely opening and closing the websocket connection
|
||||
Context manager for safely opening and closing a WebSocketChannel
|
||||
"""
|
||||
channel = WebSocketChannel(websocket, **kwargs)
|
||||
try:
|
||||
await self.accept()
|
||||
yield self
|
||||
await channel.accept()
|
||||
logger.info(f"Connected to channel - {channel}")
|
||||
|
||||
yield channel
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
await self.close()
|
||||
await channel.close()
|
||||
logger.info(f"Disconnected from channel - {channel}")
|
||||
|
@ -17,7 +17,8 @@ class MessageStream:
|
||||
async def subscribe(self):
|
||||
waiter = self._waiter
|
||||
while True:
|
||||
message, waiter = await waiter
|
||||
# Shield the future from being cancelled by a task waiting on it
|
||||
message, waiter = await asyncio.shield(waiter)
|
||||
yield message
|
||||
|
||||
__aiter__ = subscribe
|
||||
|
Loading…
Reference in New Issue
Block a user