Merge pull request #7771 from wizrds/feat/refactor-ws

Refactor WebSocket API for performance
This commit is contained in:
Matthias 2022-11-27 15:49:34 +01:00 committed by GitHub
commit 3fc367f536
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 258 additions and 302 deletions

View File

@ -81,8 +81,6 @@ async def validate_ws_token(
except HTTPException: except HTTPException:
pass pass
# No checks passed, deny the connection
logger.debug("Denying websocket request.")
# If it doesn't match, close the websocket connection # If it doesn't match, close the websocket connection
await ws.close(code=status.WS_1008_POLICY_VIOLATION) await ws.close(code=status.WS_1008_POLICY_VIOLATION)

View File

@ -1,16 +1,16 @@
import logging import logging
import time
from typing import Any, Dict from typing import Any, Dict
from fastapi import APIRouter, Depends, WebSocketDisconnect from fastapi import APIRouter, Depends
from fastapi.websockets import WebSocket, WebSocketState from fastapi.websockets import WebSocket
from pydantic import ValidationError from pydantic import ValidationError
from websockets.exceptions import WebSocketException
from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.enums import RPCMessageType, RPCRequestType
from freqtrade.rpc.api_server.api_auth import validate_ws_token from freqtrade.rpc.api_server.api_auth import validate_ws_token
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc
from freqtrade.rpc.api_server.ws import WebSocketChannel from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel
from freqtrade.rpc.api_server.ws.channel import ChannelManager from freqtrade.rpc.api_server.ws.message_stream import MessageStream
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
WSRequestSchema, WSWhitelistMessage) WSRequestSchema, WSWhitelistMessage)
from freqtrade.rpc.rpc import RPC from freqtrade.rpc.rpc import RPC
@ -22,23 +22,35 @@ logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
async def is_websocket_alive(ws: WebSocket) -> bool: async def channel_reader(channel: WebSocketChannel, rpc: RPC):
""" """
Check if a FastAPI Websocket is still open Iterate over the messages from the channel and process the request
""" """
if ( async for message in channel:
ws.application_state == WebSocketState.CONNECTED and await _process_consumer_request(message, channel, rpc)
ws.client_state == WebSocketState.CONNECTED
):
return True async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream):
return False """
Iterate over messages in the message stream and send them
"""
async for message, ts in message_stream:
if channel.subscribed_to(message.get('type')):
# Log a warning if this channel is behind
# on the message stream by a lot
if (time.time() - ts) > 60:
logger.warning(f"Channel {channel} is behind MessageStream by 1 minute,"
" this can cause a memory leak if you see this message"
" often, consider reducing pair list size or amount of"
" consumers.")
await channel.send(message, timeout=True)
async def _process_consumer_request( async def _process_consumer_request(
request: Dict[str, Any], request: Dict[str, Any],
channel: WebSocketChannel, channel: WebSocketChannel,
rpc: RPC, rpc: RPC
channel_manager: ChannelManager
): ):
""" """
Validate and handle a request from a websocket consumer Validate and handle a request from a websocket consumer
@ -74,65 +86,29 @@ async def _process_consumer_request(
# Format response # Format response
response = WSWhitelistMessage(data=whitelist) response = WSWhitelistMessage(data=whitelist)
# Send it back await channel.send(response.dict(exclude_none=True))
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
elif type == RPCRequestType.ANALYZED_DF: elif type == RPCRequestType.ANALYZED_DF:
limit = None
if data:
# Limit the amount of candles per dataframe to 'limit' or 1500 # Limit the amount of candles per dataframe to 'limit' or 1500
limit = max(data.get('limit', 1500), 1500) limit = min(data.get('limit', 1500), 1500) if data else None
# For every pair in the generator, send a separate message # For every pair in the generator, send a separate message
for message in rpc._ws_request_analyzed_df(limit): for message in rpc._ws_request_analyzed_df(limit):
# Format response
response = WSAnalyzedDFMessage(data=message) response = WSAnalyzedDFMessage(data=message)
await channel_manager.send_direct(channel, response.dict(exclude_none=True)) await channel.send(response.dict(exclude_none=True))
@router.websocket("/message/ws") @router.websocket("/message/ws")
async def message_endpoint( async def message_endpoint(
ws: WebSocket, websocket: WebSocket,
token: str = Depends(validate_ws_token),
rpc: RPC = Depends(get_rpc), rpc: RPC = Depends(get_rpc),
channel_manager=Depends(get_channel_manager), message_stream: MessageStream = Depends(get_message_stream)
token: str = Depends(validate_ws_token)
): ):
""" if token:
Message WebSocket endpoint, facilitates sending RPC messages async with create_channel(websocket) as channel:
""" await channel.run_channel_tasks(
try: channel_reader(channel, rpc),
channel = await channel_manager.on_connect(ws) channel_broadcaster(channel, message_stream)
if await is_websocket_alive(ws): )
logger.info(f"Consumer connected - {channel}")
# Keep connection open until explicitly closed, and process requests
try:
while not channel.is_closed():
request = await channel.recv()
# Process the request here
await _process_consumer_request(request, channel, rpc, channel_manager)
except (WebSocketDisconnect, WebSocketException):
# Handle client disconnects
logger.info(f"Consumer disconnected - {channel}")
except RuntimeError:
# Handle cases like -
# RuntimeError('Cannot call "send" once a closed message has been sent')
pass
except Exception as e:
logger.info(f"Consumer connection failed - {channel}: {e}")
logger.debug(e, exc_info=e)
except RuntimeError:
# WebSocket was closed
# Do nothing
pass
except Exception as e:
logger.error(f"Failed to serve - {ws.client}")
# Log tracebacks to keep track of what errors are happening
logger.exception(e)
finally:
if channel:
await channel_manager.on_disconnect(ws)

View File

@ -41,8 +41,8 @@ def get_exchange(config=Depends(get_config)):
return ApiServer._exchange return ApiServer._exchange
def get_channel_manager(): def get_message_stream():
return ApiServer._ws_channel_manager return ApiServer._message_stream
def is_webserver_mode(config=Depends(get_config)): def is_webserver_mode(config=Depends(get_config)):

View File

@ -1,22 +1,17 @@
import asyncio
import logging import logging
from ipaddress import IPv4Address from ipaddress import IPv4Address
from threading import Thread
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import orjson import orjson
import uvicorn import uvicorn
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
# Look into alternatives
from janus import Queue as ThreadedQueue
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from freqtrade.constants import Config from freqtrade.constants import Config
from freqtrade.exceptions import OperationalException from freqtrade.exceptions import OperationalException
from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer
from freqtrade.rpc.api_server.ws import ChannelManager from freqtrade.rpc.api_server.ws.message_stream import MessageStream
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler
@ -50,10 +45,8 @@ class ApiServer(RPCHandler):
_config: Config = {} _config: Config = {}
# Exchange - only available in webserver mode. # Exchange - only available in webserver mode.
_exchange = None _exchange = None
# websocket message queue stuff # websocket message stuff
_ws_channel_manager: ChannelManager _message_stream: Optional[MessageStream] = None
_ws_thread = None
_ws_loop: Optional[asyncio.AbstractEventLoop] = None
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
""" """
@ -71,15 +64,11 @@ class ApiServer(RPCHandler):
return return
self._standalone: bool = standalone self._standalone: bool = standalone
self._server = None self._server = None
self._ws_queue: Optional[ThreadedQueue] = None
self._ws_background_task = None
ApiServer.__initialized = True ApiServer.__initialized = True
api_config = self._config['api_server'] api_config = self._config['api_server']
ApiServer._ws_channel_manager = ChannelManager()
self.app = FastAPI(title="Freqtrade API", self.app = FastAPI(title="Freqtrade API",
docs_url='/docs' if api_config.get('enable_openapi', False) else None, docs_url='/docs' if api_config.get('enable_openapi', False) else None,
redoc_url=None, redoc_url=None,
@ -105,21 +94,9 @@ class ApiServer(RPCHandler):
del ApiServer._rpc del ApiServer._rpc
if self._server and not self._standalone: if self._server and not self._standalone:
logger.info("Stopping API Server") logger.info("Stopping API Server")
# self._server.force_exit, self._server.should_exit = True, True
self._server.cleanup() self._server.cleanup()
if self._ws_thread and self._ws_loop:
logger.info("Stopping API Server background tasks")
if self._ws_background_task:
# Cancel the queue task
self._ws_background_task.cancel()
self._ws_thread.join()
self._ws_thread = None
self._ws_loop = None
self._ws_background_task = None
@classmethod @classmethod
def shutdown(cls): def shutdown(cls):
cls.__initialized = False cls.__initialized = False
@ -129,9 +106,11 @@ class ApiServer(RPCHandler):
cls._rpc = None cls._rpc = None
def send_msg(self, msg: Dict[str, Any]) -> None: def send_msg(self, msg: Dict[str, Any]) -> None:
if self._ws_queue: """
sync_q = self._ws_queue.sync_q Publish the message to the message stream
sync_q.put(msg) """
if ApiServer._message_stream:
ApiServer._message_stream.publish(msg)
def handle_rpc_exception(self, request, exc): def handle_rpc_exception(self, request, exc):
logger.exception(f"API Error calling: {exc}") logger.exception(f"API Error calling: {exc}")
@ -170,54 +149,30 @@ class ApiServer(RPCHandler):
) )
app.add_exception_handler(RPCException, self.handle_rpc_exception) app.add_exception_handler(RPCException, self.handle_rpc_exception)
app.add_event_handler(
event_type="startup",
func=self._api_startup_event
)
app.add_event_handler(
event_type="shutdown",
func=self._api_shutdown_event
)
def start_message_queue(self): async def _api_startup_event(self):
if self._ws_thread: """
return Creates the MessageStream class on startup
so it has access to the same event loop
as uvicorn
"""
if not ApiServer._message_stream:
ApiServer._message_stream = MessageStream()
# Create a new loop, as it'll be just for the background thread async def _api_shutdown_event(self):
self._ws_loop = asyncio.new_event_loop() """
Removes the MessageStream class on shutdown
# Start the thread """
self._ws_thread = Thread(target=self._ws_loop.run_forever) if ApiServer._message_stream:
self._ws_thread.start() ApiServer._message_stream = None
# Finally, submit the coro to the thread
self._ws_background_task = asyncio.run_coroutine_threadsafe(
self._broadcast_queue_data(), loop=self._ws_loop)
async def _broadcast_queue_data(self) -> None:
# Instantiate the queue in this coroutine so it's attached to our loop
self._ws_queue = ThreadedQueue()
async_queue = self._ws_queue.async_q
try:
while True:
logger.debug("Getting queue messages...")
if (qsize := async_queue.qsize()) > 20:
# If the queue becomes too big for too long, this may indicate a problem.
logger.warning(f"Queue size now {qsize}")
# Get data from queue
message: WSMessageSchemaType = await async_queue.get()
logger.debug(f"Found message of type: {message.get('type')}")
async_queue.task_done()
# Broadcast it
await self._ws_channel_manager.broadcast(message)
except asyncio.CancelledError:
pass
# For testing, shouldn't happen when stable
except Exception as e:
logger.exception(f"Exception happened in background task: {e}")
finally:
# Disconnect channels and stop the loop on cancel
await self._ws_channel_manager.disconnect_all()
if self._ws_loop:
self._ws_loop.stop()
# Avoid adding more items to the queue if they aren't
# going to get broadcasted.
self._ws_queue = None
def start_api(self): def start_api(self):
""" """
@ -257,7 +212,6 @@ class ApiServer(RPCHandler):
if self._standalone: if self._standalone:
self._server.run() self._server.run()
else: else:
self.start_message_queue()
self._server.run_in_thread() self._server.run_in_thread()
except Exception: except Exception:
logger.exception("Api server failed to start.") logger.exception("Api server failed to start.")

View File

@ -3,4 +3,5 @@
from freqtrade.rpc.api_server.ws.types import WebSocketType from freqtrade.rpc.api_server.ws.types import WebSocketType
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import HybridJSONWebSocketSerializer from freqtrade.rpc.api_server.ws.serializer import HybridJSONWebSocketSerializer
from freqtrade.rpc.api_server.ws.channel import ChannelManager, WebSocketChannel from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
from freqtrade.rpc.api_server.ws.message_stream import MessageStream

View File

@ -1,11 +1,13 @@
import asyncio import asyncio
import logging import logging
import time import time
from threading import RLock from collections import deque
from typing import Any, Dict, List, Optional, Type, Union from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Deque, Dict, List, Optional, Type, Union
from uuid import uuid4 from uuid import uuid4
from fastapi import WebSocket as FastAPIWebSocket from fastapi import WebSocketDisconnect
from websockets.exceptions import ConnectionClosed
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer, from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer,
@ -21,31 +23,27 @@ class WebSocketChannel:
""" """
Object to help facilitate managing a websocket connection Object to help facilitate managing a websocket connection
""" """
def __init__( def __init__(
self, self,
websocket: WebSocketType, websocket: WebSocketType,
channel_id: Optional[str] = None, channel_id: Optional[str] = None,
drain_timeout: int = 3,
throttle: float = 0.01,
serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer
): ):
self.channel_id = channel_id if channel_id else uuid4().hex[:8] self.channel_id = channel_id if channel_id else uuid4().hex[:8]
# The WebSocket object
self._websocket = WebSocketProxy(websocket) self._websocket = WebSocketProxy(websocket)
self.drain_timeout = drain_timeout
self.throttle = throttle
self._subscriptions: List[str] = []
# 32 is the size of the receiving queue in websockets package
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
self._relay_task = asyncio.create_task(self.relay())
# Internal event to signify a closed websocket # Internal event to signify a closed websocket
self._closed = asyncio.Event() self._closed = asyncio.Event()
# The async tasks created for the channel
self._channel_tasks: List[asyncio.Task] = []
# Deque for average send times
self._send_times: Deque[float] = deque([], maxlen=10)
# High limit defaults to 3 to start
self._send_high_limit = 3
# The subscribed message types
self._subscriptions: List[str] = []
# Wrap the WebSocket in the Serializing class # Wrap the WebSocket in the Serializing class
self._wrapped_ws = serializer_cls(self._websocket) self._wrapped_ws = serializer_cls(self._websocket)
@ -61,43 +59,58 @@ class WebSocketChannel:
def remote_addr(self): def remote_addr(self):
return self._websocket.remote_addr return self._websocket.remote_addr
async def _send(self, data): @property
""" def avg_send_time(self):
Send data on the wrapped websocket return sum(self._send_times) / len(self._send_times)
"""
await self._wrapped_ws.send(data)
async def send(self, data) -> bool: def _calc_send_limit(self):
""" """
Add the data to the queue to be sent. Calculate the send high limit for this channel
:returns: True if data added to queue, False otherwise
""" """
# This block only runs if the queue is full, it will wait # Only update if we have enough data
# until self.drain_timeout for the relay to drain the outgoing queue if len(self._send_times) == self._send_times.maxlen:
# We can't use asyncio.wait_for here because the queue may have been created with a # At least 1s or twice the average of send times, with a
# different eventloop # maximum of 3 seconds per message
if not self.is_closed(): self._send_high_limit = min(max(self.avg_send_time * 2, 1), 3)
start = time.time()
while self.queue.full():
await asyncio.sleep(1)
if (time.time() - start) > self.drain_timeout:
return False
# If for some reason the queue is still full, just return False async def send(
self,
message: Union[WSMessageSchemaType, Dict[str, Any]],
timeout: bool = False
):
"""
Send a message on the wrapped websocket. If the sending
takes too long, it will raise a TimeoutError and
disconnect the connection.
:param message: The message to send
:param timeout: Enforce send high limit, defaults to False
"""
try: try:
self.queue.put_nowait(data) _ = time.time()
except asyncio.QueueFull: # If the send times out, it will raise
return False # a TimeoutError and bubble up to the
# message_endpoint to close the connection
await asyncio.wait_for(
self._wrapped_ws.send(message),
timeout=self._send_high_limit if timeout else None
)
total_time = time.time() - _
self._send_times.append(total_time)
# If we got here everything is ok self._calc_send_limit()
return True except asyncio.TimeoutError:
else: logger.info(f"Connection for {self} timed out, disconnecting")
return False raise
# Explicitly give control back to event loop as
# websockets.send does not
await asyncio.sleep(0.01)
async def recv(self): async def recv(self):
""" """
Receive data on the wrapped websocket Receive a message on the wrapped websocket
""" """
return await self._wrapped_ws.recv() return await self._wrapped_ws.recv()
@ -107,17 +120,27 @@ class WebSocketChannel:
""" """
return await self._websocket.ping() return await self._websocket.ping()
async def accept(self):
"""
Accept the underlying websocket connection,
if the connection has been closed before we can
accept, just close the channel.
"""
try:
return await self._websocket.accept()
except RuntimeError:
await self.close()
async def close(self): async def close(self):
""" """
Close the WebSocketChannel Close the WebSocketChannel
""" """
self._closed.set() self._closed.set()
self._relay_task.cancel()
try: try:
await self.raw_websocket.close() await self._websocket.close()
except Exception: except RuntimeError:
pass pass
def is_closed(self) -> bool: def is_closed(self) -> bool:
@ -142,99 +165,76 @@ class WebSocketChannel:
""" """
return message_type in self._subscriptions return message_type in self._subscriptions
async def relay(self): async def run_channel_tasks(self, *tasks, **kwargs):
""" """
Relay messages from the channel's queue and send them out. This is started Create and await on the channel tasks unless an exception
as a task. was raised, then cancel them all.
:params *tasks: All coros or tasks to be run concurrently
:param **kwargs: Any extra kwargs to pass to gather
""" """
while not self._closed.is_set():
message = await self.queue.get() if not self.is_closed():
# Wrap the coros into tasks if they aren't already
self._channel_tasks = [
task if isinstance(task, asyncio.Task) else asyncio.create_task(task)
for task in tasks
]
try: try:
await self._send(message) return await asyncio.gather(*self._channel_tasks, **kwargs)
self.queue.task_done() except Exception:
# If an exception occurred, cancel the rest of the tasks
await self.cancel_channel_tasks()
# Limit messages per sec. async def cancel_channel_tasks(self):
# Could cause problems with queue size if too low, and
# problems with network traffik if too high.
# 0.01 = 100/s
await asyncio.sleep(self.throttle)
except RuntimeError:
# The connection was closed, just exit the task
return
class ChannelManager:
def __init__(self):
self.channels = dict()
self._lock = RLock() # Re-entrant Lock
async def on_connect(self, websocket: WebSocketType):
""" """
Wrap websocket connection into Channel and add to list Cancel and wait on all channel tasks
:param websocket: The WebSocket object to attach to the Channel
""" """
if isinstance(websocket, FastAPIWebSocket): for task in self._channel_tasks:
task.cancel()
# Wait for tasks to finish cancelling
try: try:
await websocket.accept() await task
except RuntimeError: except (
# The connection was closed before we could accept it asyncio.CancelledError,
return asyncio.TimeoutError,
WebSocketDisconnect,
ConnectionClosed,
RuntimeError
):
pass
except Exception as e:
logger.info(f"Encountered unknown exception: {e}", exc_info=e)
ws_channel = WebSocketChannel(websocket) self._channel_tasks = []
with self._lock: async def __aiter__(self):
self.channels[websocket] = ws_channel
return ws_channel
async def on_disconnect(self, websocket: WebSocketType):
""" """
Call close on the channel if it's not, and remove from channel list Generator for received messages
:param websocket: The WebSocket objet attached to the Channel
""" """
with self._lock: # We can not catch any errors here as websocket.recv is
channel = self.channels.get(websocket) # the first to catch any disconnects and bubble it up
if channel: # so the connection is garbage collected right away
logger.info(f"Disconnecting channel {channel}") while not self.is_closed():
if not channel.is_closed(): yield await self.recv()
@asynccontextmanager
async def create_channel(
websocket: WebSocketType,
**kwargs
) -> AsyncIterator[WebSocketChannel]:
"""
Context manager for safely opening and closing a WebSocketChannel
"""
channel = WebSocketChannel(websocket, **kwargs)
try:
await channel.accept()
logger.info(f"Connected to channel - {channel}")
yield channel
finally:
await channel.close() await channel.close()
logger.info(f"Disconnected from channel - {channel}")
del self.channels[websocket]
async def disconnect_all(self):
"""
Disconnect all Channels
"""
with self._lock:
for websocket in self.channels.copy().keys():
await self.on_disconnect(websocket)
async def broadcast(self, message: WSMessageSchemaType):
"""
Broadcast a message on all Channels
:param message: The message to send
"""
with self._lock:
for channel in self.channels.copy().values():
if channel.subscribed_to(message.get('type')):
await self.send_direct(channel, message)
async def send_direct(
self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]):
"""
Send a message directly through direct_channel only
:param direct_channel: The WebSocketChannel object to send the message through
:param message: The message to send
"""
if not await channel.send(message):
await self.on_disconnect(channel.raw_websocket)
def has_channels(self):
"""
Flag for more than 0 channels
"""
return len(self.channels) > 0

View File

@ -0,0 +1,31 @@
import asyncio
import time
class MessageStream:
"""
A message stream for consumers to subscribe to,
and for producers to publish to.
"""
def __init__(self):
self._loop = asyncio.get_running_loop()
self._waiter = self._loop.create_future()
def publish(self, message):
"""
Publish a message to this MessageStream
:param message: The message to publish
"""
waiter, self._waiter = self._waiter, self._loop.create_future()
waiter.set_result((message, time.time(), self._waiter))
async def __aiter__(self):
"""
Iterate over the messages in the message stream
"""
waiter = self._waiter
while True:
# Shield the future from being cancelled by a task waiting on it
message, ts, waiter = await asyncio.shield(waiter)
yield message, ts

View File

@ -1,5 +1,6 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Union
import orjson import orjson
import rapidjson import rapidjson
@ -7,6 +8,7 @@ from pandas import DataFrame
from freqtrade.misc import dataframe_to_json, json_to_dataframe from freqtrade.misc import dataframe_to_json, json_to_dataframe
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,17 +26,13 @@ class WebSocketSerializer(ABC):
def _deserialize(self, data): def _deserialize(self, data):
raise NotImplementedError() raise NotImplementedError()
async def send(self, data: bytes): async def send(self, data: Union[WSMessageSchemaType, Dict[str, Any]]):
await self._websocket.send(self._serialize(data)) await self._websocket.send(self._serialize(data))
async def recv(self) -> bytes: async def recv(self) -> bytes:
data = await self._websocket.recv() data = await self._websocket.recv()
return self._deserialize(data) return self._deserialize(data)
async def close(self, code: int = 1000):
await self._websocket.close(code)
class HybridJSONWebSocketSerializer(WebSocketSerializer): class HybridJSONWebSocketSerializer(WebSocketSerializer):
def _serialize(self, data) -> str: def _serialize(self, data) -> str:

View File

@ -57,7 +57,10 @@ def botclient(default_conf, mocker):
try: try:
apiserver = ApiServer(default_conf) apiserver = ApiServer(default_conf)
apiserver.add_rpc_handler(rpc) apiserver.add_rpc_handler(rpc)
yield ftbot, TestClient(apiserver.app) # We need to use the TestClient as a context manager to
# handle lifespan events correctly
with TestClient(apiserver.app) as client:
yield ftbot, client
# Cleanup ... ? # Cleanup ... ?
finally: finally:
if apiserver: if apiserver:
@ -438,7 +441,6 @@ def test_api_cleanup(default_conf, mocker, caplog):
apiserver.cleanup() apiserver.cleanup()
assert apiserver._server.cleanup.call_count == 1 assert apiserver._server.cleanup.call_count == 1
assert log_has("Stopping API Server", caplog) assert log_has("Stopping API Server", caplog)
assert log_has("Stopping API Server background tasks", caplog)
ApiServer.shutdown() ApiServer.shutdown()
@ -1714,12 +1716,14 @@ def test_api_ws_subscribe(botclient, mocker):
with client.websocket_connect(ws_url) as ws: with client.websocket_connect(ws_url) as ws:
ws.send_json({'type': 'subscribe', 'data': ['whitelist']}) ws.send_json({'type': 'subscribe', 'data': ['whitelist']})
time.sleep(1)
# Check call count is now 1 as we sent a valid subscribe request # Check call count is now 1 as we sent a valid subscribe request
assert sub_mock.call_count == 1 assert sub_mock.call_count == 1
with client.websocket_connect(ws_url) as ws: with client.websocket_connect(ws_url) as ws:
ws.send_json({'type': 'subscribe', 'data': 'whitelist'}) ws.send_json({'type': 'subscribe', 'data': 'whitelist'})
time.sleep(1)
# Call count hasn't changed as the subscribe request was invalid # Call count hasn't changed as the subscribe request was invalid
assert sub_mock.call_count == 1 assert sub_mock.call_count == 1
@ -1773,24 +1777,18 @@ def test_api_ws_send_msg(default_conf, mocker, caplog):
mocker.patch('freqtrade.rpc.api_server.ApiServer.start_api') mocker.patch('freqtrade.rpc.api_server.ApiServer.start_api')
apiserver = ApiServer(default_conf) apiserver = ApiServer(default_conf)
apiserver.add_rpc_handler(RPC(get_patched_freqtradebot(mocker, default_conf))) apiserver.add_rpc_handler(RPC(get_patched_freqtradebot(mocker, default_conf)))
apiserver.start_message_queue()
# Give the queue thread time to start
time.sleep(0.2)
# Test message_queue coro receives the message # Start test client context manager to run lifespan events
with TestClient(apiserver.app):
# Test message is published on the Message Stream
test_message = {"type": "status", "data": "test"} test_message = {"type": "status", "data": "test"}
first_waiter = apiserver._message_stream._waiter
apiserver.send_msg(test_message) apiserver.send_msg(test_message)
time.sleep(0.1) # Not sure how else to wait for the coro to receive the data assert first_waiter.result()[0] == test_message
assert log_has("Found message of type: status", caplog)
# Test if exception logged when error occurs in sending
mocker.patch('freqtrade.rpc.api_server.ws.channel.ChannelManager.broadcast',
side_effect=Exception)
second_waiter = apiserver._message_stream._waiter
apiserver.send_msg(test_message) apiserver.send_msg(test_message)
time.sleep(0.1) # Not sure how else to wait for the coro to receive the data assert first_waiter != second_waiter
assert log_has_re(r"Exception happened in background task.*", caplog)
finally: finally:
apiserver.cleanup()
ApiServer.shutdown() ApiServer.shutdown()