initial revision
This commit is contained in:
		| @@ -1,16 +1,17 @@ | ||||
| import asyncio | ||||
| import logging | ||||
| from typing import Any, Dict | ||||
|  | ||||
| from fastapi import APIRouter, Depends, WebSocketDisconnect | ||||
| from fastapi.websockets import WebSocket, WebSocketState | ||||
| from fastapi import APIRouter, Depends | ||||
| from fastapi.websockets import WebSocket, WebSocketDisconnect | ||||
| from pydantic import ValidationError | ||||
| from websockets.exceptions import WebSocketException | ||||
| from websockets.exceptions import ConnectionClosed | ||||
|  | ||||
| from freqtrade.enums import RPCMessageType, RPCRequestType | ||||
| from freqtrade.rpc.api_server.api_auth import validate_ws_token | ||||
| from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc | ||||
| from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc | ||||
| from freqtrade.rpc.api_server.ws import WebSocketChannel | ||||
| from freqtrade.rpc.api_server.ws.channel import ChannelManager | ||||
| from freqtrade.rpc.api_server.ws.message_stream import MessageStream | ||||
| from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, | ||||
|                                                  WSRequestSchema, WSWhitelistMessage) | ||||
| from freqtrade.rpc.rpc import RPC | ||||
| @@ -22,23 +23,63 @@ logger = logging.getLogger(__name__) | ||||
| router = APIRouter() | ||||
|  | ||||
|  | ||||
| async def is_websocket_alive(ws: WebSocket) -> bool: | ||||
| # async def is_websocket_alive(ws: WebSocket) -> bool: | ||||
| #     """ | ||||
| #     Check if a FastAPI Websocket is still open | ||||
| #     """ | ||||
| #     if ( | ||||
| #         ws.application_state == WebSocketState.CONNECTED and | ||||
| #         ws.client_state == WebSocketState.CONNECTED | ||||
| #     ): | ||||
| #         return True | ||||
| #     return False | ||||
|  | ||||
|  | ||||
| class WebSocketChannelClosed(Exception): | ||||
|     """ | ||||
|     Check if a FastAPI Websocket is still open | ||||
|     General WebSocket exception to signal closing the channel | ||||
|     """ | ||||
|     if ( | ||||
|         ws.application_state == WebSocketState.CONNECTED and | ||||
|         ws.client_state == WebSocketState.CONNECTED | ||||
|     pass | ||||
|  | ||||
|  | ||||
| async def channel_reader(channel: WebSocketChannel, rpc: RPC): | ||||
|     """ | ||||
|     Iterate over the messages from the channel and process the request | ||||
|     """ | ||||
|     try: | ||||
|         async for message in channel: | ||||
|             await _process_consumer_request(message, channel, rpc) | ||||
|     except ( | ||||
|         RuntimeError, | ||||
|         WebSocketDisconnect, | ||||
|         ConnectionClosed | ||||
|     ): | ||||
|         return True | ||||
|     return False | ||||
|         raise WebSocketChannelClosed | ||||
|     except asyncio.CancelledError: | ||||
|         return | ||||
|  | ||||
|  | ||||
| async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream): | ||||
|     """ | ||||
|     Iterate over messages in the message stream and send them | ||||
|     """ | ||||
|     try: | ||||
|         async for message in message_stream: | ||||
|             await channel.send(message) | ||||
|     except ( | ||||
|         RuntimeError, | ||||
|         WebSocketDisconnect, | ||||
|         ConnectionClosed | ||||
|     ): | ||||
|         raise WebSocketChannelClosed | ||||
|     except asyncio.CancelledError: | ||||
|         return | ||||
|  | ||||
|  | ||||
| async def _process_consumer_request( | ||||
|     request: Dict[str, Any], | ||||
|     channel: WebSocketChannel, | ||||
|     rpc: RPC, | ||||
|     channel_manager: ChannelManager | ||||
|     rpc: RPC | ||||
| ): | ||||
|     """ | ||||
|     Validate and handle a request from a websocket consumer | ||||
| @@ -75,7 +116,7 @@ async def _process_consumer_request( | ||||
|         # Format response | ||||
|         response = WSWhitelistMessage(data=whitelist) | ||||
|         # Send it back | ||||
|         await channel_manager.send_direct(channel, response.dict(exclude_none=True)) | ||||
|         await channel.send(response.dict(exclude_none=True)) | ||||
|  | ||||
|     elif type == RPCRequestType.ANALYZED_DF: | ||||
|         limit = None | ||||
| @@ -86,53 +127,76 @@ async def _process_consumer_request( | ||||
|  | ||||
|         # For every pair in the generator, send a separate message | ||||
|         for message in rpc._ws_request_analyzed_df(limit): | ||||
|             # Format response | ||||
|             response = WSAnalyzedDFMessage(data=message) | ||||
|             await channel_manager.send_direct(channel, response.dict(exclude_none=True)) | ||||
|             await channel.send(response.dict(exclude_none=True)) | ||||
|  | ||||
|  | ||||
| @router.websocket("/message/ws") | ||||
| async def message_endpoint( | ||||
|     ws: WebSocket, | ||||
|     websocket: WebSocket, | ||||
|     token: str = Depends(validate_ws_token), | ||||
|     rpc: RPC = Depends(get_rpc), | ||||
|     channel_manager=Depends(get_channel_manager), | ||||
|     token: str = Depends(validate_ws_token) | ||||
|     message_stream: MessageStream = Depends(get_message_stream) | ||||
| ): | ||||
|     """ | ||||
|     Message WebSocket endpoint, facilitates sending RPC messages | ||||
|     """ | ||||
|     try: | ||||
|         channel = await channel_manager.on_connect(ws) | ||||
|         if await is_websocket_alive(ws): | ||||
|     async with WebSocketChannel(websocket).connect() as channel: | ||||
|         try: | ||||
|             logger.info(f"Channel connected - {channel}") | ||||
|  | ||||
|             logger.info(f"Consumer connected - {channel}") | ||||
|             channel_tasks = asyncio.gather( | ||||
|                 channel_reader(channel, rpc), | ||||
|                 channel_broadcaster(channel, message_stream) | ||||
|             ) | ||||
|             await channel_tasks | ||||
|  | ||||
|             # Keep connection open until explicitly closed, and process requests | ||||
|             try: | ||||
|                 while not channel.is_closed(): | ||||
|                     request = await channel.recv() | ||||
|         finally: | ||||
|             logger.info(f"Channel disconnected - {channel}") | ||||
|             channel_tasks.cancel() | ||||
|  | ||||
|                     # Process the request here | ||||
|                     await _process_consumer_request(request, channel, rpc, channel_manager) | ||||
|  | ||||
|             except (WebSocketDisconnect, WebSocketException): | ||||
|                 # Handle client disconnects | ||||
|                 logger.info(f"Consumer disconnected - {channel}") | ||||
|             except RuntimeError: | ||||
|                 # Handle cases like - | ||||
|                 # RuntimeError('Cannot call "send" once a closed message has been sent') | ||||
|                 pass | ||||
|             except Exception as e: | ||||
|                 logger.info(f"Consumer connection failed - {channel}: {e}") | ||||
|                 logger.debug(e, exc_info=e) | ||||
| # @router.websocket("/message/ws") | ||||
| # async def message_endpoint( | ||||
| #     ws: WebSocket, | ||||
| #     rpc: RPC = Depends(get_rpc), | ||||
| #     channel_manager=Depends(get_channel_manager), | ||||
| #     token: str = Depends(validate_ws_token) | ||||
| # ): | ||||
| #     """ | ||||
| #     Message WebSocket endpoint, facilitates sending RPC messages | ||||
| #     """ | ||||
| #     try: | ||||
| #         channel = await channel_manager.on_connect(ws) | ||||
| #         if await is_websocket_alive(ws): | ||||
|  | ||||
|     except RuntimeError: | ||||
|         # WebSocket was closed | ||||
|         # Do nothing | ||||
|         pass | ||||
|     except Exception as e: | ||||
|         logger.error(f"Failed to serve - {ws.client}") | ||||
|         # Log tracebacks to keep track of what errors are happening | ||||
|         logger.exception(e) | ||||
|     finally: | ||||
|         if channel: | ||||
|             await channel_manager.on_disconnect(ws) | ||||
| #             logger.info(f"Consumer connected - {channel}") | ||||
|  | ||||
| #             # Keep connection open until explicitly closed, and process requests | ||||
| #             try: | ||||
| #                 while not channel.is_closed(): | ||||
| #                     request = await channel.recv() | ||||
|  | ||||
| #                     # Process the request here | ||||
| #                     await _process_consumer_request(request, channel, rpc, channel_manager) | ||||
|  | ||||
| #             except (WebSocketDisconnect, WebSocketException): | ||||
| #                 # Handle client disconnects | ||||
| #                 logger.info(f"Consumer disconnected - {channel}") | ||||
| #             except RuntimeError: | ||||
| #                 # Handle cases like - | ||||
| #                 # RuntimeError('Cannot call "send" once a closed message has been sent') | ||||
| #                 pass | ||||
| #             except Exception as e: | ||||
| #                 logger.info(f"Consumer connection failed - {channel}: {e}") | ||||
| #                 logger.debug(e, exc_info=e) | ||||
|  | ||||
| #     except RuntimeError: | ||||
| #         # WebSocket was closed | ||||
| #         # Do nothing | ||||
| #         pass | ||||
| #     except Exception as e: | ||||
| #         logger.error(f"Failed to serve - {ws.client}") | ||||
| #         # Log tracebacks to keep track of what errors are happening | ||||
| #         logger.exception(e) | ||||
| #     finally: | ||||
| #         if channel: | ||||
| #             await channel_manager.on_disconnect(ws) | ||||
|   | ||||
| @@ -41,8 +41,8 @@ def get_exchange(config=Depends(get_config)): | ||||
|     return ApiServer._exchange | ||||
|  | ||||
|  | ||||
| def get_channel_manager(): | ||||
|     return ApiServer._ws_channel_manager | ||||
| def get_message_stream(): | ||||
|     return ApiServer._message_stream | ||||
|  | ||||
|  | ||||
| def is_webserver_mode(config=Depends(get_config)): | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| import asyncio | ||||
| import logging | ||||
| from ipaddress import IPv4Address | ||||
| from threading import Thread | ||||
| from typing import Any, Dict | ||||
|  | ||||
| import orjson | ||||
| @@ -15,7 +14,7 @@ from starlette.responses import JSONResponse | ||||
| from freqtrade.constants import Config | ||||
| from freqtrade.exceptions import OperationalException | ||||
| from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer | ||||
| from freqtrade.rpc.api_server.ws import ChannelManager | ||||
| from freqtrade.rpc.api_server.ws.message_stream import MessageStream | ||||
| from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType | ||||
| from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler | ||||
|  | ||||
| @@ -51,9 +50,10 @@ class ApiServer(RPCHandler): | ||||
|     # Exchange - only available in webserver mode. | ||||
|     _exchange = None | ||||
|     # websocket message queue stuff | ||||
|     _ws_channel_manager = None | ||||
|     _ws_thread = None | ||||
|     _ws_loop = None | ||||
|     # _ws_channel_manager = None | ||||
|     # _ws_thread = None | ||||
|     # _ws_loop = None | ||||
|     _message_stream = None | ||||
|  | ||||
|     def __new__(cls, *args, **kwargs): | ||||
|         """ | ||||
| @@ -71,14 +71,15 @@ class ApiServer(RPCHandler): | ||||
|             return | ||||
|         self._standalone: bool = standalone | ||||
|         self._server = None | ||||
|  | ||||
|         self._ws_queue = None | ||||
|         self._ws_background_task = None | ||||
|         self._ws_publisher_task = None | ||||
|  | ||||
|         ApiServer.__initialized = True | ||||
|  | ||||
|         api_config = self._config['api_server'] | ||||
|  | ||||
|         ApiServer._ws_channel_manager = ChannelManager() | ||||
|         # ApiServer._ws_channel_manager = ChannelManager() | ||||
|  | ||||
|         self.app = FastAPI(title="Freqtrade API", | ||||
|                            docs_url='/docs' if api_config.get('enable_openapi', False) else None, | ||||
| @@ -107,18 +108,18 @@ class ApiServer(RPCHandler): | ||||
|             logger.info("Stopping API Server") | ||||
|             self._server.cleanup() | ||||
|  | ||||
|         if self._ws_thread and self._ws_loop: | ||||
|             logger.info("Stopping API Server background tasks") | ||||
|         # if self._ws_thread and self._ws_loop: | ||||
|         #     logger.info("Stopping API Server background tasks") | ||||
|  | ||||
|             if self._ws_background_task: | ||||
|                 # Cancel the queue task | ||||
|                 self._ws_background_task.cancel() | ||||
|         #     if self._ws_background_task: | ||||
|         #         # Cancel the queue task | ||||
|         #         self._ws_background_task.cancel() | ||||
|  | ||||
|             self._ws_thread.join() | ||||
|         #     self._ws_thread.join() | ||||
|  | ||||
|         self._ws_thread = None | ||||
|         self._ws_loop = None | ||||
|         self._ws_background_task = None | ||||
|         # self._ws_thread = None | ||||
|         # self._ws_loop = None | ||||
|         # self._ws_background_task = None | ||||
|  | ||||
|     @classmethod | ||||
|     def shutdown(cls): | ||||
| @@ -170,51 +171,102 @@ class ApiServer(RPCHandler): | ||||
|         ) | ||||
|  | ||||
|         app.add_exception_handler(RPCException, self.handle_rpc_exception) | ||||
|         app.add_event_handler( | ||||
|             event_type="startup", | ||||
|             func=self._api_startup_event | ||||
|         ) | ||||
|         app.add_event_handler( | ||||
|             event_type="shutdown", | ||||
|             func=self._api_shutdown_event | ||||
|         ) | ||||
|  | ||||
|     def start_message_queue(self): | ||||
|         if self._ws_thread: | ||||
|             return | ||||
|     async def _api_startup_event(self): | ||||
|         if not ApiServer._message_stream: | ||||
|             ApiServer._message_stream = MessageStream() | ||||
|  | ||||
|         # Create a new loop, as it'll be just for the background thread | ||||
|         self._ws_loop = asyncio.new_event_loop() | ||||
|         if not self._ws_queue: | ||||
|             self._ws_queue = ThreadedQueue() | ||||
|  | ||||
|         # Start the thread | ||||
|         self._ws_thread = Thread(target=self._ws_loop.run_forever) | ||||
|         self._ws_thread.start() | ||||
|         if not self._ws_publisher_task: | ||||
|             self._ws_publisher_task = asyncio.create_task( | ||||
|                 self._publish_messages() | ||||
|             ) | ||||
|  | ||||
|         # Finally, submit the coro to the thread | ||||
|         self._ws_background_task = asyncio.run_coroutine_threadsafe( | ||||
|             self._broadcast_queue_data(), loop=self._ws_loop) | ||||
|     async def _api_shutdown_event(self): | ||||
|         if ApiServer._message_stream: | ||||
|             ApiServer._message_stream = None | ||||
|  | ||||
|     async def _broadcast_queue_data(self): | ||||
|         # Instantiate the queue in this coroutine so it's attached to our loop | ||||
|         self._ws_queue = ThreadedQueue() | ||||
|         async_queue = self._ws_queue.async_q | ||||
|  | ||||
|         try: | ||||
|             while True: | ||||
|                 logger.debug("Getting queue messages...") | ||||
|                 # Get data from queue | ||||
|                 message: WSMessageSchemaType = await async_queue.get() | ||||
|                 logger.debug(f"Found message of type: {message.get('type')}") | ||||
|                 async_queue.task_done() | ||||
|                 # Broadcast it | ||||
|                 await self._ws_channel_manager.broadcast(message) | ||||
|         except asyncio.CancelledError: | ||||
|             pass | ||||
|  | ||||
|         # For testing, shouldn't happen when stable | ||||
|         except Exception as e: | ||||
|             logger.exception(f"Exception happened in background task: {e}") | ||||
|  | ||||
|         finally: | ||||
|             # Disconnect channels and stop the loop on cancel | ||||
|             await self._ws_channel_manager.disconnect_all() | ||||
|             self._ws_loop.stop() | ||||
|             # Avoid adding more items to the queue if they aren't | ||||
|             # going to get broadcasted. | ||||
|         if self._ws_queue: | ||||
|             self._ws_queue = None | ||||
|  | ||||
|         if self._ws_publisher_task: | ||||
|             self._ws_publisher_task.cancel() | ||||
|  | ||||
|     async def _publish_messages(self): | ||||
|         """ | ||||
|         Background task that reads messages from the queue and adds them | ||||
|         to the message stream | ||||
|         """ | ||||
|         try: | ||||
|             async_queue = self._ws_queue.async_q | ||||
|             message_stream = ApiServer._message_stream | ||||
|  | ||||
|             while message_stream: | ||||
|                 message: WSMessageSchemaType = await async_queue.get() | ||||
|                 message_stream.publish(message) | ||||
|  | ||||
|                 # Make sure to throttle how fast we | ||||
|                 # publish messages as some clients will be | ||||
|                 # slower than others | ||||
|                 await asyncio.sleep(0.01) | ||||
|                 async_queue.task_done() | ||||
|         finally: | ||||
|             self._ws_queue = None | ||||
|  | ||||
|     # def start_message_queue(self): | ||||
|     #     if self._ws_thread: | ||||
|     #         return | ||||
|  | ||||
|     #     # Create a new loop, as it'll be just for the background thread | ||||
|     #     self._ws_loop = asyncio.new_event_loop() | ||||
|  | ||||
|     #     # Start the thread | ||||
|     #     self._ws_thread = Thread(target=self._ws_loop.run_forever) | ||||
|     #     self._ws_thread.start() | ||||
|  | ||||
|     #     # Finally, submit the coro to the thread | ||||
|     #     self._ws_background_task = asyncio.run_coroutine_threadsafe( | ||||
|     #         self._broadcast_queue_data(), loop=self._ws_loop) | ||||
|  | ||||
|     # async def _broadcast_queue_data(self): | ||||
|     #     # Instantiate the queue in this coroutine so it's attached to our loop | ||||
|     #     self._ws_queue = ThreadedQueue() | ||||
|     #     async_queue = self._ws_queue.async_q | ||||
|  | ||||
|     #     try: | ||||
|     #         while True: | ||||
|     #             logger.debug("Getting queue messages...") | ||||
|     #             # Get data from queue | ||||
|     #             message: WSMessageSchemaType = await async_queue.get() | ||||
|     #             logger.debug(f"Found message of type: {message.get('type')}") | ||||
|     #             async_queue.task_done() | ||||
|     #             # Broadcast it | ||||
|     #             await self._ws_channel_manager.broadcast(message) | ||||
|     #     except asyncio.CancelledError: | ||||
|     #         pass | ||||
|  | ||||
|     #     # For testing, shouldn't happen when stable | ||||
|     #     except Exception as e: | ||||
|     #         logger.exception(f"Exception happened in background task: {e}") | ||||
|  | ||||
|     #     finally: | ||||
|     #         # Disconnect channels and stop the loop on cancel | ||||
|     #         await self._ws_channel_manager.disconnect_all() | ||||
|     #         self._ws_loop.stop() | ||||
|     #         # Avoid adding more items to the queue if they aren't | ||||
|     #         # going to get broadcasted. | ||||
|     #         self._ws_queue = None | ||||
|  | ||||
|     def start_api(self): | ||||
|         """ | ||||
|         Start API ... should be run in thread. | ||||
| @@ -253,7 +305,7 @@ class ApiServer(RPCHandler): | ||||
|             if self._standalone: | ||||
|                 self._server.run() | ||||
|             else: | ||||
|                 self.start_message_queue() | ||||
|                 # self.start_message_queue() | ||||
|                 self._server.run_in_thread() | ||||
|         except Exception: | ||||
|             logger.exception("Api server failed to start.") | ||||
|   | ||||
| @@ -3,4 +3,5 @@ | ||||
| from freqtrade.rpc.api_server.ws.types import WebSocketType | ||||
| from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy | ||||
| from freqtrade.rpc.api_server.ws.serializer import HybridJSONWebSocketSerializer | ||||
| from freqtrade.rpc.api_server.ws.channel import ChannelManager, WebSocketChannel | ||||
| from freqtrade.rpc.api_server.ws.channel import WebSocketChannel | ||||
| from freqtrade.rpc.api_server.ws.message_stream import MessageStream | ||||
|   | ||||
| @@ -1,12 +1,9 @@ | ||||
| import asyncio | ||||
| import logging | ||||
| import time | ||||
| from threading import RLock | ||||
| from contextlib import asynccontextmanager | ||||
| from typing import Any, Dict, List, Optional, Type, Union | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from fastapi import WebSocket as FastAPIWebSocket | ||||
|  | ||||
| from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy | ||||
| from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer, | ||||
|                                                     WebSocketSerializer) | ||||
| @@ -21,32 +18,21 @@ class WebSocketChannel: | ||||
|     """ | ||||
|     Object to help facilitate managing a websocket connection | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         websocket: WebSocketType, | ||||
|         channel_id: Optional[str] = None, | ||||
|         drain_timeout: int = 3, | ||||
|         throttle: float = 0.01, | ||||
|         serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer | ||||
|     ): | ||||
|  | ||||
|         self.channel_id = channel_id if channel_id else uuid4().hex[:8] | ||||
|  | ||||
|         # The WebSocket object | ||||
|         self._websocket = WebSocketProxy(websocket) | ||||
|  | ||||
|         self.drain_timeout = drain_timeout | ||||
|         self.throttle = throttle | ||||
|  | ||||
|         self._subscriptions: List[str] = [] | ||||
|         # 32 is the size of the receiving queue in websockets package | ||||
|         self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32) | ||||
|         self._relay_task = asyncio.create_task(self.relay()) | ||||
|  | ||||
|         # Internal event to signify a closed websocket | ||||
|         self._closed = asyncio.Event() | ||||
|  | ||||
|         # Throttle how fast we send messages | ||||
|         self._throttle = 0.01 | ||||
|  | ||||
|         # Wrap the WebSocket in the Serializing class | ||||
|         self._wrapped_ws = serializer_cls(self._websocket) | ||||
|  | ||||
| @@ -61,40 +47,16 @@ class WebSocketChannel: | ||||
|     def remote_addr(self): | ||||
|         return self._websocket.remote_addr | ||||
|  | ||||
|     async def _send(self, data): | ||||
|     async def send(self, message: Union[WSMessageSchemaType, Dict[str, Any]]): | ||||
|         """ | ||||
|         Send data on the wrapped websocket | ||||
|         Send a message on the wrapped websocket | ||||
|         """ | ||||
|         await self._wrapped_ws.send(data) | ||||
|  | ||||
|     async def send(self, data) -> bool: | ||||
|         """ | ||||
|         Add the data to the queue to be sent. | ||||
|         :returns: True if data added to queue, False otherwise | ||||
|         """ | ||||
|  | ||||
|         # This block only runs if the queue is full, it will wait | ||||
|         # until self.drain_timeout for the relay to drain the outgoing queue | ||||
|         # We can't use asyncio.wait_for here because the queue may have been created with a | ||||
|         # different eventloop | ||||
|         start = time.time() | ||||
|         while self.queue.full(): | ||||
|             await asyncio.sleep(1) | ||||
|             if (time.time() - start) > self.drain_timeout: | ||||
|                 return False | ||||
|  | ||||
|         # If for some reason the queue is still full, just return False | ||||
|         try: | ||||
|             self.queue.put_nowait(data) | ||||
|         except asyncio.QueueFull: | ||||
|             return False | ||||
|  | ||||
|         # If we got here everything is ok | ||||
|         return True | ||||
|         await asyncio.sleep(self._throttle) | ||||
|         await self._wrapped_ws.send(message) | ||||
|  | ||||
|     async def recv(self): | ||||
|         """ | ||||
|         Receive data on the wrapped websocket | ||||
|         Receive a message on the wrapped websocket | ||||
|         """ | ||||
|         return await self._wrapped_ws.recv() | ||||
|  | ||||
| @@ -104,18 +66,23 @@ class WebSocketChannel: | ||||
|         """ | ||||
|         return await self._websocket.ping() | ||||
|  | ||||
|     async def accept(self): | ||||
|         """ | ||||
|         Accept the underlying websocket connection | ||||
|         """ | ||||
|         return await self._websocket.accept() | ||||
|  | ||||
|     async def close(self): | ||||
|         """ | ||||
|         Close the WebSocketChannel | ||||
|         """ | ||||
|  | ||||
|         try: | ||||
|             await self.raw_websocket.close() | ||||
|             await self._websocket.close() | ||||
|         except Exception: | ||||
|             pass | ||||
|  | ||||
|         self._closed.set() | ||||
|         self._relay_task.cancel() | ||||
|  | ||||
|     def is_closed(self) -> bool: | ||||
|         """ | ||||
| @@ -139,99 +106,243 @@ class WebSocketChannel: | ||||
|         """ | ||||
|         return message_type in self._subscriptions | ||||
|  | ||||
|     async def relay(self): | ||||
|     async def __aiter__(self): | ||||
|         """ | ||||
|         Relay messages from the channel's queue and send them out. This is started | ||||
|         as a task. | ||||
|         Generator for received messages | ||||
|         """ | ||||
|         while not self._closed.is_set(): | ||||
|             message = await self.queue.get() | ||||
|         while True: | ||||
|             try: | ||||
|                 await self._send(message) | ||||
|                 self.queue.task_done() | ||||
|                 yield await self.recv() | ||||
|             except Exception: | ||||
|                 break | ||||
|  | ||||
|                 # Limit messages per sec. | ||||
|                 # Could cause problems with queue size if too low, and | ||||
|                 # problems with network traffik if too high. | ||||
|                 # 0.01 = 100/s | ||||
|                 await asyncio.sleep(self.throttle) | ||||
|             except RuntimeError: | ||||
|                 # The connection was closed, just exit the task | ||||
|                 return | ||||
|  | ||||
|  | ||||
| class ChannelManager: | ||||
|     def __init__(self): | ||||
|         self.channels = dict() | ||||
|         self._lock = RLock()  # Re-entrant Lock | ||||
|  | ||||
|     async def on_connect(self, websocket: WebSocketType): | ||||
|     @asynccontextmanager | ||||
|     async def connect(self): | ||||
|         """ | ||||
|         Wrap websocket connection into Channel and add to list | ||||
|  | ||||
|         :param websocket: The WebSocket object to attach to the Channel | ||||
|         Context manager for safely opening and closing the websocket connection | ||||
|         """ | ||||
|         if isinstance(websocket, FastAPIWebSocket): | ||||
|             try: | ||||
|                 await websocket.accept() | ||||
|             except RuntimeError: | ||||
|                 # The connection was closed before we could accept it | ||||
|                 return | ||||
|         try: | ||||
|             await self.accept() | ||||
|             yield self | ||||
|         finally: | ||||
|             await self.close() | ||||
|  | ||||
|         ws_channel = WebSocketChannel(websocket) | ||||
|  | ||||
|         with self._lock: | ||||
|             self.channels[websocket] = ws_channel | ||||
| # class WebSocketChannel: | ||||
| #     """ | ||||
| #     Object to help facilitate managing a websocket connection | ||||
| #     """ | ||||
|  | ||||
|         return ws_channel | ||||
| #     def __init__( | ||||
| #         self, | ||||
| #         websocket: WebSocketType, | ||||
| #         channel_id: Optional[str] = None, | ||||
| #         drain_timeout: int = 3, | ||||
| #         throttle: float = 0.01, | ||||
| #         serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer | ||||
| #     ): | ||||
|  | ||||
|     async def on_disconnect(self, websocket: WebSocketType): | ||||
|         """ | ||||
|         Call close on the channel if it's not, and remove from channel list | ||||
| #         self.channel_id = channel_id if channel_id else uuid4().hex[:8] | ||||
|  | ||||
|         :param websocket: The WebSocket objet attached to the Channel | ||||
|         """ | ||||
|         with self._lock: | ||||
|             channel = self.channels.get(websocket) | ||||
|             if channel: | ||||
|                 logger.info(f"Disconnecting channel {channel}") | ||||
|                 if not channel.is_closed(): | ||||
|                     await channel.close() | ||||
| #         # The WebSocket object | ||||
| #         self._websocket = WebSocketProxy(websocket) | ||||
|  | ||||
|                 del self.channels[websocket] | ||||
| #         self.drain_timeout = drain_timeout | ||||
| #         self.throttle = throttle | ||||
|  | ||||
|     async def disconnect_all(self): | ||||
|         """ | ||||
|         Disconnect all Channels | ||||
|         """ | ||||
|         with self._lock: | ||||
|             for websocket in self.channels.copy().keys(): | ||||
|                 await self.on_disconnect(websocket) | ||||
| #         self._subscriptions: List[str] = [] | ||||
| #         # 32 is the size of the receiving queue in websockets package | ||||
| #         self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32) | ||||
| #         self._relay_task = asyncio.create_task(self.relay()) | ||||
|  | ||||
|     async def broadcast(self, message: WSMessageSchemaType): | ||||
|         """ | ||||
|         Broadcast a message on all Channels | ||||
| #         # Internal event to signify a closed websocket | ||||
| #         self._closed = asyncio.Event() | ||||
|  | ||||
|         :param message: The message to send | ||||
|         """ | ||||
|         with self._lock: | ||||
|             for channel in self.channels.copy().values(): | ||||
|                 if channel.subscribed_to(message.get('type')): | ||||
|                     await self.send_direct(channel, message) | ||||
| #         # Wrap the WebSocket in the Serializing class | ||||
| #         self._wrapped_ws = serializer_cls(self._websocket) | ||||
|  | ||||
|     async def send_direct( | ||||
|             self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]): | ||||
|         """ | ||||
|         Send a message directly through direct_channel only | ||||
| #     def __repr__(self): | ||||
| #         return f"WebSocketChannel({self.channel_id}, {self.remote_addr})" | ||||
|  | ||||
|         :param direct_channel: The WebSocketChannel object to send the message through | ||||
|         :param message: The message to send | ||||
|         """ | ||||
|         if not await channel.send(message): | ||||
|             await self.on_disconnect(channel.raw_websocket) | ||||
| #     @property | ||||
| #     def raw_websocket(self): | ||||
| #         return self._websocket.raw_websocket | ||||
|  | ||||
|     def has_channels(self): | ||||
|         """ | ||||
|         Flag for more than 0 channels | ||||
|         """ | ||||
|         return len(self.channels) > 0 | ||||
| #     @property | ||||
| #     def remote_addr(self): | ||||
| #         return self._websocket.remote_addr | ||||
|  | ||||
| #     async def _send(self, data): | ||||
| #         """ | ||||
| #         Send data on the wrapped websocket | ||||
| #         """ | ||||
| #         await self._wrapped_ws.send(data) | ||||
|  | ||||
| #     async def send(self, data) -> bool: | ||||
| #         """ | ||||
| #         Add the data to the queue to be sent. | ||||
| #         :returns: True if data added to queue, False otherwise | ||||
| #         """ | ||||
|  | ||||
| #         # This block only runs if the queue is full, it will wait | ||||
| #         # until self.drain_timeout for the relay to drain the outgoing queue | ||||
| #         # We can't use asyncio.wait_for here because the queue may have been created with a | ||||
| #         # different eventloop | ||||
| #         start = time.time() | ||||
| #         while self.queue.full(): | ||||
| #             await asyncio.sleep(1) | ||||
| #             if (time.time() - start) > self.drain_timeout: | ||||
| #                 return False | ||||
|  | ||||
| #         # If for some reason the queue is still full, just return False | ||||
| #         try: | ||||
| #             self.queue.put_nowait(data) | ||||
| #         except asyncio.QueueFull: | ||||
| #             return False | ||||
|  | ||||
| #         # If we got here everything is ok | ||||
| #         return True | ||||
|  | ||||
| #     async def recv(self): | ||||
| #         """ | ||||
| #         Receive data on the wrapped websocket | ||||
| #         """ | ||||
| #         return await self._wrapped_ws.recv() | ||||
|  | ||||
| #     async def ping(self): | ||||
| #         """ | ||||
| #         Ping the websocket | ||||
| #         """ | ||||
| #         return await self._websocket.ping() | ||||
|  | ||||
| #     async def close(self): | ||||
| #         """ | ||||
| #         Close the WebSocketChannel | ||||
| #         """ | ||||
|  | ||||
| #         try: | ||||
| #             await self.raw_websocket.close() | ||||
| #         except Exception: | ||||
| #             pass | ||||
|  | ||||
| #         self._closed.set() | ||||
| #         self._relay_task.cancel() | ||||
|  | ||||
| #     def is_closed(self) -> bool: | ||||
| #         """ | ||||
| #         Closed flag | ||||
| #         """ | ||||
| #         return self._closed.is_set() | ||||
|  | ||||
| #     def set_subscriptions(self, subscriptions: List[str] = []) -> None: | ||||
| #         """ | ||||
| #         Set which subscriptions this channel is subscribed to | ||||
|  | ||||
| #         :param subscriptions: List of subscriptions, List[str] | ||||
| #         """ | ||||
| #         self._subscriptions = subscriptions | ||||
|  | ||||
| #     def subscribed_to(self, message_type: str) -> bool: | ||||
| #         """ | ||||
| #         Check if this channel is subscribed to the message_type | ||||
|  | ||||
| #         :param message_type: The message type to check | ||||
| #         """ | ||||
| #         return message_type in self._subscriptions | ||||
|  | ||||
| #     async def relay(self): | ||||
| #         """ | ||||
| #         Relay messages from the channel's queue and send them out. This is started | ||||
| #         as a task. | ||||
| #         """ | ||||
| #         while not self._closed.is_set(): | ||||
| #             message = await self.queue.get() | ||||
| #             try: | ||||
| #                 await self._send(message) | ||||
| #                 self.queue.task_done() | ||||
|  | ||||
| #                 # Limit messages per sec. | ||||
| #                 # Could cause problems with queue size if too low, and | ||||
| #                 # problems with network traffik if too high. | ||||
| #                 # 0.01 = 100/s | ||||
| #                 await asyncio.sleep(self.throttle) | ||||
| #             except RuntimeError: | ||||
| #                 # The connection was closed, just exit the task | ||||
| #                 return | ||||
|  | ||||
|  | ||||
| # class ChannelManager: | ||||
| #     def __init__(self): | ||||
| #         self.channels = dict() | ||||
| #         self._lock = RLock()  # Re-entrant Lock | ||||
|  | ||||
| #     async def on_connect(self, websocket: WebSocketType): | ||||
| #         """ | ||||
| #         Wrap websocket connection into Channel and add to list | ||||
|  | ||||
| #         :param websocket: The WebSocket object to attach to the Channel | ||||
| #         """ | ||||
| #         if isinstance(websocket, FastAPIWebSocket): | ||||
| #             try: | ||||
| #                 await websocket.accept() | ||||
| #             except RuntimeError: | ||||
| #                 # The connection was closed before we could accept it | ||||
| #                 return | ||||
|  | ||||
| #         ws_channel = WebSocketChannel(websocket) | ||||
|  | ||||
| #         with self._lock: | ||||
| #             self.channels[websocket] = ws_channel | ||||
|  | ||||
| #         return ws_channel | ||||
|  | ||||
| #     async def on_disconnect(self, websocket: WebSocketType): | ||||
| #         """ | ||||
| #         Call close on the channel if it's not, and remove from channel list | ||||
|  | ||||
| #         :param websocket: The WebSocket objet attached to the Channel | ||||
| #         """ | ||||
| #         with self._lock: | ||||
| #             channel = self.channels.get(websocket) | ||||
| #             if channel: | ||||
| #                 logger.info(f"Disconnecting channel {channel}") | ||||
| #                 if not channel.is_closed(): | ||||
| #                     await channel.close() | ||||
|  | ||||
| #                 del self.channels[websocket] | ||||
|  | ||||
| #     async def disconnect_all(self): | ||||
| #         """ | ||||
| #         Disconnect all Channels | ||||
| #         """ | ||||
| #         with self._lock: | ||||
| #             for websocket in self.channels.copy().keys(): | ||||
| #                 await self.on_disconnect(websocket) | ||||
|  | ||||
| #     async def broadcast(self, message: WSMessageSchemaType): | ||||
| #         """ | ||||
| #         Broadcast a message on all Channels | ||||
|  | ||||
| #         :param message: The message to send | ||||
| #         """ | ||||
| #         with self._lock: | ||||
| #             for channel in self.channels.copy().values(): | ||||
| #                 if channel.subscribed_to(message.get('type')): | ||||
| #                     await self.send_direct(channel, message) | ||||
|  | ||||
| #     async def send_direct( | ||||
| #             self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]): | ||||
| #         """ | ||||
| #         Send a message directly through direct_channel only | ||||
|  | ||||
| #         :param direct_channel: The WebSocketChannel object to send the message through | ||||
| #         :param message: The message to send | ||||
| #         """ | ||||
| #         if not await channel.send(message): | ||||
| #             await self.on_disconnect(channel.raw_websocket) | ||||
|  | ||||
| #     def has_channels(self): | ||||
| #         """ | ||||
| #         Flag for more than 0 channels | ||||
| #         """ | ||||
| #         return len(self.channels) > 0 | ||||
|   | ||||
							
								
								
									
										23
									
								
								freqtrade/rpc/api_server/ws/message_stream.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								freqtrade/rpc/api_server/ws/message_stream.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,23 @@ | ||||
| import asyncio | ||||
|  | ||||
|  | ||||
| class MessageStream: | ||||
|     """ | ||||
|     A message stream for consumers to subscribe to, | ||||
|     and for producers to publish to. | ||||
|     """ | ||||
|     def __init__(self): | ||||
|         self._loop = asyncio.get_running_loop() | ||||
|         self._waiter = self._loop.create_future() | ||||
|  | ||||
|     def publish(self, message): | ||||
|         waiter, self._waiter = self._waiter, self._loop.create_future() | ||||
|         waiter.set_result((message, self._waiter)) | ||||
|  | ||||
|     async def subscribe(self): | ||||
|         waiter = self._waiter | ||||
|         while True: | ||||
|             message, waiter = await waiter | ||||
|             yield message | ||||
|  | ||||
|     __aiter__ = subscribe | ||||
| @@ -1,5 +1,6 @@ | ||||
| import logging | ||||
| from abc import ABC, abstractmethod | ||||
| from typing import Any, Dict, Union | ||||
|  | ||||
| import orjson | ||||
| import rapidjson | ||||
| @@ -7,6 +8,7 @@ from pandas import DataFrame | ||||
|  | ||||
| from freqtrade.misc import dataframe_to_json, json_to_dataframe | ||||
| from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy | ||||
| from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
| @@ -24,7 +26,7 @@ class WebSocketSerializer(ABC): | ||||
|     def _deserialize(self, data): | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     async def send(self, data: bytes): | ||||
|     async def send(self, data: Union[WSMessageSchemaType, Dict[str, Any]]): | ||||
|         await self._websocket.send(self._serialize(data)) | ||||
|  | ||||
|     async def recv(self) -> bytes: | ||||
| @@ -32,8 +34,8 @@ class WebSocketSerializer(ABC): | ||||
|  | ||||
|         return self._deserialize(data) | ||||
|  | ||||
|     async def close(self, code: int = 1000): | ||||
|         await self._websocket.close(code) | ||||
|     # async def close(self, code: int = 1000): | ||||
|     #     await self._websocket.close(code) | ||||
|  | ||||
|  | ||||
| class HybridJSONWebSocketSerializer(WebSocketSerializer): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user