Merge pull request #7771 from wizrds/feat/refactor-ws
Refactor WebSocket API for performance
This commit is contained in:
		| @@ -81,8 +81,6 @@ async def validate_ws_token( | ||||
|     except HTTPException: | ||||
|         pass | ||||
|  | ||||
|     # No checks passed, deny the connection | ||||
|     logger.debug("Denying websocket request.") | ||||
|     # If it doesn't match, close the websocket connection | ||||
|     await ws.close(code=status.WS_1008_POLICY_VIOLATION) | ||||
|  | ||||
|   | ||||
| @@ -1,16 +1,16 @@ | ||||
| import logging | ||||
| import time | ||||
| from typing import Any, Dict | ||||
|  | ||||
| from fastapi import APIRouter, Depends, WebSocketDisconnect | ||||
| from fastapi.websockets import WebSocket, WebSocketState | ||||
| from fastapi import APIRouter, Depends | ||||
| from fastapi.websockets import WebSocket | ||||
| from pydantic import ValidationError | ||||
| from websockets.exceptions import WebSocketException | ||||
|  | ||||
| from freqtrade.enums import RPCMessageType, RPCRequestType | ||||
| from freqtrade.rpc.api_server.api_auth import validate_ws_token | ||||
| from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc | ||||
| from freqtrade.rpc.api_server.ws import WebSocketChannel | ||||
| from freqtrade.rpc.api_server.ws.channel import ChannelManager | ||||
| from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc | ||||
| from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel | ||||
| from freqtrade.rpc.api_server.ws.message_stream import MessageStream | ||||
| from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, | ||||
|                                                  WSRequestSchema, WSWhitelistMessage) | ||||
| from freqtrade.rpc.rpc import RPC | ||||
| @@ -22,23 +22,35 @@ logger = logging.getLogger(__name__) | ||||
| router = APIRouter() | ||||
|  | ||||
|  | ||||
| async def is_websocket_alive(ws: WebSocket) -> bool: | ||||
| async def channel_reader(channel: WebSocketChannel, rpc: RPC): | ||||
|     """ | ||||
|     Check if a FastAPI Websocket is still open | ||||
|     Iterate over the messages from the channel and process the request | ||||
|     """ | ||||
|     if ( | ||||
|         ws.application_state == WebSocketState.CONNECTED and | ||||
|         ws.client_state == WebSocketState.CONNECTED | ||||
|     ): | ||||
|         return True | ||||
|     return False | ||||
|     async for message in channel: | ||||
|         await _process_consumer_request(message, channel, rpc) | ||||
|  | ||||
|  | ||||
| async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream): | ||||
|     """ | ||||
|     Iterate over messages in the message stream and send them | ||||
|     """ | ||||
|     async for message, ts in message_stream: | ||||
|         if channel.subscribed_to(message.get('type')): | ||||
|             # Log a warning if this channel is behind | ||||
|             # on the message stream by a lot | ||||
|             if (time.time() - ts) > 60: | ||||
|                 logger.warning(f"Channel {channel} is behind MessageStream by 1 minute," | ||||
|                                " this can cause a memory leak if you see this message" | ||||
|                                " often, consider reducing pair list size or amount of" | ||||
|                                " consumers.") | ||||
|  | ||||
|             await channel.send(message, timeout=True) | ||||
|  | ||||
|  | ||||
| async def _process_consumer_request( | ||||
|     request: Dict[str, Any], | ||||
|     channel: WebSocketChannel, | ||||
|     rpc: RPC, | ||||
|     channel_manager: ChannelManager | ||||
|     rpc: RPC | ||||
| ): | ||||
|     """ | ||||
|     Validate and handle a request from a websocket consumer | ||||
| @@ -74,65 +86,29 @@ async def _process_consumer_request( | ||||
|  | ||||
|         # Format response | ||||
|         response = WSWhitelistMessage(data=whitelist) | ||||
|         # Send it back | ||||
|         await channel_manager.send_direct(channel, response.dict(exclude_none=True)) | ||||
|         await channel.send(response.dict(exclude_none=True)) | ||||
|  | ||||
|     elif type == RPCRequestType.ANALYZED_DF: | ||||
|         limit = None | ||||
|  | ||||
|         if data: | ||||
|             # Limit the amount of candles per dataframe to 'limit' or 1500 | ||||
|             limit = max(data.get('limit', 1500), 1500) | ||||
|         # Limit the amount of candles per dataframe to 'limit' or 1500 | ||||
|         limit = min(data.get('limit', 1500), 1500) if data else None | ||||
|  | ||||
|         # For every pair in the generator, send a separate message | ||||
|         for message in rpc._ws_request_analyzed_df(limit): | ||||
|             # Format response | ||||
|             response = WSAnalyzedDFMessage(data=message) | ||||
|             await channel_manager.send_direct(channel, response.dict(exclude_none=True)) | ||||
|             await channel.send(response.dict(exclude_none=True)) | ||||
|  | ||||
|  | ||||
| @router.websocket("/message/ws") | ||||
| async def message_endpoint( | ||||
|     ws: WebSocket, | ||||
|     websocket: WebSocket, | ||||
|     token: str = Depends(validate_ws_token), | ||||
|     rpc: RPC = Depends(get_rpc), | ||||
|     channel_manager=Depends(get_channel_manager), | ||||
|     token: str = Depends(validate_ws_token) | ||||
|     message_stream: MessageStream = Depends(get_message_stream) | ||||
| ): | ||||
|     """ | ||||
|     Message WebSocket endpoint, facilitates sending RPC messages | ||||
|     """ | ||||
|     try: | ||||
|         channel = await channel_manager.on_connect(ws) | ||||
|         if await is_websocket_alive(ws): | ||||
|  | ||||
|             logger.info(f"Consumer connected - {channel}") | ||||
|  | ||||
|             # Keep connection open until explicitly closed, and process requests | ||||
|             try: | ||||
|                 while not channel.is_closed(): | ||||
|                     request = await channel.recv() | ||||
|  | ||||
|                     # Process the request here | ||||
|                     await _process_consumer_request(request, channel, rpc, channel_manager) | ||||
|  | ||||
|             except (WebSocketDisconnect, WebSocketException): | ||||
|                 # Handle client disconnects | ||||
|                 logger.info(f"Consumer disconnected - {channel}") | ||||
|             except RuntimeError: | ||||
|                 # Handle cases like - | ||||
|                 # RuntimeError('Cannot call "send" once a closed message has been sent') | ||||
|                 pass | ||||
|             except Exception as e: | ||||
|                 logger.info(f"Consumer connection failed - {channel}: {e}") | ||||
|                 logger.debug(e, exc_info=e) | ||||
|  | ||||
|     except RuntimeError: | ||||
|         # WebSocket was closed | ||||
|         # Do nothing | ||||
|         pass | ||||
|     except Exception as e: | ||||
|         logger.error(f"Failed to serve - {ws.client}") | ||||
|         # Log tracebacks to keep track of what errors are happening | ||||
|         logger.exception(e) | ||||
|     finally: | ||||
|         if channel: | ||||
|             await channel_manager.on_disconnect(ws) | ||||
|     if token: | ||||
|         async with create_channel(websocket) as channel: | ||||
|             await channel.run_channel_tasks( | ||||
|                 channel_reader(channel, rpc), | ||||
|                 channel_broadcaster(channel, message_stream) | ||||
|             ) | ||||
|   | ||||
| @@ -41,8 +41,8 @@ def get_exchange(config=Depends(get_config)): | ||||
|     return ApiServer._exchange | ||||
|  | ||||
|  | ||||
| def get_channel_manager(): | ||||
|     return ApiServer._ws_channel_manager | ||||
| def get_message_stream(): | ||||
|     return ApiServer._message_stream | ||||
|  | ||||
|  | ||||
| def is_webserver_mode(config=Depends(get_config)): | ||||
|   | ||||
| @@ -1,22 +1,17 @@ | ||||
| import asyncio | ||||
| import logging | ||||
| from ipaddress import IPv4Address | ||||
| from threading import Thread | ||||
| from typing import Any, Dict, Optional | ||||
|  | ||||
| import orjson | ||||
| import uvicorn | ||||
| from fastapi import Depends, FastAPI | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| # Look into alternatives | ||||
| from janus import Queue as ThreadedQueue | ||||
| from starlette.responses import JSONResponse | ||||
|  | ||||
| from freqtrade.constants import Config | ||||
| from freqtrade.exceptions import OperationalException | ||||
| from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer | ||||
| from freqtrade.rpc.api_server.ws import ChannelManager | ||||
| from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType | ||||
| from freqtrade.rpc.api_server.ws.message_stream import MessageStream | ||||
| from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler | ||||
|  | ||||
|  | ||||
| @@ -50,10 +45,8 @@ class ApiServer(RPCHandler): | ||||
|     _config: Config = {} | ||||
|     # Exchange - only available in webserver mode. | ||||
|     _exchange = None | ||||
|     # websocket message queue stuff | ||||
|     _ws_channel_manager: ChannelManager | ||||
|     _ws_thread = None | ||||
|     _ws_loop: Optional[asyncio.AbstractEventLoop] = None | ||||
|     # websocket message stuff | ||||
|     _message_stream: Optional[MessageStream] = None | ||||
|  | ||||
|     def __new__(cls, *args, **kwargs): | ||||
|         """ | ||||
| @@ -71,15 +64,11 @@ class ApiServer(RPCHandler): | ||||
|             return | ||||
|         self._standalone: bool = standalone | ||||
|         self._server = None | ||||
|         self._ws_queue: Optional[ThreadedQueue] = None | ||||
|         self._ws_background_task = None | ||||
|  | ||||
|         ApiServer.__initialized = True | ||||
|  | ||||
|         api_config = self._config['api_server'] | ||||
|  | ||||
|         ApiServer._ws_channel_manager = ChannelManager() | ||||
|  | ||||
|         self.app = FastAPI(title="Freqtrade API", | ||||
|                            docs_url='/docs' if api_config.get('enable_openapi', False) else None, | ||||
|                            redoc_url=None, | ||||
| @@ -105,21 +94,9 @@ class ApiServer(RPCHandler): | ||||
|         del ApiServer._rpc | ||||
|         if self._server and not self._standalone: | ||||
|             logger.info("Stopping API Server") | ||||
|             # self._server.force_exit, self._server.should_exit = True, True | ||||
|             self._server.cleanup() | ||||
|  | ||||
|         if self._ws_thread and self._ws_loop: | ||||
|             logger.info("Stopping API Server background tasks") | ||||
|  | ||||
|             if self._ws_background_task: | ||||
|                 # Cancel the queue task | ||||
|                 self._ws_background_task.cancel() | ||||
|  | ||||
|             self._ws_thread.join() | ||||
|  | ||||
|         self._ws_thread = None | ||||
|         self._ws_loop = None | ||||
|         self._ws_background_task = None | ||||
|  | ||||
|     @classmethod | ||||
|     def shutdown(cls): | ||||
|         cls.__initialized = False | ||||
| @@ -129,9 +106,11 @@ class ApiServer(RPCHandler): | ||||
|         cls._rpc = None | ||||
|  | ||||
|     def send_msg(self, msg: Dict[str, Any]) -> None: | ||||
|         if self._ws_queue: | ||||
|             sync_q = self._ws_queue.sync_q | ||||
|             sync_q.put(msg) | ||||
|         """ | ||||
|         Publish the message to the message stream | ||||
|         """ | ||||
|         if ApiServer._message_stream: | ||||
|             ApiServer._message_stream.publish(msg) | ||||
|  | ||||
|     def handle_rpc_exception(self, request, exc): | ||||
|         logger.exception(f"API Error calling: {exc}") | ||||
| @@ -170,54 +149,30 @@ class ApiServer(RPCHandler): | ||||
|         ) | ||||
|  | ||||
|         app.add_exception_handler(RPCException, self.handle_rpc_exception) | ||||
|         app.add_event_handler( | ||||
|             event_type="startup", | ||||
|             func=self._api_startup_event | ||||
|         ) | ||||
|         app.add_event_handler( | ||||
|             event_type="shutdown", | ||||
|             func=self._api_shutdown_event | ||||
|         ) | ||||
|  | ||||
|     def start_message_queue(self): | ||||
|         if self._ws_thread: | ||||
|             return | ||||
|     async def _api_startup_event(self): | ||||
|         """ | ||||
|         Creates the MessageStream class on startup | ||||
|         so it has access to the same event loop | ||||
|         as uvicorn | ||||
|         """ | ||||
|         if not ApiServer._message_stream: | ||||
|             ApiServer._message_stream = MessageStream() | ||||
|  | ||||
|         # Create a new loop, as it'll be just for the background thread | ||||
|         self._ws_loop = asyncio.new_event_loop() | ||||
|  | ||||
|         # Start the thread | ||||
|         self._ws_thread = Thread(target=self._ws_loop.run_forever) | ||||
|         self._ws_thread.start() | ||||
|  | ||||
|         # Finally, submit the coro to the thread | ||||
|         self._ws_background_task = asyncio.run_coroutine_threadsafe( | ||||
|             self._broadcast_queue_data(), loop=self._ws_loop) | ||||
|  | ||||
|     async def _broadcast_queue_data(self) -> None: | ||||
|         # Instantiate the queue in this coroutine so it's attached to our loop | ||||
|         self._ws_queue = ThreadedQueue() | ||||
|         async_queue = self._ws_queue.async_q | ||||
|  | ||||
|         try: | ||||
|             while True: | ||||
|                 logger.debug("Getting queue messages...") | ||||
|                 if (qsize := async_queue.qsize()) > 20: | ||||
|                     # If the queue becomes too big for too long, this may indicate a problem. | ||||
|                     logger.warning(f"Queue size now {qsize}") | ||||
|                 # Get data from queue | ||||
|                 message: WSMessageSchemaType = await async_queue.get() | ||||
|                 logger.debug(f"Found message of type: {message.get('type')}") | ||||
|                 async_queue.task_done() | ||||
|                 # Broadcast it | ||||
|                 await self._ws_channel_manager.broadcast(message) | ||||
|         except asyncio.CancelledError: | ||||
|             pass | ||||
|  | ||||
|         # For testing, shouldn't happen when stable | ||||
|         except Exception as e: | ||||
|             logger.exception(f"Exception happened in background task: {e}") | ||||
|  | ||||
|         finally: | ||||
|             # Disconnect channels and stop the loop on cancel | ||||
|             await self._ws_channel_manager.disconnect_all() | ||||
|             if self._ws_loop: | ||||
|                 self._ws_loop.stop() | ||||
|             # Avoid adding more items to the queue if they aren't | ||||
|             # going to get broadcasted. | ||||
|             self._ws_queue = None | ||||
|     async def _api_shutdown_event(self): | ||||
|         """ | ||||
|         Removes the MessageStream class on shutdown | ||||
|         """ | ||||
|         if ApiServer._message_stream: | ||||
|             ApiServer._message_stream = None | ||||
|  | ||||
|     def start_api(self): | ||||
|         """ | ||||
| @@ -257,7 +212,6 @@ class ApiServer(RPCHandler): | ||||
|             if self._standalone: | ||||
|                 self._server.run() | ||||
|             else: | ||||
|                 self.start_message_queue() | ||||
|                 self._server.run_in_thread() | ||||
|         except Exception: | ||||
|             logger.exception("Api server failed to start.") | ||||
|   | ||||
| @@ -3,4 +3,5 @@ | ||||
| from freqtrade.rpc.api_server.ws.types import WebSocketType | ||||
| from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy | ||||
| from freqtrade.rpc.api_server.ws.serializer import HybridJSONWebSocketSerializer | ||||
| from freqtrade.rpc.api_server.ws.channel import ChannelManager, WebSocketChannel | ||||
| from freqtrade.rpc.api_server.ws.channel import WebSocketChannel | ||||
| from freqtrade.rpc.api_server.ws.message_stream import MessageStream | ||||
|   | ||||
| @@ -1,11 +1,13 @@ | ||||
| import asyncio | ||||
| import logging | ||||
| import time | ||||
| from threading import RLock | ||||
| from typing import Any, Dict, List, Optional, Type, Union | ||||
| from collections import deque | ||||
| from contextlib import asynccontextmanager | ||||
| from typing import Any, AsyncIterator, Deque, Dict, List, Optional, Type, Union | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from fastapi import WebSocket as FastAPIWebSocket | ||||
| from fastapi import WebSocketDisconnect | ||||
| from websockets.exceptions import ConnectionClosed | ||||
|  | ||||
| from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy | ||||
| from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer, | ||||
| @@ -21,31 +23,27 @@ class WebSocketChannel: | ||||
|     """ | ||||
|     Object to help facilitate managing a websocket connection | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         websocket: WebSocketType, | ||||
|         channel_id: Optional[str] = None, | ||||
|         drain_timeout: int = 3, | ||||
|         throttle: float = 0.01, | ||||
|         serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer | ||||
|     ): | ||||
|  | ||||
|         self.channel_id = channel_id if channel_id else uuid4().hex[:8] | ||||
|  | ||||
|         # The WebSocket object | ||||
|         self._websocket = WebSocketProxy(websocket) | ||||
|  | ||||
|         self.drain_timeout = drain_timeout | ||||
|         self.throttle = throttle | ||||
|  | ||||
|         self._subscriptions: List[str] = [] | ||||
|         # 32 is the size of the receiving queue in websockets package | ||||
|         self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32) | ||||
|         self._relay_task = asyncio.create_task(self.relay()) | ||||
|  | ||||
|         # Internal event to signify a closed websocket | ||||
|         self._closed = asyncio.Event() | ||||
|         # The async tasks created for the channel | ||||
|         self._channel_tasks: List[asyncio.Task] = [] | ||||
|  | ||||
|         # Deque for average send times | ||||
|         self._send_times: Deque[float] = deque([], maxlen=10) | ||||
|         # High limit defaults to 3 to start | ||||
|         self._send_high_limit = 3 | ||||
|  | ||||
|         # The subscribed message types | ||||
|         self._subscriptions: List[str] = [] | ||||
|  | ||||
|         # Wrap the WebSocket in the Serializing class | ||||
|         self._wrapped_ws = serializer_cls(self._websocket) | ||||
| @@ -61,43 +59,58 @@ class WebSocketChannel: | ||||
|     def remote_addr(self): | ||||
|         return self._websocket.remote_addr | ||||
|  | ||||
|     async def _send(self, data): | ||||
|         """ | ||||
|         Send data on the wrapped websocket | ||||
|         """ | ||||
|         await self._wrapped_ws.send(data) | ||||
|     @property | ||||
|     def avg_send_time(self): | ||||
|         return sum(self._send_times) / len(self._send_times) | ||||
|  | ||||
|     async def send(self, data) -> bool: | ||||
|     def _calc_send_limit(self): | ||||
|         """ | ||||
|         Add the data to the queue to be sent. | ||||
|         :returns: True if data added to queue, False otherwise | ||||
|         Calculate the send high limit for this channel | ||||
|         """ | ||||
|  | ||||
|         # This block only runs if the queue is full, it will wait | ||||
|         # until self.drain_timeout for the relay to drain the outgoing queue | ||||
|         # We can't use asyncio.wait_for here because the queue may have been created with a | ||||
|         # different eventloop | ||||
|         if not self.is_closed(): | ||||
|             start = time.time() | ||||
|             while self.queue.full(): | ||||
|                 await asyncio.sleep(1) | ||||
|                 if (time.time() - start) > self.drain_timeout: | ||||
|                     return False | ||||
|         # Only update if we have enough data | ||||
|         if len(self._send_times) == self._send_times.maxlen: | ||||
|             # At least 1s or twice the average of send times, with a | ||||
|             # maximum of 3 seconds per message | ||||
|             self._send_high_limit = min(max(self.avg_send_time * 2, 1), 3) | ||||
|  | ||||
|             # If for some reason the queue is still full, just return False | ||||
|             try: | ||||
|                 self.queue.put_nowait(data) | ||||
|             except asyncio.QueueFull: | ||||
|                 return False | ||||
|     async def send( | ||||
|         self, | ||||
|         message: Union[WSMessageSchemaType, Dict[str, Any]], | ||||
|         timeout: bool = False | ||||
|     ): | ||||
|         """ | ||||
|         Send a message on the wrapped websocket. If the sending | ||||
|         takes too long, it will raise a TimeoutError and | ||||
|         disconnect the connection. | ||||
|  | ||||
|             # If we got here everything is ok | ||||
|             return True | ||||
|         else: | ||||
|             return False | ||||
|         :param message: The message to send | ||||
|         :param timeout: Enforce send high limit, defaults to False | ||||
|         """ | ||||
|         try: | ||||
|             _ = time.time() | ||||
|             # If the send times out, it will raise | ||||
|             # a TimeoutError and bubble up to the | ||||
|             # message_endpoint to close the connection | ||||
|             await asyncio.wait_for( | ||||
|                 self._wrapped_ws.send(message), | ||||
|                 timeout=self._send_high_limit if timeout else None | ||||
|             ) | ||||
|             total_time = time.time() - _ | ||||
|             self._send_times.append(total_time) | ||||
|  | ||||
|             self._calc_send_limit() | ||||
|         except asyncio.TimeoutError: | ||||
|             logger.info(f"Connection for {self} timed out, disconnecting") | ||||
|             raise | ||||
|  | ||||
|         # Explicitly give control back to event loop as | ||||
|         # websockets.send does not | ||||
|         await asyncio.sleep(0.01) | ||||
|  | ||||
|     async def recv(self): | ||||
|         """ | ||||
|         Receive data on the wrapped websocket | ||||
|         Receive a message on the wrapped websocket | ||||
|         """ | ||||
|         return await self._wrapped_ws.recv() | ||||
|  | ||||
| @@ -107,17 +120,27 @@ class WebSocketChannel: | ||||
|         """ | ||||
|         return await self._websocket.ping() | ||||
|  | ||||
|     async def accept(self): | ||||
|         """ | ||||
|         Accept the underlying websocket connection, | ||||
|         if the connection has been closed before we can | ||||
|         accept, just close the channel. | ||||
|         """ | ||||
|         try: | ||||
|             return await self._websocket.accept() | ||||
|         except RuntimeError: | ||||
|             await self.close() | ||||
|  | ||||
|     async def close(self): | ||||
|         """ | ||||
|         Close the WebSocketChannel | ||||
|         """ | ||||
|  | ||||
|         self._closed.set() | ||||
|         self._relay_task.cancel() | ||||
|  | ||||
|         try: | ||||
|             await self.raw_websocket.close() | ||||
|         except Exception: | ||||
|             await self._websocket.close() | ||||
|         except RuntimeError: | ||||
|             pass | ||||
|  | ||||
|     def is_closed(self) -> bool: | ||||
| @@ -142,99 +165,76 @@ class WebSocketChannel: | ||||
|         """ | ||||
|         return message_type in self._subscriptions | ||||
|  | ||||
|     async def relay(self): | ||||
|     async def run_channel_tasks(self, *tasks, **kwargs): | ||||
|         """ | ||||
|         Relay messages from the channel's queue and send them out. This is started | ||||
|         as a task. | ||||
|         Create and await on the channel tasks unless an exception | ||||
|         was raised, then cancel them all. | ||||
|  | ||||
|         :params *tasks: All coros or tasks to be run concurrently | ||||
|         :param **kwargs: Any extra kwargs to pass to gather | ||||
|         """ | ||||
|         while not self._closed.is_set(): | ||||
|             message = await self.queue.get() | ||||
|  | ||||
|         if not self.is_closed(): | ||||
|             # Wrap the coros into tasks if they aren't already | ||||
|             self._channel_tasks = [ | ||||
|                 task if isinstance(task, asyncio.Task) else asyncio.create_task(task) | ||||
|                 for task in tasks | ||||
|             ] | ||||
|  | ||||
|             try: | ||||
|                 await self._send(message) | ||||
|                 self.queue.task_done() | ||||
|                 return await asyncio.gather(*self._channel_tasks, **kwargs) | ||||
|             except Exception: | ||||
|                 # If an exception occurred, cancel the rest of the tasks | ||||
|                 await self.cancel_channel_tasks() | ||||
|  | ||||
|                 # Limit messages per sec. | ||||
|                 # Could cause problems with queue size if too low, and | ||||
|                 # problems with network traffik if too high. | ||||
|                 # 0.01 = 100/s | ||||
|                 await asyncio.sleep(self.throttle) | ||||
|             except RuntimeError: | ||||
|                 # The connection was closed, just exit the task | ||||
|                 return | ||||
|  | ||||
|  | ||||
| class ChannelManager: | ||||
|     def __init__(self): | ||||
|         self.channels = dict() | ||||
|         self._lock = RLock()  # Re-entrant Lock | ||||
|  | ||||
|     async def on_connect(self, websocket: WebSocketType): | ||||
|     async def cancel_channel_tasks(self): | ||||
|         """ | ||||
|         Wrap websocket connection into Channel and add to list | ||||
|  | ||||
|         :param websocket: The WebSocket object to attach to the Channel | ||||
|         Cancel and wait on all channel tasks | ||||
|         """ | ||||
|         if isinstance(websocket, FastAPIWebSocket): | ||||
|         for task in self._channel_tasks: | ||||
|             task.cancel() | ||||
|  | ||||
|             # Wait for tasks to finish cancelling | ||||
|             try: | ||||
|                 await websocket.accept() | ||||
|             except RuntimeError: | ||||
|                 # The connection was closed before we could accept it | ||||
|                 return | ||||
|                 await task | ||||
|             except ( | ||||
|                 asyncio.CancelledError, | ||||
|                 asyncio.TimeoutError, | ||||
|                 WebSocketDisconnect, | ||||
|                 ConnectionClosed, | ||||
|                 RuntimeError | ||||
|             ): | ||||
|                 pass | ||||
|             except Exception as e: | ||||
|                 logger.info(f"Encountered unknown exception: {e}", exc_info=e) | ||||
|  | ||||
|         ws_channel = WebSocketChannel(websocket) | ||||
|         self._channel_tasks = [] | ||||
|  | ||||
|         with self._lock: | ||||
|             self.channels[websocket] = ws_channel | ||||
|  | ||||
|         return ws_channel | ||||
|  | ||||
|     async def on_disconnect(self, websocket: WebSocketType): | ||||
|     async def __aiter__(self): | ||||
|         """ | ||||
|         Call close on the channel if it's not, and remove from channel list | ||||
|         Generator for received messages | ||||
|         """ | ||||
|         # We can not catch any errors here as websocket.recv is | ||||
|         # the first to catch any disconnects and bubble it up | ||||
|         # so the connection is garbage collected right away | ||||
|         while not self.is_closed(): | ||||
|             yield await self.recv() | ||||
|  | ||||
|         :param websocket: The WebSocket objet attached to the Channel | ||||
|         """ | ||||
|         with self._lock: | ||||
|             channel = self.channels.get(websocket) | ||||
|             if channel: | ||||
|                 logger.info(f"Disconnecting channel {channel}") | ||||
|                 if not channel.is_closed(): | ||||
|                     await channel.close() | ||||
|  | ||||
|                 del self.channels[websocket] | ||||
| @asynccontextmanager | ||||
| async def create_channel( | ||||
|     websocket: WebSocketType, | ||||
|     **kwargs | ||||
| ) -> AsyncIterator[WebSocketChannel]: | ||||
|     """ | ||||
|     Context manager for safely opening and closing a WebSocketChannel | ||||
|     """ | ||||
|     channel = WebSocketChannel(websocket, **kwargs) | ||||
|     try: | ||||
|         await channel.accept() | ||||
|         logger.info(f"Connected to channel - {channel}") | ||||
|  | ||||
|     async def disconnect_all(self): | ||||
|         """ | ||||
|         Disconnect all Channels | ||||
|         """ | ||||
|         with self._lock: | ||||
|             for websocket in self.channels.copy().keys(): | ||||
|                 await self.on_disconnect(websocket) | ||||
|  | ||||
|     async def broadcast(self, message: WSMessageSchemaType): | ||||
|         """ | ||||
|         Broadcast a message on all Channels | ||||
|  | ||||
|         :param message: The message to send | ||||
|         """ | ||||
|         with self._lock: | ||||
|             for channel in self.channels.copy().values(): | ||||
|                 if channel.subscribed_to(message.get('type')): | ||||
|                     await self.send_direct(channel, message) | ||||
|  | ||||
|     async def send_direct( | ||||
|             self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]): | ||||
|         """ | ||||
|         Send a message directly through direct_channel only | ||||
|  | ||||
|         :param direct_channel: The WebSocketChannel object to send the message through | ||||
|         :param message: The message to send | ||||
|         """ | ||||
|         if not await channel.send(message): | ||||
|             await self.on_disconnect(channel.raw_websocket) | ||||
|  | ||||
|     def has_channels(self): | ||||
|         """ | ||||
|         Flag for more than 0 channels | ||||
|         """ | ||||
|         return len(self.channels) > 0 | ||||
|         yield channel | ||||
|     finally: | ||||
|         await channel.close() | ||||
|         logger.info(f"Disconnected from channel - {channel}") | ||||
|   | ||||
							
								
								
									
										31
									
								
								freqtrade/rpc/api_server/ws/message_stream.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								freqtrade/rpc/api_server/ws/message_stream.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| import asyncio | ||||
| import time | ||||
|  | ||||
|  | ||||
| class MessageStream: | ||||
|     """ | ||||
|     A message stream for consumers to subscribe to, | ||||
|     and for producers to publish to. | ||||
|     """ | ||||
|     def __init__(self): | ||||
|         self._loop = asyncio.get_running_loop() | ||||
|         self._waiter = self._loop.create_future() | ||||
|  | ||||
|     def publish(self, message): | ||||
|         """ | ||||
|         Publish a message to this MessageStream | ||||
|  | ||||
|         :param message: The message to publish | ||||
|         """ | ||||
|         waiter, self._waiter = self._waiter, self._loop.create_future() | ||||
|         waiter.set_result((message, time.time(), self._waiter)) | ||||
|  | ||||
|     async def __aiter__(self): | ||||
|         """ | ||||
|         Iterate over the messages in the message stream | ||||
|         """ | ||||
|         waiter = self._waiter | ||||
|         while True: | ||||
|             # Shield the future from being cancelled by a task waiting on it | ||||
|             message, ts, waiter = await asyncio.shield(waiter) | ||||
|             yield message, ts | ||||
| @@ -1,5 +1,6 @@ | ||||
| import logging | ||||
| from abc import ABC, abstractmethod | ||||
| from typing import Any, Dict, Union | ||||
|  | ||||
| import orjson | ||||
| import rapidjson | ||||
| @@ -7,6 +8,7 @@ from pandas import DataFrame | ||||
|  | ||||
| from freqtrade.misc import dataframe_to_json, json_to_dataframe | ||||
| from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy | ||||
| from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
| @@ -24,17 +26,13 @@ class WebSocketSerializer(ABC): | ||||
|     def _deserialize(self, data): | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     async def send(self, data: bytes): | ||||
|     async def send(self, data: Union[WSMessageSchemaType, Dict[str, Any]]): | ||||
|         await self._websocket.send(self._serialize(data)) | ||||
|  | ||||
|     async def recv(self) -> bytes: | ||||
|         data = await self._websocket.recv() | ||||
|  | ||||
|         return self._deserialize(data) | ||||
|  | ||||
|     async def close(self, code: int = 1000): | ||||
|         await self._websocket.close(code) | ||||
|  | ||||
|  | ||||
| class HybridJSONWebSocketSerializer(WebSocketSerializer): | ||||
|     def _serialize(self, data) -> str: | ||||
|   | ||||
| @@ -57,7 +57,10 @@ def botclient(default_conf, mocker): | ||||
|     try: | ||||
|         apiserver = ApiServer(default_conf) | ||||
|         apiserver.add_rpc_handler(rpc) | ||||
|         yield ftbot, TestClient(apiserver.app) | ||||
|         # We need to use the TestClient as a context manager to | ||||
|         # handle lifespan events correctly | ||||
|         with TestClient(apiserver.app) as client: | ||||
|             yield ftbot, client | ||||
|         # Cleanup ... ? | ||||
|     finally: | ||||
|         if apiserver: | ||||
| @@ -438,7 +441,6 @@ def test_api_cleanup(default_conf, mocker, caplog): | ||||
|     apiserver.cleanup() | ||||
|     assert apiserver._server.cleanup.call_count == 1 | ||||
|     assert log_has("Stopping API Server", caplog) | ||||
|     assert log_has("Stopping API Server background tasks", caplog) | ||||
|     ApiServer.shutdown() | ||||
|  | ||||
|  | ||||
| @@ -1714,12 +1716,14 @@ def test_api_ws_subscribe(botclient, mocker): | ||||
|  | ||||
|     with client.websocket_connect(ws_url) as ws: | ||||
|         ws.send_json({'type': 'subscribe', 'data': ['whitelist']}) | ||||
|         time.sleep(1) | ||||
|  | ||||
|     # Check call count is now 1 as we sent a valid subscribe request | ||||
|     assert sub_mock.call_count == 1 | ||||
|  | ||||
|     with client.websocket_connect(ws_url) as ws: | ||||
|         ws.send_json({'type': 'subscribe', 'data': 'whitelist'}) | ||||
|         time.sleep(1) | ||||
|  | ||||
|     # Call count hasn't changed as the subscribe request was invalid | ||||
|     assert sub_mock.call_count == 1 | ||||
| @@ -1773,24 +1777,18 @@ def test_api_ws_send_msg(default_conf, mocker, caplog): | ||||
|         mocker.patch('freqtrade.rpc.api_server.ApiServer.start_api') | ||||
|         apiserver = ApiServer(default_conf) | ||||
|         apiserver.add_rpc_handler(RPC(get_patched_freqtradebot(mocker, default_conf))) | ||||
|         apiserver.start_message_queue() | ||||
|         # Give the queue thread time to start | ||||
|         time.sleep(0.2) | ||||
|  | ||||
|         # Test message_queue coro receives the message | ||||
|         test_message = {"type": "status", "data": "test"} | ||||
|         apiserver.send_msg(test_message) | ||||
|         time.sleep(0.1)  # Not sure how else to wait for the coro to receive the data | ||||
|         assert log_has("Found message of type: status", caplog) | ||||
|         # Start test client context manager to run lifespan events | ||||
|         with TestClient(apiserver.app): | ||||
|             # Test message is published on the Message Stream | ||||
|             test_message = {"type": "status", "data": "test"} | ||||
|             first_waiter = apiserver._message_stream._waiter | ||||
|             apiserver.send_msg(test_message) | ||||
|             assert first_waiter.result()[0] == test_message | ||||
|  | ||||
|         # Test if exception logged when error occurs in sending | ||||
|         mocker.patch('freqtrade.rpc.api_server.ws.channel.ChannelManager.broadcast', | ||||
|                      side_effect=Exception) | ||||
|  | ||||
|         apiserver.send_msg(test_message) | ||||
|         time.sleep(0.1)  # Not sure how else to wait for the coro to receive the data | ||||
|         assert log_has_re(r"Exception happened in background task.*", caplog) | ||||
|             second_waiter = apiserver._message_stream._waiter | ||||
|             apiserver.send_msg(test_message) | ||||
|             assert first_waiter != second_waiter | ||||
|  | ||||
|     finally: | ||||
|         apiserver.cleanup() | ||||
|         ApiServer.shutdown() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user