import logging from threading import RLock from typing import List, Optional, Type from uuid import uuid4 from fastapi import WebSocket as FastAPIWebSocket from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer, WebSocketSerializer) from freqtrade.rpc.api_server.ws.types import WebSocketType logger = logging.getLogger(__name__) class WebSocketChannel: """ Object to help facilitate managing a websocket connection """ def __init__( self, websocket: WebSocketType, channel_id: Optional[str] = None, serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer ): self.channel_id = channel_id if channel_id else uuid4().hex[:8] # The WebSocket object self._websocket = WebSocketProxy(websocket) # The Serializing class for the WebSocket object self._serializer_cls = serializer_cls self._subscriptions: List[str] = [] # Internal event to signify a closed websocket self._closed = False # Wrap the WebSocket in the Serializing class self._wrapped_ws = self._serializer_cls(self._websocket) def __repr__(self): return f"WebSocketChannel({self.channel_id}, {self.remote_addr})" @property def remote_addr(self): return self._websocket.remote_addr async def send(self, data): """ Send data on the wrapped websocket """ await self._wrapped_ws.send(data) async def recv(self): """ Receive data on the wrapped websocket """ return await self._wrapped_ws.recv() async def ping(self): """ Ping the websocket """ return await self._websocket.ping() async def close(self): """ Close the WebSocketChannel """ self._closed = True def is_closed(self) -> bool: """ Closed flag """ return self._closed def set_subscriptions(self, subscriptions: List[str] = []) -> None: """ Set which subscriptions this channel is subscribed to :param subscriptions: List of subscriptions, List[str] """ self._subscriptions = subscriptions def subscribed_to(self, message_type: str) -> bool: """ Check if this channel is subscribed to the message_type :param message_type: The message type to check """ return message_type in self._subscriptions class ChannelManager: def __init__(self): self.channels = dict() self._lock = RLock() # Re-entrant Lock async def on_connect(self, websocket: WebSocketType): """ Wrap websocket connection into Channel and add to list :param websocket: The WebSocket object to attach to the Channel """ if isinstance(websocket, FastAPIWebSocket): try: await websocket.accept() except RuntimeError: # The connection was closed before we could accept it return ws_channel = WebSocketChannel(websocket) with self._lock: self.channels[websocket] = ws_channel return ws_channel async def on_disconnect(self, websocket: WebSocketType): """ Call close on the channel if it's not, and remove from channel list :param websocket: The WebSocket objet attached to the Channel """ with self._lock: channel = self.channels.get(websocket) if channel: if not channel.is_closed(): await channel.close() del self.channels[websocket] async def disconnect_all(self): """ Disconnect all Channels """ with self._lock: for websocket, channel in self.channels.items(): if not channel.is_closed(): await channel.close() self.channels = dict() async def broadcast(self, data): """ Broadcast data on all Channels :param data: The data to send """ with self._lock: message_type = data.get('type') for websocket, channel in self.channels.items(): try: if channel.subscribed_to(message_type): await channel.send(data) except RuntimeError: # Handle cannot send after close cases await self.on_disconnect(websocket) async def send_direct(self, channel, data): """ Send data directly through direct_channel only :param direct_channel: The WebSocketChannel object to send data through :param data: The data to send """ await channel.send(data) def has_channels(self): """ Flag for more than 0 channels """ return len(self.channels) > 0