Merge pull request #7771 from wizrds/feat/refactor-ws

Refactor WebSocket API for performance
This commit is contained in:
Matthias 2022-11-27 15:49:34 +01:00 committed by GitHub
commit 3fc367f536
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 258 additions and 302 deletions

View File

@ -81,8 +81,6 @@ async def validate_ws_token(
except HTTPException:
pass
# No checks passed, deny the connection
logger.debug("Denying websocket request.")
# If it doesn't match, close the websocket connection
await ws.close(code=status.WS_1008_POLICY_VIOLATION)

View File

@ -1,16 +1,16 @@
import logging
import time
from typing import Any, Dict
from fastapi import APIRouter, Depends, WebSocketDisconnect
from fastapi.websockets import WebSocket, WebSocketState
from fastapi import APIRouter, Depends
from fastapi.websockets import WebSocket
from pydantic import ValidationError
from websockets.exceptions import WebSocketException
from freqtrade.enums import RPCMessageType, RPCRequestType
from freqtrade.rpc.api_server.api_auth import validate_ws_token
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
from freqtrade.rpc.api_server.ws import WebSocketChannel
from freqtrade.rpc.api_server.ws.channel import ChannelManager
from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
WSRequestSchema, WSWhitelistMessage)
from freqtrade.rpc.rpc import RPC
@ -22,23 +22,35 @@ logger = logging.getLogger(__name__)
router = APIRouter()
async def is_websocket_alive(ws: WebSocket) -> bool:
async def channel_reader(channel: WebSocketChannel, rpc: RPC):
"""
Check if a FastAPI Websocket is still open
Iterate over the messages from the channel and process the request
"""
if (
ws.application_state == WebSocketState.CONNECTED and
ws.client_state == WebSocketState.CONNECTED
):
return True
return False
async for message in channel:
await _process_consumer_request(message, channel, rpc)
async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream):
"""
Iterate over messages in the message stream and send them
"""
async for message, ts in message_stream:
if channel.subscribed_to(message.get('type')):
# Log a warning if this channel is behind
# on the message stream by a lot
if (time.time() - ts) > 60:
logger.warning(f"Channel {channel} is behind MessageStream by 1 minute,"
" this can cause a memory leak if you see this message"
" often, consider reducing pair list size or amount of"
" consumers.")
await channel.send(message, timeout=True)
async def _process_consumer_request(
request: Dict[str, Any],
channel: WebSocketChannel,
rpc: RPC,
channel_manager: ChannelManager
rpc: RPC
):
"""
Validate and handle a request from a websocket consumer
@ -74,65 +86,29 @@ async def _process_consumer_request(
# Format response
response = WSWhitelistMessage(data=whitelist)
# Send it back
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
await channel.send(response.dict(exclude_none=True))
elif type == RPCRequestType.ANALYZED_DF:
limit = None
if data:
# Limit the amount of candles per dataframe to 'limit' or 1500
limit = max(data.get('limit', 1500), 1500)
# Limit the amount of candles per dataframe to 'limit' or 1500
limit = min(data.get('limit', 1500), 1500) if data else None
# For every pair in the generator, send a separate message
for message in rpc._ws_request_analyzed_df(limit):
# Format response
response = WSAnalyzedDFMessage(data=message)
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
await channel.send(response.dict(exclude_none=True))
@router.websocket("/message/ws")
async def message_endpoint(
ws: WebSocket,
websocket: WebSocket,
token: str = Depends(validate_ws_token),
rpc: RPC = Depends(get_rpc),
channel_manager=Depends(get_channel_manager),
token: str = Depends(validate_ws_token)
message_stream: MessageStream = Depends(get_message_stream)
):
"""
Message WebSocket endpoint, facilitates sending RPC messages
"""
try:
channel = await channel_manager.on_connect(ws)
if await is_websocket_alive(ws):
logger.info(f"Consumer connected - {channel}")
# Keep connection open until explicitly closed, and process requests
try:
while not channel.is_closed():
request = await channel.recv()
# Process the request here
await _process_consumer_request(request, channel, rpc, channel_manager)
except (WebSocketDisconnect, WebSocketException):
# Handle client disconnects
logger.info(f"Consumer disconnected - {channel}")
except RuntimeError:
# Handle cases like -
# RuntimeError('Cannot call "send" once a closed message has been sent')
pass
except Exception as e:
logger.info(f"Consumer connection failed - {channel}: {e}")
logger.debug(e, exc_info=e)
except RuntimeError:
# WebSocket was closed
# Do nothing
pass
except Exception as e:
logger.error(f"Failed to serve - {ws.client}")
# Log tracebacks to keep track of what errors are happening
logger.exception(e)
finally:
if channel:
await channel_manager.on_disconnect(ws)
if token:
async with create_channel(websocket) as channel:
await channel.run_channel_tasks(
channel_reader(channel, rpc),
channel_broadcaster(channel, message_stream)
)

View File

@ -41,8 +41,8 @@ def get_exchange(config=Depends(get_config)):
return ApiServer._exchange
def get_channel_manager():
return ApiServer._ws_channel_manager
def get_message_stream():
return ApiServer._message_stream
def is_webserver_mode(config=Depends(get_config)):

View File

@ -1,22 +1,17 @@
import asyncio
import logging
from ipaddress import IPv4Address
from threading import Thread
from typing import Any, Dict, Optional
import orjson
import uvicorn
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
# Look into alternatives
from janus import Queue as ThreadedQueue
from starlette.responses import JSONResponse
from freqtrade.constants import Config
from freqtrade.exceptions import OperationalException
from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer
from freqtrade.rpc.api_server.ws import ChannelManager
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler
@ -50,10 +45,8 @@ class ApiServer(RPCHandler):
_config: Config = {}
# Exchange - only available in webserver mode.
_exchange = None
# websocket message queue stuff
_ws_channel_manager: ChannelManager
_ws_thread = None
_ws_loop: Optional[asyncio.AbstractEventLoop] = None
# websocket message stuff
_message_stream: Optional[MessageStream] = None
def __new__(cls, *args, **kwargs):
"""
@ -71,15 +64,11 @@ class ApiServer(RPCHandler):
return
self._standalone: bool = standalone
self._server = None
self._ws_queue: Optional[ThreadedQueue] = None
self._ws_background_task = None
ApiServer.__initialized = True
api_config = self._config['api_server']
ApiServer._ws_channel_manager = ChannelManager()
self.app = FastAPI(title="Freqtrade API",
docs_url='/docs' if api_config.get('enable_openapi', False) else None,
redoc_url=None,
@ -105,21 +94,9 @@ class ApiServer(RPCHandler):
del ApiServer._rpc
if self._server and not self._standalone:
logger.info("Stopping API Server")
# self._server.force_exit, self._server.should_exit = True, True
self._server.cleanup()
if self._ws_thread and self._ws_loop:
logger.info("Stopping API Server background tasks")
if self._ws_background_task:
# Cancel the queue task
self._ws_background_task.cancel()
self._ws_thread.join()
self._ws_thread = None
self._ws_loop = None
self._ws_background_task = None
@classmethod
def shutdown(cls):
cls.__initialized = False
@ -129,9 +106,11 @@ class ApiServer(RPCHandler):
cls._rpc = None
def send_msg(self, msg: Dict[str, Any]) -> None:
if self._ws_queue:
sync_q = self._ws_queue.sync_q
sync_q.put(msg)
"""
Publish the message to the message stream
"""
if ApiServer._message_stream:
ApiServer._message_stream.publish(msg)
def handle_rpc_exception(self, request, exc):
logger.exception(f"API Error calling: {exc}")
@ -170,54 +149,30 @@ class ApiServer(RPCHandler):
)
app.add_exception_handler(RPCException, self.handle_rpc_exception)
app.add_event_handler(
event_type="startup",
func=self._api_startup_event
)
app.add_event_handler(
event_type="shutdown",
func=self._api_shutdown_event
)
def start_message_queue(self):
if self._ws_thread:
return
async def _api_startup_event(self):
"""
Creates the MessageStream class on startup
so it has access to the same event loop
as uvicorn
"""
if not ApiServer._message_stream:
ApiServer._message_stream = MessageStream()
# Create a new loop, as it'll be just for the background thread
self._ws_loop = asyncio.new_event_loop()
# Start the thread
self._ws_thread = Thread(target=self._ws_loop.run_forever)
self._ws_thread.start()
# Finally, submit the coro to the thread
self._ws_background_task = asyncio.run_coroutine_threadsafe(
self._broadcast_queue_data(), loop=self._ws_loop)
async def _broadcast_queue_data(self) -> None:
# Instantiate the queue in this coroutine so it's attached to our loop
self._ws_queue = ThreadedQueue()
async_queue = self._ws_queue.async_q
try:
while True:
logger.debug("Getting queue messages...")
if (qsize := async_queue.qsize()) > 20:
# If the queue becomes too big for too long, this may indicate a problem.
logger.warning(f"Queue size now {qsize}")
# Get data from queue
message: WSMessageSchemaType = await async_queue.get()
logger.debug(f"Found message of type: {message.get('type')}")
async_queue.task_done()
# Broadcast it
await self._ws_channel_manager.broadcast(message)
except asyncio.CancelledError:
pass
# For testing, shouldn't happen when stable
except Exception as e:
logger.exception(f"Exception happened in background task: {e}")
finally:
# Disconnect channels and stop the loop on cancel
await self._ws_channel_manager.disconnect_all()
if self._ws_loop:
self._ws_loop.stop()
# Avoid adding more items to the queue if they aren't
# going to get broadcasted.
self._ws_queue = None
async def _api_shutdown_event(self):
"""
Removes the MessageStream class on shutdown
"""
if ApiServer._message_stream:
ApiServer._message_stream = None
def start_api(self):
"""
@ -257,7 +212,6 @@ class ApiServer(RPCHandler):
if self._standalone:
self._server.run()
else:
self.start_message_queue()
self._server.run_in_thread()
except Exception:
logger.exception("Api server failed to start.")

View File

@ -3,4 +3,5 @@
from freqtrade.rpc.api_server.ws.types import WebSocketType
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import HybridJSONWebSocketSerializer
from freqtrade.rpc.api_server.ws.channel import ChannelManager, WebSocketChannel
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
from freqtrade.rpc.api_server.ws.message_stream import MessageStream

View File

@ -1,11 +1,13 @@
import asyncio
import logging
import time
from threading import RLock
from typing import Any, Dict, List, Optional, Type, Union
from collections import deque
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Deque, Dict, List, Optional, Type, Union
from uuid import uuid4
from fastapi import WebSocket as FastAPIWebSocket
from fastapi import WebSocketDisconnect
from websockets.exceptions import ConnectionClosed
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer,
@ -21,31 +23,27 @@ class WebSocketChannel:
"""
Object to help facilitate managing a websocket connection
"""
def __init__(
self,
websocket: WebSocketType,
channel_id: Optional[str] = None,
drain_timeout: int = 3,
throttle: float = 0.01,
serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer
):
self.channel_id = channel_id if channel_id else uuid4().hex[:8]
# The WebSocket object
self._websocket = WebSocketProxy(websocket)
self.drain_timeout = drain_timeout
self.throttle = throttle
self._subscriptions: List[str] = []
# 32 is the size of the receiving queue in websockets package
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
self._relay_task = asyncio.create_task(self.relay())
# Internal event to signify a closed websocket
self._closed = asyncio.Event()
# The async tasks created for the channel
self._channel_tasks: List[asyncio.Task] = []
# Deque for average send times
self._send_times: Deque[float] = deque([], maxlen=10)
# High limit defaults to 3 to start
self._send_high_limit = 3
# The subscribed message types
self._subscriptions: List[str] = []
# Wrap the WebSocket in the Serializing class
self._wrapped_ws = serializer_cls(self._websocket)
@ -61,43 +59,58 @@ class WebSocketChannel:
def remote_addr(self):
return self._websocket.remote_addr
async def _send(self, data):
"""
Send data on the wrapped websocket
"""
await self._wrapped_ws.send(data)
@property
def avg_send_time(self):
return sum(self._send_times) / len(self._send_times)
async def send(self, data) -> bool:
def _calc_send_limit(self):
"""
Add the data to the queue to be sent.
:returns: True if data added to queue, False otherwise
Calculate the send high limit for this channel
"""
# This block only runs if the queue is full, it will wait
# until self.drain_timeout for the relay to drain the outgoing queue
# We can't use asyncio.wait_for here because the queue may have been created with a
# different eventloop
if not self.is_closed():
start = time.time()
while self.queue.full():
await asyncio.sleep(1)
if (time.time() - start) > self.drain_timeout:
return False
# Only update if we have enough data
if len(self._send_times) == self._send_times.maxlen:
# At least 1s or twice the average of send times, with a
# maximum of 3 seconds per message
self._send_high_limit = min(max(self.avg_send_time * 2, 1), 3)
# If for some reason the queue is still full, just return False
try:
self.queue.put_nowait(data)
except asyncio.QueueFull:
return False
async def send(
self,
message: Union[WSMessageSchemaType, Dict[str, Any]],
timeout: bool = False
):
"""
Send a message on the wrapped websocket. If the sending
takes too long, it will raise a TimeoutError and
disconnect the connection.
# If we got here everything is ok
return True
else:
return False
:param message: The message to send
:param timeout: Enforce send high limit, defaults to False
"""
try:
_ = time.time()
# If the send times out, it will raise
# a TimeoutError and bubble up to the
# message_endpoint to close the connection
await asyncio.wait_for(
self._wrapped_ws.send(message),
timeout=self._send_high_limit if timeout else None
)
total_time = time.time() - _
self._send_times.append(total_time)
self._calc_send_limit()
except asyncio.TimeoutError:
logger.info(f"Connection for {self} timed out, disconnecting")
raise
# Explicitly give control back to event loop as
# websockets.send does not
await asyncio.sleep(0.01)
async def recv(self):
"""
Receive data on the wrapped websocket
Receive a message on the wrapped websocket
"""
return await self._wrapped_ws.recv()
@ -107,17 +120,27 @@ class WebSocketChannel:
"""
return await self._websocket.ping()
async def accept(self):
"""
Accept the underlying websocket connection,
if the connection has been closed before we can
accept, just close the channel.
"""
try:
return await self._websocket.accept()
except RuntimeError:
await self.close()
async def close(self):
"""
Close the WebSocketChannel
"""
self._closed.set()
self._relay_task.cancel()
try:
await self.raw_websocket.close()
except Exception:
await self._websocket.close()
except RuntimeError:
pass
def is_closed(self) -> bool:
@ -142,99 +165,76 @@ class WebSocketChannel:
"""
return message_type in self._subscriptions
async def relay(self):
async def run_channel_tasks(self, *tasks, **kwargs):
"""
Relay messages from the channel's queue and send them out. This is started
as a task.
Create and await on the channel tasks unless an exception
was raised, then cancel them all.
:params *tasks: All coros or tasks to be run concurrently
:param **kwargs: Any extra kwargs to pass to gather
"""
while not self._closed.is_set():
message = await self.queue.get()
if not self.is_closed():
# Wrap the coros into tasks if they aren't already
self._channel_tasks = [
task if isinstance(task, asyncio.Task) else asyncio.create_task(task)
for task in tasks
]
try:
await self._send(message)
self.queue.task_done()
return await asyncio.gather(*self._channel_tasks, **kwargs)
except Exception:
# If an exception occurred, cancel the rest of the tasks
await self.cancel_channel_tasks()
# Limit messages per sec.
# Could cause problems with queue size if too low, and
# problems with network traffik if too high.
# 0.01 = 100/s
await asyncio.sleep(self.throttle)
except RuntimeError:
# The connection was closed, just exit the task
return
class ChannelManager:
def __init__(self):
self.channels = dict()
self._lock = RLock() # Re-entrant Lock
async def on_connect(self, websocket: WebSocketType):
async def cancel_channel_tasks(self):
"""
Wrap websocket connection into Channel and add to list
:param websocket: The WebSocket object to attach to the Channel
Cancel and wait on all channel tasks
"""
if isinstance(websocket, FastAPIWebSocket):
for task in self._channel_tasks:
task.cancel()
# Wait for tasks to finish cancelling
try:
await websocket.accept()
except RuntimeError:
# The connection was closed before we could accept it
return
await task
except (
asyncio.CancelledError,
asyncio.TimeoutError,
WebSocketDisconnect,
ConnectionClosed,
RuntimeError
):
pass
except Exception as e:
logger.info(f"Encountered unknown exception: {e}", exc_info=e)
ws_channel = WebSocketChannel(websocket)
self._channel_tasks = []
with self._lock:
self.channels[websocket] = ws_channel
return ws_channel
async def on_disconnect(self, websocket: WebSocketType):
async def __aiter__(self):
"""
Call close on the channel if it's not, and remove from channel list
Generator for received messages
"""
# We can not catch any errors here as websocket.recv is
# the first to catch any disconnects and bubble it up
# so the connection is garbage collected right away
while not self.is_closed():
yield await self.recv()
:param websocket: The WebSocket objet attached to the Channel
"""
with self._lock:
channel = self.channels.get(websocket)
if channel:
logger.info(f"Disconnecting channel {channel}")
if not channel.is_closed():
await channel.close()
del self.channels[websocket]
@asynccontextmanager
async def create_channel(
websocket: WebSocketType,
**kwargs
) -> AsyncIterator[WebSocketChannel]:
"""
Context manager for safely opening and closing a WebSocketChannel
"""
channel = WebSocketChannel(websocket, **kwargs)
try:
await channel.accept()
logger.info(f"Connected to channel - {channel}")
async def disconnect_all(self):
"""
Disconnect all Channels
"""
with self._lock:
for websocket in self.channels.copy().keys():
await self.on_disconnect(websocket)
async def broadcast(self, message: WSMessageSchemaType):
"""
Broadcast a message on all Channels
:param message: The message to send
"""
with self._lock:
for channel in self.channels.copy().values():
if channel.subscribed_to(message.get('type')):
await self.send_direct(channel, message)
async def send_direct(
self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]):
"""
Send a message directly through direct_channel only
:param direct_channel: The WebSocketChannel object to send the message through
:param message: The message to send
"""
if not await channel.send(message):
await self.on_disconnect(channel.raw_websocket)
def has_channels(self):
"""
Flag for more than 0 channels
"""
return len(self.channels) > 0
yield channel
finally:
await channel.close()
logger.info(f"Disconnected from channel - {channel}")

View File

@ -0,0 +1,31 @@
import asyncio
import time
class MessageStream:
"""
A message stream for consumers to subscribe to,
and for producers to publish to.
"""
def __init__(self):
self._loop = asyncio.get_running_loop()
self._waiter = self._loop.create_future()
def publish(self, message):
"""
Publish a message to this MessageStream
:param message: The message to publish
"""
waiter, self._waiter = self._waiter, self._loop.create_future()
waiter.set_result((message, time.time(), self._waiter))
async def __aiter__(self):
"""
Iterate over the messages in the message stream
"""
waiter = self._waiter
while True:
# Shield the future from being cancelled by a task waiting on it
message, ts, waiter = await asyncio.shield(waiter)
yield message, ts

View File

@ -1,5 +1,6 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Union
import orjson
import rapidjson
@ -7,6 +8,7 @@ from pandas import DataFrame
from freqtrade.misc import dataframe_to_json, json_to_dataframe
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
logger = logging.getLogger(__name__)
@ -24,17 +26,13 @@ class WebSocketSerializer(ABC):
def _deserialize(self, data):
raise NotImplementedError()
async def send(self, data: bytes):
async def send(self, data: Union[WSMessageSchemaType, Dict[str, Any]]):
await self._websocket.send(self._serialize(data))
async def recv(self) -> bytes:
data = await self._websocket.recv()
return self._deserialize(data)
async def close(self, code: int = 1000):
await self._websocket.close(code)
class HybridJSONWebSocketSerializer(WebSocketSerializer):
def _serialize(self, data) -> str:

View File

@ -57,7 +57,10 @@ def botclient(default_conf, mocker):
try:
apiserver = ApiServer(default_conf)
apiserver.add_rpc_handler(rpc)
yield ftbot, TestClient(apiserver.app)
# We need to use the TestClient as a context manager to
# handle lifespan events correctly
with TestClient(apiserver.app) as client:
yield ftbot, client
# Cleanup ... ?
finally:
if apiserver:
@ -438,7 +441,6 @@ def test_api_cleanup(default_conf, mocker, caplog):
apiserver.cleanup()
assert apiserver._server.cleanup.call_count == 1
assert log_has("Stopping API Server", caplog)
assert log_has("Stopping API Server background tasks", caplog)
ApiServer.shutdown()
@ -1714,12 +1716,14 @@ def test_api_ws_subscribe(botclient, mocker):
with client.websocket_connect(ws_url) as ws:
ws.send_json({'type': 'subscribe', 'data': ['whitelist']})
time.sleep(1)
# Check call count is now 1 as we sent a valid subscribe request
assert sub_mock.call_count == 1
with client.websocket_connect(ws_url) as ws:
ws.send_json({'type': 'subscribe', 'data': 'whitelist'})
time.sleep(1)
# Call count hasn't changed as the subscribe request was invalid
assert sub_mock.call_count == 1
@ -1773,24 +1777,18 @@ def test_api_ws_send_msg(default_conf, mocker, caplog):
mocker.patch('freqtrade.rpc.api_server.ApiServer.start_api')
apiserver = ApiServer(default_conf)
apiserver.add_rpc_handler(RPC(get_patched_freqtradebot(mocker, default_conf)))
apiserver.start_message_queue()
# Give the queue thread time to start
time.sleep(0.2)
# Test message_queue coro receives the message
test_message = {"type": "status", "data": "test"}
apiserver.send_msg(test_message)
time.sleep(0.1) # Not sure how else to wait for the coro to receive the data
assert log_has("Found message of type: status", caplog)
# Start test client context manager to run lifespan events
with TestClient(apiserver.app):
# Test message is published on the Message Stream
test_message = {"type": "status", "data": "test"}
first_waiter = apiserver._message_stream._waiter
apiserver.send_msg(test_message)
assert first_waiter.result()[0] == test_message
# Test if exception logged when error occurs in sending
mocker.patch('freqtrade.rpc.api_server.ws.channel.ChannelManager.broadcast',
side_effect=Exception)
apiserver.send_msg(test_message)
time.sleep(0.1) # Not sure how else to wait for the coro to receive the data
assert log_has_re(r"Exception happened in background task.*", caplog)
second_waiter = apiserver._message_stream._waiter
apiserver.send_msg(test_message)
assert first_waiter != second_waiter
finally:
apiserver.cleanup()
ApiServer.shutdown()