better error handling, true async sending, more readable api
This commit is contained in:
parent
ba493eb7a7
commit
0cb6f71c02
@ -1,16 +1,14 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from fastapi.websockets import WebSocket, WebSocketDisconnect
|
from fastapi.websockets import WebSocket
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from websockets.exceptions import ConnectionClosed
|
|
||||||
|
|
||||||
from freqtrade.enums import RPCMessageType, RPCRequestType
|
from freqtrade.enums import RPCMessageType, RPCRequestType
|
||||||
from freqtrade.rpc.api_server.api_auth import validate_ws_token
|
from freqtrade.rpc.api_server.api_auth import validate_ws_token
|
||||||
from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc
|
from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc
|
||||||
from freqtrade.rpc.api_server.ws import WebSocketChannel
|
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel
|
||||||
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
|
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
|
||||||
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
|
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
|
||||||
WSRequestSchema, WSWhitelistMessage)
|
WSRequestSchema, WSWhitelistMessage)
|
||||||
@ -23,45 +21,20 @@ logger = logging.getLogger(__name__)
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
class WebSocketChannelClosed(Exception):
|
|
||||||
"""
|
|
||||||
General WebSocket exception to signal closing the channel
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def channel_reader(channel: WebSocketChannel, rpc: RPC):
|
async def channel_reader(channel: WebSocketChannel, rpc: RPC):
|
||||||
"""
|
"""
|
||||||
Iterate over the messages from the channel and process the request
|
Iterate over the messages from the channel and process the request
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
async for message in channel:
|
async for message in channel:
|
||||||
await _process_consumer_request(message, channel, rpc)
|
await _process_consumer_request(message, channel, rpc)
|
||||||
except (
|
|
||||||
RuntimeError,
|
|
||||||
WebSocketDisconnect,
|
|
||||||
ConnectionClosed
|
|
||||||
):
|
|
||||||
raise WebSocketChannelClosed
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream):
|
async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream):
|
||||||
"""
|
"""
|
||||||
Iterate over messages in the message stream and send them
|
Iterate over messages in the message stream and send them
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
async for message in message_stream:
|
async for message in message_stream:
|
||||||
await channel.send(message)
|
await channel.send(message)
|
||||||
except (
|
|
||||||
RuntimeError,
|
|
||||||
WebSocketDisconnect,
|
|
||||||
ConnectionClosed
|
|
||||||
):
|
|
||||||
raise WebSocketChannelClosed
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
async def _process_consumer_request(
|
async def _process_consumer_request(
|
||||||
@ -103,15 +76,11 @@ async def _process_consumer_request(
|
|||||||
|
|
||||||
# Format response
|
# Format response
|
||||||
response = WSWhitelistMessage(data=whitelist)
|
response = WSWhitelistMessage(data=whitelist)
|
||||||
# Send it back
|
|
||||||
await channel.send(response.dict(exclude_none=True))
|
await channel.send(response.dict(exclude_none=True))
|
||||||
|
|
||||||
elif type == RPCRequestType.ANALYZED_DF:
|
elif type == RPCRequestType.ANALYZED_DF:
|
||||||
limit = None
|
|
||||||
|
|
||||||
if data:
|
|
||||||
# Limit the amount of candles per dataframe to 'limit' or 1500
|
# Limit the amount of candles per dataframe to 'limit' or 1500
|
||||||
limit = max(data.get('limit', 1500), 1500)
|
limit = min(data.get('limit', 1500), 1500) if data else None
|
||||||
|
|
||||||
# For every pair in the generator, send a separate message
|
# For every pair in the generator, send a separate message
|
||||||
for message in rpc._ws_request_analyzed_df(limit):
|
for message in rpc._ws_request_analyzed_df(limit):
|
||||||
@ -127,17 +96,8 @@ async def message_endpoint(
|
|||||||
rpc: RPC = Depends(get_rpc),
|
rpc: RPC = Depends(get_rpc),
|
||||||
message_stream: MessageStream = Depends(get_message_stream)
|
message_stream: MessageStream = Depends(get_message_stream)
|
||||||
):
|
):
|
||||||
async with WebSocketChannel(websocket).connect() as channel:
|
async with create_channel(websocket) as channel:
|
||||||
try:
|
await channel.run_channel_tasks(
|
||||||
logger.info(f"Channel connected - {channel}")
|
|
||||||
|
|
||||||
channel_tasks = asyncio.gather(
|
|
||||||
channel_reader(channel, rpc),
|
channel_reader(channel, rpc),
|
||||||
channel_broadcaster(channel, message_stream)
|
channel_broadcaster(channel, message_stream)
|
||||||
)
|
)
|
||||||
await channel_tasks
|
|
||||||
except WebSocketChannelClosed:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
logger.info(f"Channel disconnected - {channel}")
|
|
||||||
channel_tasks.cancel()
|
|
||||||
|
@ -94,6 +94,7 @@ class ApiServer(RPCHandler):
|
|||||||
del ApiServer._rpc
|
del ApiServer._rpc
|
||||||
if self._server and not self._standalone:
|
if self._server and not self._standalone:
|
||||||
logger.info("Stopping API Server")
|
logger.info("Stopping API Server")
|
||||||
|
# self._server.force_exit, self._server.should_exit = True, True
|
||||||
self._server.cleanup()
|
self._server.cleanup()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -29,6 +29,7 @@ class WebSocketChannel:
|
|||||||
|
|
||||||
# Internal event to signify a closed websocket
|
# Internal event to signify a closed websocket
|
||||||
self._closed = asyncio.Event()
|
self._closed = asyncio.Event()
|
||||||
|
self._send_timeout_high_limit = 2
|
||||||
|
|
||||||
# The subscribed message types
|
# The subscribed message types
|
||||||
self._subscriptions: List[str] = []
|
self._subscriptions: List[str] = []
|
||||||
@ -36,6 +37,9 @@ class WebSocketChannel:
|
|||||||
# Wrap the WebSocket in the Serializing class
|
# Wrap the WebSocket in the Serializing class
|
||||||
self._wrapped_ws = serializer_cls(self._websocket)
|
self._wrapped_ws = serializer_cls(self._websocket)
|
||||||
|
|
||||||
|
# The async tasks created for the channel
|
||||||
|
self._channel_tasks: List[asyncio.Task] = []
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
|
return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
|
||||||
|
|
||||||
@ -51,7 +55,14 @@ class WebSocketChannel:
|
|||||||
"""
|
"""
|
||||||
Send a message on the wrapped websocket
|
Send a message on the wrapped websocket
|
||||||
"""
|
"""
|
||||||
await self._wrapped_ws.send(message)
|
|
||||||
|
# Without this sleep, messages would send to one channel
|
||||||
|
# first then another after the first one finished.
|
||||||
|
# With the sleep call, it gives control to the event
|
||||||
|
# loop to schedule other channel send methods.
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
return await self._wrapped_ws.send(message)
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
"""
|
"""
|
||||||
@ -77,7 +88,6 @@ class WebSocketChannel:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self._closed.set()
|
self._closed.set()
|
||||||
self._relay_task.cancel()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._websocket.close()
|
await self._websocket.close()
|
||||||
@ -106,23 +116,68 @@ class WebSocketChannel:
|
|||||||
"""
|
"""
|
||||||
return message_type in self._subscriptions
|
return message_type in self._subscriptions
|
||||||
|
|
||||||
|
async def run_channel_tasks(self, *tasks, **kwargs):
|
||||||
|
"""
|
||||||
|
Create and await on the channel tasks unless an exception
|
||||||
|
was raised, then cancel them all.
|
||||||
|
|
||||||
|
:params *tasks: All coros or tasks to be run concurrently
|
||||||
|
:param **kwargs: Any extra kwargs to pass to gather
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Wrap the coros into tasks if they aren't already
|
||||||
|
self._channel_tasks = [
|
||||||
|
task if isinstance(task, asyncio.Task) else asyncio.create_task(task)
|
||||||
|
for task in tasks
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*self._channel_tasks, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
# If an exception occurred, cancel the rest of the tasks and bubble up
|
||||||
|
# the error that was caught here
|
||||||
|
await self.cancel_channel_tasks()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def cancel_channel_tasks(self):
|
||||||
|
"""
|
||||||
|
Cancel and wait on all channel tasks
|
||||||
|
"""
|
||||||
|
for task in self._channel_tasks:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Wait for tasks to finish cancelling
|
||||||
|
try:
|
||||||
|
await asyncio.wait(self._channel_tasks)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._channel_tasks = []
|
||||||
|
|
||||||
async def __aiter__(self):
|
async def __aiter__(self):
|
||||||
"""
|
"""
|
||||||
Generator for received messages
|
Generator for received messages
|
||||||
"""
|
"""
|
||||||
while True:
|
# We can not catch any errors here as websocket.recv is
|
||||||
try:
|
# the first to catch any disconnects and bubble it up
|
||||||
|
# so the connection is garbage collected right away
|
||||||
|
while not self.is_closed():
|
||||||
yield await self.recv()
|
yield await self.recv()
|
||||||
except Exception:
|
|
||||||
break
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def connect(self):
|
async def create_channel(websocket: WebSocketType, **kwargs):
|
||||||
"""
|
"""
|
||||||
Context manager for safely opening and closing the websocket connection
|
Context manager for safely opening and closing a WebSocketChannel
|
||||||
"""
|
"""
|
||||||
|
channel = WebSocketChannel(websocket, **kwargs)
|
||||||
try:
|
try:
|
||||||
await self.accept()
|
await channel.accept()
|
||||||
yield self
|
logger.info(f"Connected to channel - {channel}")
|
||||||
|
|
||||||
|
yield channel
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
finally:
|
finally:
|
||||||
await self.close()
|
await channel.close()
|
||||||
|
logger.info(f"Disconnected from channel - {channel}")
|
||||||
|
@ -17,7 +17,8 @@ class MessageStream:
|
|||||||
async def subscribe(self):
|
async def subscribe(self):
|
||||||
waiter = self._waiter
|
waiter = self._waiter
|
||||||
while True:
|
while True:
|
||||||
message, waiter = await waiter
|
# Shield the future from being cancelled by a task waiting on it
|
||||||
|
message, waiter = await asyncio.shield(waiter)
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
__aiter__ = subscribe
|
__aiter__ = subscribe
|
||||||
|
Loading…
Reference in New Issue
Block a user