diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 785773b39..a9b88aadb 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -1,16 +1,17 @@ +import asyncio import logging from typing import Any, Dict -from fastapi import APIRouter, Depends, WebSocketDisconnect -from fastapi.websockets import WebSocket, WebSocketState +from fastapi import APIRouter, Depends +from fastapi.websockets import WebSocket, WebSocketDisconnect from pydantic import ValidationError -from websockets.exceptions import WebSocketException +from websockets.exceptions import ConnectionClosed from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.rpc.api_server.api_auth import validate_ws_token -from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc +from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc from freqtrade.rpc.api_server.ws import WebSocketChannel -from freqtrade.rpc.api_server.ws.channel import ChannelManager +from freqtrade.rpc.api_server.ws.message_stream import MessageStream from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, WSRequestSchema, WSWhitelistMessage) from freqtrade.rpc.rpc import RPC @@ -22,23 +23,63 @@ logger = logging.getLogger(__name__) router = APIRouter() -async def is_websocket_alive(ws: WebSocket) -> bool: +# async def is_websocket_alive(ws: WebSocket) -> bool: +# """ +# Check if a FastAPI Websocket is still open +# """ +# if ( +# ws.application_state == WebSocketState.CONNECTED and +# ws.client_state == WebSocketState.CONNECTED +# ): +# return True +# return False + + +class WebSocketChannelClosed(Exception): """ - Check if a FastAPI Websocket is still open + General WebSocket exception to signal closing the channel """ - if ( - ws.application_state == WebSocketState.CONNECTED and - ws.client_state == WebSocketState.CONNECTED + pass + + +async def channel_reader(channel: WebSocketChannel, rpc: RPC): + """ + Iterate over the messages from the channel and process the request + """ + try: + async for message in channel: + await _process_consumer_request(message, channel, rpc) + except ( + RuntimeError, + WebSocketDisconnect, + ConnectionClosed ): - return True - return False + raise WebSocketChannelClosed + except asyncio.CancelledError: + return + + +async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream): + """ + Iterate over messages in the message stream and send them + """ + try: + async for message in message_stream: + await channel.send(message) + except ( + RuntimeError, + WebSocketDisconnect, + ConnectionClosed + ): + raise WebSocketChannelClosed + except asyncio.CancelledError: + return async def _process_consumer_request( request: Dict[str, Any], channel: WebSocketChannel, - rpc: RPC, - channel_manager: ChannelManager + rpc: RPC ): """ Validate and handle a request from a websocket consumer @@ -75,7 +116,7 @@ async def _process_consumer_request( # Format response response = WSWhitelistMessage(data=whitelist) # Send it back - await channel_manager.send_direct(channel, response.dict(exclude_none=True)) + await channel.send(response.dict(exclude_none=True)) elif type == RPCRequestType.ANALYZED_DF: limit = None @@ -86,53 +127,76 @@ async def _process_consumer_request( # For every pair in the generator, send a separate message for message in rpc._ws_request_analyzed_df(limit): + # Format response response = WSAnalyzedDFMessage(data=message) - await channel_manager.send_direct(channel, response.dict(exclude_none=True)) + await channel.send(response.dict(exclude_none=True)) @router.websocket("/message/ws") async def message_endpoint( - ws: WebSocket, + websocket: WebSocket, + token: str = Depends(validate_ws_token), rpc: RPC = Depends(get_rpc), - channel_manager=Depends(get_channel_manager), - token: str = Depends(validate_ws_token) + message_stream: MessageStream = Depends(get_message_stream) ): - """ - Message WebSocket endpoint, facilitates sending RPC messages - """ - try: - channel = await channel_manager.on_connect(ws) - if await is_websocket_alive(ws): + async with WebSocketChannel(websocket).connect() as channel: + try: + logger.info(f"Channel connected - {channel}") - logger.info(f"Consumer connected - {channel}") + channel_tasks = asyncio.gather( + channel_reader(channel, rpc), + channel_broadcaster(channel, message_stream) + ) + await channel_tasks - # Keep connection open until explicitly closed, and process requests - try: - while not channel.is_closed(): - request = await channel.recv() + finally: + logger.info(f"Channel disconnected - {channel}") + channel_tasks.cancel() - # Process the request here - await _process_consumer_request(request, channel, rpc, channel_manager) - except (WebSocketDisconnect, WebSocketException): - # Handle client disconnects - logger.info(f"Consumer disconnected - {channel}") - except RuntimeError: - # Handle cases like - - # RuntimeError('Cannot call "send" once a closed message has been sent') - pass - except Exception as e: - logger.info(f"Consumer connection failed - {channel}: {e}") - logger.debug(e, exc_info=e) +# @router.websocket("/message/ws") +# async def message_endpoint( +# ws: WebSocket, +# rpc: RPC = Depends(get_rpc), +# channel_manager=Depends(get_channel_manager), +# token: str = Depends(validate_ws_token) +# ): +# """ +# Message WebSocket endpoint, facilitates sending RPC messages +# """ +# try: +# channel = await channel_manager.on_connect(ws) +# if await is_websocket_alive(ws): - except RuntimeError: - # WebSocket was closed - # Do nothing - pass - except Exception as e: - logger.error(f"Failed to serve - {ws.client}") - # Log tracebacks to keep track of what errors are happening - logger.exception(e) - finally: - if channel: - await channel_manager.on_disconnect(ws) +# logger.info(f"Consumer connected - {channel}") + +# # Keep connection open until explicitly closed, and process requests +# try: +# while not channel.is_closed(): +# request = await channel.recv() + +# # Process the request here +# await _process_consumer_request(request, channel, rpc, channel_manager) + +# except (WebSocketDisconnect, WebSocketException): +# # Handle client disconnects +# logger.info(f"Consumer disconnected - {channel}") +# except RuntimeError: +# # Handle cases like - +# # RuntimeError('Cannot call "send" once a closed message has been sent') +# pass +# except Exception as e: +# logger.info(f"Consumer connection failed - {channel}: {e}") +# logger.debug(e, exc_info=e) + +# except RuntimeError: +# # WebSocket was closed +# # Do nothing +# pass +# except Exception as e: +# logger.error(f"Failed to serve - {ws.client}") +# # Log tracebacks to keep track of what errors are happening +# logger.exception(e) +# finally: +# if channel: +# await channel_manager.on_disconnect(ws) diff --git a/freqtrade/rpc/api_server/deps.py b/freqtrade/rpc/api_server/deps.py index abd3db036..aed97367b 100644 --- a/freqtrade/rpc/api_server/deps.py +++ b/freqtrade/rpc/api_server/deps.py @@ -41,8 +41,8 @@ def get_exchange(config=Depends(get_config)): return ApiServer._exchange -def get_channel_manager(): - return ApiServer._ws_channel_manager +def get_message_stream(): + return ApiServer._message_stream def is_webserver_mode(config=Depends(get_config)): diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index e9a12e4df..7e2c3f39f 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -1,7 +1,6 @@ import asyncio import logging from ipaddress import IPv4Address -from threading import Thread from typing import Any, Dict import orjson @@ -15,7 +14,7 @@ from starlette.responses import JSONResponse from freqtrade.constants import Config from freqtrade.exceptions import OperationalException from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer -from freqtrade.rpc.api_server.ws import ChannelManager +from freqtrade.rpc.api_server.ws.message_stream import MessageStream from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler @@ -51,9 +50,10 @@ class ApiServer(RPCHandler): # Exchange - only available in webserver mode. _exchange = None # websocket message queue stuff - _ws_channel_manager = None - _ws_thread = None - _ws_loop = None + # _ws_channel_manager = None + # _ws_thread = None + # _ws_loop = None + _message_stream = None def __new__(cls, *args, **kwargs): """ @@ -71,14 +71,15 @@ class ApiServer(RPCHandler): return self._standalone: bool = standalone self._server = None + self._ws_queue = None - self._ws_background_task = None + self._ws_publisher_task = None ApiServer.__initialized = True api_config = self._config['api_server'] - ApiServer._ws_channel_manager = ChannelManager() + # ApiServer._ws_channel_manager = ChannelManager() self.app = FastAPI(title="Freqtrade API", docs_url='/docs' if api_config.get('enable_openapi', False) else None, @@ -107,18 +108,18 @@ class ApiServer(RPCHandler): logger.info("Stopping API Server") self._server.cleanup() - if self._ws_thread and self._ws_loop: - logger.info("Stopping API Server background tasks") + # if self._ws_thread and self._ws_loop: + # logger.info("Stopping API Server background tasks") - if self._ws_background_task: - # Cancel the queue task - self._ws_background_task.cancel() + # if self._ws_background_task: + # # Cancel the queue task + # self._ws_background_task.cancel() - self._ws_thread.join() + # self._ws_thread.join() - self._ws_thread = None - self._ws_loop = None - self._ws_background_task = None + # self._ws_thread = None + # self._ws_loop = None + # self._ws_background_task = None @classmethod def shutdown(cls): @@ -170,51 +171,102 @@ class ApiServer(RPCHandler): ) app.add_exception_handler(RPCException, self.handle_rpc_exception) + app.add_event_handler( + event_type="startup", + func=self._api_startup_event + ) + app.add_event_handler( + event_type="shutdown", + func=self._api_shutdown_event + ) - def start_message_queue(self): - if self._ws_thread: - return + async def _api_startup_event(self): + if not ApiServer._message_stream: + ApiServer._message_stream = MessageStream() - # Create a new loop, as it'll be just for the background thread - self._ws_loop = asyncio.new_event_loop() + if not self._ws_queue: + self._ws_queue = ThreadedQueue() - # Start the thread - self._ws_thread = Thread(target=self._ws_loop.run_forever) - self._ws_thread.start() + if not self._ws_publisher_task: + self._ws_publisher_task = asyncio.create_task( + self._publish_messages() + ) - # Finally, submit the coro to the thread - self._ws_background_task = asyncio.run_coroutine_threadsafe( - self._broadcast_queue_data(), loop=self._ws_loop) + async def _api_shutdown_event(self): + if ApiServer._message_stream: + ApiServer._message_stream = None - async def _broadcast_queue_data(self): - # Instantiate the queue in this coroutine so it's attached to our loop - self._ws_queue = ThreadedQueue() - async_queue = self._ws_queue.async_q - - try: - while True: - logger.debug("Getting queue messages...") - # Get data from queue - message: WSMessageSchemaType = await async_queue.get() - logger.debug(f"Found message of type: {message.get('type')}") - async_queue.task_done() - # Broadcast it - await self._ws_channel_manager.broadcast(message) - except asyncio.CancelledError: - pass - - # For testing, shouldn't happen when stable - except Exception as e: - logger.exception(f"Exception happened in background task: {e}") - - finally: - # Disconnect channels and stop the loop on cancel - await self._ws_channel_manager.disconnect_all() - self._ws_loop.stop() - # Avoid adding more items to the queue if they aren't - # going to get broadcasted. + if self._ws_queue: self._ws_queue = None + if self._ws_publisher_task: + self._ws_publisher_task.cancel() + + async def _publish_messages(self): + """ + Background task that reads messages from the queue and adds them + to the message stream + """ + try: + async_queue = self._ws_queue.async_q + message_stream = ApiServer._message_stream + + while message_stream: + message: WSMessageSchemaType = await async_queue.get() + message_stream.publish(message) + + # Make sure to throttle how fast we + # publish messages as some clients will be + # slower than others + await asyncio.sleep(0.01) + async_queue.task_done() + finally: + self._ws_queue = None + + # def start_message_queue(self): + # if self._ws_thread: + # return + + # # Create a new loop, as it'll be just for the background thread + # self._ws_loop = asyncio.new_event_loop() + + # # Start the thread + # self._ws_thread = Thread(target=self._ws_loop.run_forever) + # self._ws_thread.start() + + # # Finally, submit the coro to the thread + # self._ws_background_task = asyncio.run_coroutine_threadsafe( + # self._broadcast_queue_data(), loop=self._ws_loop) + + # async def _broadcast_queue_data(self): + # # Instantiate the queue in this coroutine so it's attached to our loop + # self._ws_queue = ThreadedQueue() + # async_queue = self._ws_queue.async_q + + # try: + # while True: + # logger.debug("Getting queue messages...") + # # Get data from queue + # message: WSMessageSchemaType = await async_queue.get() + # logger.debug(f"Found message of type: {message.get('type')}") + # async_queue.task_done() + # # Broadcast it + # await self._ws_channel_manager.broadcast(message) + # except asyncio.CancelledError: + # pass + + # # For testing, shouldn't happen when stable + # except Exception as e: + # logger.exception(f"Exception happened in background task: {e}") + + # finally: + # # Disconnect channels and stop the loop on cancel + # await self._ws_channel_manager.disconnect_all() + # self._ws_loop.stop() + # # Avoid adding more items to the queue if they aren't + # # going to get broadcasted. + # self._ws_queue = None + def start_api(self): """ Start API ... should be run in thread. @@ -253,7 +305,7 @@ class ApiServer(RPCHandler): if self._standalone: self._server.run() else: - self.start_message_queue() + # self.start_message_queue() self._server.run_in_thread() except Exception: logger.exception("Api server failed to start.") diff --git a/freqtrade/rpc/api_server/ws/__init__.py b/freqtrade/rpc/api_server/ws/__init__.py index 055b20a9d..0b94d3fee 100644 --- a/freqtrade/rpc/api_server/ws/__init__.py +++ b/freqtrade/rpc/api_server/ws/__init__.py @@ -3,4 +3,5 @@ from freqtrade.rpc.api_server.ws.types import WebSocketType from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.serializer import HybridJSONWebSocketSerializer -from freqtrade.rpc.api_server.ws.channel import ChannelManager, WebSocketChannel +from freqtrade.rpc.api_server.ws.channel import WebSocketChannel +from freqtrade.rpc.api_server.ws.message_stream import MessageStream diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 88b4db9ba..b98bd13c9 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -1,12 +1,9 @@ import asyncio import logging -import time -from threading import RLock +from contextlib import asynccontextmanager from typing import Any, Dict, List, Optional, Type, Union from uuid import uuid4 -from fastapi import WebSocket as FastAPIWebSocket - from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer, WebSocketSerializer) @@ -21,32 +18,21 @@ class WebSocketChannel: """ Object to help facilitate managing a websocket connection """ - def __init__( self, websocket: WebSocketType, channel_id: Optional[str] = None, - drain_timeout: int = 3, - throttle: float = 0.01, serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer ): - self.channel_id = channel_id if channel_id else uuid4().hex[:8] - - # The WebSocket object self._websocket = WebSocketProxy(websocket) - self.drain_timeout = drain_timeout - self.throttle = throttle - - self._subscriptions: List[str] = [] - # 32 is the size of the receiving queue in websockets package - self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32) - self._relay_task = asyncio.create_task(self.relay()) - # Internal event to signify a closed websocket self._closed = asyncio.Event() + # Throttle how fast we send messages + self._throttle = 0.01 + # Wrap the WebSocket in the Serializing class self._wrapped_ws = serializer_cls(self._websocket) @@ -61,40 +47,16 @@ class WebSocketChannel: def remote_addr(self): return self._websocket.remote_addr - async def _send(self, data): + async def send(self, message: Union[WSMessageSchemaType, Dict[str, Any]]): """ - Send data on the wrapped websocket + Send a message on the wrapped websocket """ - await self._wrapped_ws.send(data) - - async def send(self, data) -> bool: - """ - Add the data to the queue to be sent. - :returns: True if data added to queue, False otherwise - """ - - # This block only runs if the queue is full, it will wait - # until self.drain_timeout for the relay to drain the outgoing queue - # We can't use asyncio.wait_for here because the queue may have been created with a - # different eventloop - start = time.time() - while self.queue.full(): - await asyncio.sleep(1) - if (time.time() - start) > self.drain_timeout: - return False - - # If for some reason the queue is still full, just return False - try: - self.queue.put_nowait(data) - except asyncio.QueueFull: - return False - - # If we got here everything is ok - return True + await asyncio.sleep(self._throttle) + await self._wrapped_ws.send(message) async def recv(self): """ - Receive data on the wrapped websocket + Receive a message on the wrapped websocket """ return await self._wrapped_ws.recv() @@ -104,18 +66,23 @@ class WebSocketChannel: """ return await self._websocket.ping() + async def accept(self): + """ + Accept the underlying websocket connection + """ + return await self._websocket.accept() + async def close(self): """ Close the WebSocketChannel """ try: - await self.raw_websocket.close() + await self._websocket.close() except Exception: pass self._closed.set() - self._relay_task.cancel() def is_closed(self) -> bool: """ @@ -139,99 +106,243 @@ class WebSocketChannel: """ return message_type in self._subscriptions - async def relay(self): + async def __aiter__(self): """ - Relay messages from the channel's queue and send them out. This is started - as a task. + Generator for received messages """ - while not self._closed.is_set(): - message = await self.queue.get() + while True: try: - await self._send(message) - self.queue.task_done() + yield await self.recv() + except Exception: + break - # Limit messages per sec. - # Could cause problems with queue size if too low, and - # problems with network traffik if too high. - # 0.01 = 100/s - await asyncio.sleep(self.throttle) - except RuntimeError: - # The connection was closed, just exit the task - return - - -class ChannelManager: - def __init__(self): - self.channels = dict() - self._lock = RLock() # Re-entrant Lock - - async def on_connect(self, websocket: WebSocketType): + @asynccontextmanager + async def connect(self): """ - Wrap websocket connection into Channel and add to list - - :param websocket: The WebSocket object to attach to the Channel + Context manager for safely opening and closing the websocket connection """ - if isinstance(websocket, FastAPIWebSocket): - try: - await websocket.accept() - except RuntimeError: - # The connection was closed before we could accept it - return + try: + await self.accept() + yield self + finally: + await self.close() - ws_channel = WebSocketChannel(websocket) - with self._lock: - self.channels[websocket] = ws_channel +# class WebSocketChannel: +# """ +# Object to help facilitate managing a websocket connection +# """ - return ws_channel +# def __init__( +# self, +# websocket: WebSocketType, +# channel_id: Optional[str] = None, +# drain_timeout: int = 3, +# throttle: float = 0.01, +# serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer +# ): - async def on_disconnect(self, websocket: WebSocketType): - """ - Call close on the channel if it's not, and remove from channel list +# self.channel_id = channel_id if channel_id else uuid4().hex[:8] - :param websocket: The WebSocket objet attached to the Channel - """ - with self._lock: - channel = self.channels.get(websocket) - if channel: - logger.info(f"Disconnecting channel {channel}") - if not channel.is_closed(): - await channel.close() +# # The WebSocket object +# self._websocket = WebSocketProxy(websocket) - del self.channels[websocket] +# self.drain_timeout = drain_timeout +# self.throttle = throttle - async def disconnect_all(self): - """ - Disconnect all Channels - """ - with self._lock: - for websocket in self.channels.copy().keys(): - await self.on_disconnect(websocket) +# self._subscriptions: List[str] = [] +# # 32 is the size of the receiving queue in websockets package +# self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32) +# self._relay_task = asyncio.create_task(self.relay()) - async def broadcast(self, message: WSMessageSchemaType): - """ - Broadcast a message on all Channels +# # Internal event to signify a closed websocket +# self._closed = asyncio.Event() - :param message: The message to send - """ - with self._lock: - for channel in self.channels.copy().values(): - if channel.subscribed_to(message.get('type')): - await self.send_direct(channel, message) +# # Wrap the WebSocket in the Serializing class +# self._wrapped_ws = serializer_cls(self._websocket) - async def send_direct( - self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]): - """ - Send a message directly through direct_channel only +# def __repr__(self): +# return f"WebSocketChannel({self.channel_id}, {self.remote_addr})" - :param direct_channel: The WebSocketChannel object to send the message through - :param message: The message to send - """ - if not await channel.send(message): - await self.on_disconnect(channel.raw_websocket) +# @property +# def raw_websocket(self): +# return self._websocket.raw_websocket - def has_channels(self): - """ - Flag for more than 0 channels - """ - return len(self.channels) > 0 +# @property +# def remote_addr(self): +# return self._websocket.remote_addr + +# async def _send(self, data): +# """ +# Send data on the wrapped websocket +# """ +# await self._wrapped_ws.send(data) + +# async def send(self, data) -> bool: +# """ +# Add the data to the queue to be sent. +# :returns: True if data added to queue, False otherwise +# """ + +# # This block only runs if the queue is full, it will wait +# # until self.drain_timeout for the relay to drain the outgoing queue +# # We can't use asyncio.wait_for here because the queue may have been created with a +# # different eventloop +# start = time.time() +# while self.queue.full(): +# await asyncio.sleep(1) +# if (time.time() - start) > self.drain_timeout: +# return False + +# # If for some reason the queue is still full, just return False +# try: +# self.queue.put_nowait(data) +# except asyncio.QueueFull: +# return False + +# # If we got here everything is ok +# return True + +# async def recv(self): +# """ +# Receive data on the wrapped websocket +# """ +# return await self._wrapped_ws.recv() + +# async def ping(self): +# """ +# Ping the websocket +# """ +# return await self._websocket.ping() + +# async def close(self): +# """ +# Close the WebSocketChannel +# """ + +# try: +# await self.raw_websocket.close() +# except Exception: +# pass + +# self._closed.set() +# self._relay_task.cancel() + +# def is_closed(self) -> bool: +# """ +# Closed flag +# """ +# return self._closed.is_set() + +# def set_subscriptions(self, subscriptions: List[str] = []) -> None: +# """ +# Set which subscriptions this channel is subscribed to + +# :param subscriptions: List of subscriptions, List[str] +# """ +# self._subscriptions = subscriptions + +# def subscribed_to(self, message_type: str) -> bool: +# """ +# Check if this channel is subscribed to the message_type + +# :param message_type: The message type to check +# """ +# return message_type in self._subscriptions + +# async def relay(self): +# """ +# Relay messages from the channel's queue and send them out. This is started +# as a task. +# """ +# while not self._closed.is_set(): +# message = await self.queue.get() +# try: +# await self._send(message) +# self.queue.task_done() + +# # Limit messages per sec. +# # Could cause problems with queue size if too low, and +# # problems with network traffik if too high. +# # 0.01 = 100/s +# await asyncio.sleep(self.throttle) +# except RuntimeError: +# # The connection was closed, just exit the task +# return + + +# class ChannelManager: +# def __init__(self): +# self.channels = dict() +# self._lock = RLock() # Re-entrant Lock + +# async def on_connect(self, websocket: WebSocketType): +# """ +# Wrap websocket connection into Channel and add to list + +# :param websocket: The WebSocket object to attach to the Channel +# """ +# if isinstance(websocket, FastAPIWebSocket): +# try: +# await websocket.accept() +# except RuntimeError: +# # The connection was closed before we could accept it +# return + +# ws_channel = WebSocketChannel(websocket) + +# with self._lock: +# self.channels[websocket] = ws_channel + +# return ws_channel + +# async def on_disconnect(self, websocket: WebSocketType): +# """ +# Call close on the channel if it's not, and remove from channel list + +# :param websocket: The WebSocket objet attached to the Channel +# """ +# with self._lock: +# channel = self.channels.get(websocket) +# if channel: +# logger.info(f"Disconnecting channel {channel}") +# if not channel.is_closed(): +# await channel.close() + +# del self.channels[websocket] + +# async def disconnect_all(self): +# """ +# Disconnect all Channels +# """ +# with self._lock: +# for websocket in self.channels.copy().keys(): +# await self.on_disconnect(websocket) + +# async def broadcast(self, message: WSMessageSchemaType): +# """ +# Broadcast a message on all Channels + +# :param message: The message to send +# """ +# with self._lock: +# for channel in self.channels.copy().values(): +# if channel.subscribed_to(message.get('type')): +# await self.send_direct(channel, message) + +# async def send_direct( +# self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]): +# """ +# Send a message directly through direct_channel only + +# :param direct_channel: The WebSocketChannel object to send the message through +# :param message: The message to send +# """ +# if not await channel.send(message): +# await self.on_disconnect(channel.raw_websocket) + +# def has_channels(self): +# """ +# Flag for more than 0 channels +# """ +# return len(self.channels) > 0 diff --git a/freqtrade/rpc/api_server/ws/message_stream.py b/freqtrade/rpc/api_server/ws/message_stream.py new file mode 100644 index 000000000..f77242719 --- /dev/null +++ b/freqtrade/rpc/api_server/ws/message_stream.py @@ -0,0 +1,23 @@ +import asyncio + + +class MessageStream: + """ + A message stream for consumers to subscribe to, + and for producers to publish to. + """ + def __init__(self): + self._loop = asyncio.get_running_loop() + self._waiter = self._loop.create_future() + + def publish(self, message): + waiter, self._waiter = self._waiter, self._loop.create_future() + waiter.set_result((message, self._waiter)) + + async def subscribe(self): + waiter = self._waiter + while True: + message, waiter = await waiter + yield message + + __aiter__ = subscribe diff --git a/freqtrade/rpc/api_server/ws/serializer.py b/freqtrade/rpc/api_server/ws/serializer.py index 6c402a100..85703136b 100644 --- a/freqtrade/rpc/api_server/ws/serializer.py +++ b/freqtrade/rpc/api_server/ws/serializer.py @@ -1,5 +1,6 @@ import logging from abc import ABC, abstractmethod +from typing import Any, Dict, Union import orjson import rapidjson @@ -7,6 +8,7 @@ from pandas import DataFrame from freqtrade.misc import dataframe_to_json, json_to_dataframe from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy +from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType logger = logging.getLogger(__name__) @@ -24,7 +26,7 @@ class WebSocketSerializer(ABC): def _deserialize(self, data): raise NotImplementedError() - async def send(self, data: bytes): + async def send(self, data: Union[WSMessageSchemaType, Dict[str, Any]]): await self._websocket.send(self._serialize(data)) async def recv(self) -> bytes: @@ -32,8 +34,8 @@ class WebSocketSerializer(ABC): return self._deserialize(data) - async def close(self, code: int = 1000): - await self._websocket.close(code) + # async def close(self, code: int = 1000): + # await self._websocket.close(code) class HybridJSONWebSocketSerializer(WebSocketSerializer):