initial revision

This commit is contained in:
Timothy Pogue
2022-11-14 20:27:45 -07:00
parent a951b49541
commit 659c8c237f
7 changed files with 494 additions and 241 deletions

View File

@@ -3,4 +3,5 @@
from freqtrade.rpc.api_server.ws.types import WebSocketType
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import HybridJSONWebSocketSerializer
from freqtrade.rpc.api_server.ws.channel import ChannelManager, WebSocketChannel
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
from freqtrade.rpc.api_server.ws.message_stream import MessageStream

View File

@@ -1,12 +1,9 @@
import asyncio
import logging
import time
from threading import RLock
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Optional, Type, Union
from uuid import uuid4
from fastapi import WebSocket as FastAPIWebSocket
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer,
WebSocketSerializer)
@@ -21,32 +18,21 @@ class WebSocketChannel:
"""
Object to help facilitate managing a websocket connection
"""
def __init__(
self,
websocket: WebSocketType,
channel_id: Optional[str] = None,
drain_timeout: int = 3,
throttle: float = 0.01,
serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer
):
self.channel_id = channel_id if channel_id else uuid4().hex[:8]
# The WebSocket object
self._websocket = WebSocketProxy(websocket)
self.drain_timeout = drain_timeout
self.throttle = throttle
self._subscriptions: List[str] = []
# 32 is the size of the receiving queue in websockets package
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
self._relay_task = asyncio.create_task(self.relay())
# Internal event to signify a closed websocket
self._closed = asyncio.Event()
# Throttle how fast we send messages
self._throttle = 0.01
# Wrap the WebSocket in the Serializing class
self._wrapped_ws = serializer_cls(self._websocket)
@@ -61,40 +47,16 @@ class WebSocketChannel:
def remote_addr(self):
return self._websocket.remote_addr
async def _send(self, data):
async def send(self, message: Union[WSMessageSchemaType, Dict[str, Any]]):
"""
Send data on the wrapped websocket
Send a message on the wrapped websocket
"""
await self._wrapped_ws.send(data)
async def send(self, data) -> bool:
"""
Add the data to the queue to be sent.
:returns: True if data added to queue, False otherwise
"""
# This block only runs if the queue is full, it will wait
# until self.drain_timeout for the relay to drain the outgoing queue
# We can't use asyncio.wait_for here because the queue may have been created with a
# different eventloop
start = time.time()
while self.queue.full():
await asyncio.sleep(1)
if (time.time() - start) > self.drain_timeout:
return False
# If for some reason the queue is still full, just return False
try:
self.queue.put_nowait(data)
except asyncio.QueueFull:
return False
# If we got here everything is ok
return True
await asyncio.sleep(self._throttle)
await self._wrapped_ws.send(message)
async def recv(self):
"""
Receive data on the wrapped websocket
Receive a message on the wrapped websocket
"""
return await self._wrapped_ws.recv()
@@ -104,18 +66,23 @@ class WebSocketChannel:
"""
return await self._websocket.ping()
async def accept(self):
"""
Accept the underlying websocket connection
"""
return await self._websocket.accept()
async def close(self):
"""
Close the WebSocketChannel
"""
try:
await self.raw_websocket.close()
await self._websocket.close()
except Exception:
pass
self._closed.set()
self._relay_task.cancel()
def is_closed(self) -> bool:
"""
@@ -139,99 +106,243 @@ class WebSocketChannel:
"""
return message_type in self._subscriptions
async def relay(self):
async def __aiter__(self):
"""
Relay messages from the channel's queue and send them out. This is started
as a task.
Generator for received messages
"""
while not self._closed.is_set():
message = await self.queue.get()
while True:
try:
await self._send(message)
self.queue.task_done()
yield await self.recv()
except Exception:
break
# Limit messages per sec.
# Could cause problems with queue size if too low, and
# problems with network traffik if too high.
# 0.01 = 100/s
await asyncio.sleep(self.throttle)
except RuntimeError:
# The connection was closed, just exit the task
return
class ChannelManager:
def __init__(self):
self.channels = dict()
self._lock = RLock() # Re-entrant Lock
async def on_connect(self, websocket: WebSocketType):
@asynccontextmanager
async def connect(self):
"""
Wrap websocket connection into Channel and add to list
:param websocket: The WebSocket object to attach to the Channel
Context manager for safely opening and closing the websocket connection
"""
if isinstance(websocket, FastAPIWebSocket):
try:
await websocket.accept()
except RuntimeError:
# The connection was closed before we could accept it
return
try:
await self.accept()
yield self
finally:
await self.close()
ws_channel = WebSocketChannel(websocket)
with self._lock:
self.channels[websocket] = ws_channel
# class WebSocketChannel:
# """
# Object to help facilitate managing a websocket connection
# """
return ws_channel
# def __init__(
# self,
# websocket: WebSocketType,
# channel_id: Optional[str] = None,
# drain_timeout: int = 3,
# throttle: float = 0.01,
# serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer
# ):
async def on_disconnect(self, websocket: WebSocketType):
"""
Call close on the channel if it's not, and remove from channel list
# self.channel_id = channel_id if channel_id else uuid4().hex[:8]
:param websocket: The WebSocket objet attached to the Channel
"""
with self._lock:
channel = self.channels.get(websocket)
if channel:
logger.info(f"Disconnecting channel {channel}")
if not channel.is_closed():
await channel.close()
# # The WebSocket object
# self._websocket = WebSocketProxy(websocket)
del self.channels[websocket]
# self.drain_timeout = drain_timeout
# self.throttle = throttle
async def disconnect_all(self):
"""
Disconnect all Channels
"""
with self._lock:
for websocket in self.channels.copy().keys():
await self.on_disconnect(websocket)
# self._subscriptions: List[str] = []
# # 32 is the size of the receiving queue in websockets package
# self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
# self._relay_task = asyncio.create_task(self.relay())
async def broadcast(self, message: WSMessageSchemaType):
"""
Broadcast a message on all Channels
# # Internal event to signify a closed websocket
# self._closed = asyncio.Event()
:param message: The message to send
"""
with self._lock:
for channel in self.channels.copy().values():
if channel.subscribed_to(message.get('type')):
await self.send_direct(channel, message)
# # Wrap the WebSocket in the Serializing class
# self._wrapped_ws = serializer_cls(self._websocket)
async def send_direct(
self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]):
"""
Send a message directly through direct_channel only
# def __repr__(self):
# return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
:param direct_channel: The WebSocketChannel object to send the message through
:param message: The message to send
"""
if not await channel.send(message):
await self.on_disconnect(channel.raw_websocket)
# @property
# def raw_websocket(self):
# return self._websocket.raw_websocket
def has_channels(self):
"""
Flag for more than 0 channels
"""
return len(self.channels) > 0
# @property
# def remote_addr(self):
# return self._websocket.remote_addr
# async def _send(self, data):
# """
# Send data on the wrapped websocket
# """
# await self._wrapped_ws.send(data)
# async def send(self, data) -> bool:
# """
# Add the data to the queue to be sent.
# :returns: True if data added to queue, False otherwise
# """
# # This block only runs if the queue is full, it will wait
# # until self.drain_timeout for the relay to drain the outgoing queue
# # We can't use asyncio.wait_for here because the queue may have been created with a
# # different eventloop
# start = time.time()
# while self.queue.full():
# await asyncio.sleep(1)
# if (time.time() - start) > self.drain_timeout:
# return False
# # If for some reason the queue is still full, just return False
# try:
# self.queue.put_nowait(data)
# except asyncio.QueueFull:
# return False
# # If we got here everything is ok
# return True
# async def recv(self):
# """
# Receive data on the wrapped websocket
# """
# return await self._wrapped_ws.recv()
# async def ping(self):
# """
# Ping the websocket
# """
# return await self._websocket.ping()
# async def close(self):
# """
# Close the WebSocketChannel
# """
# try:
# await self.raw_websocket.close()
# except Exception:
# pass
# self._closed.set()
# self._relay_task.cancel()
# def is_closed(self) -> bool:
# """
# Closed flag
# """
# return self._closed.is_set()
# def set_subscriptions(self, subscriptions: List[str] = []) -> None:
# """
# Set which subscriptions this channel is subscribed to
# :param subscriptions: List of subscriptions, List[str]
# """
# self._subscriptions = subscriptions
# def subscribed_to(self, message_type: str) -> bool:
# """
# Check if this channel is subscribed to the message_type
# :param message_type: The message type to check
# """
# return message_type in self._subscriptions
# async def relay(self):
# """
# Relay messages from the channel's queue and send them out. This is started
# as a task.
# """
# while not self._closed.is_set():
# message = await self.queue.get()
# try:
# await self._send(message)
# self.queue.task_done()
# # Limit messages per sec.
# # Could cause problems with queue size if too low, and
# # problems with network traffik if too high.
# # 0.01 = 100/s
# await asyncio.sleep(self.throttle)
# except RuntimeError:
# # The connection was closed, just exit the task
# return
# class ChannelManager:
# def __init__(self):
# self.channels = dict()
# self._lock = RLock() # Re-entrant Lock
# async def on_connect(self, websocket: WebSocketType):
# """
# Wrap websocket connection into Channel and add to list
# :param websocket: The WebSocket object to attach to the Channel
# """
# if isinstance(websocket, FastAPIWebSocket):
# try:
# await websocket.accept()
# except RuntimeError:
# # The connection was closed before we could accept it
# return
# ws_channel = WebSocketChannel(websocket)
# with self._lock:
# self.channels[websocket] = ws_channel
# return ws_channel
# async def on_disconnect(self, websocket: WebSocketType):
# """
# Call close on the channel if it's not, and remove from channel list
# :param websocket: The WebSocket objet attached to the Channel
# """
# with self._lock:
# channel = self.channels.get(websocket)
# if channel:
# logger.info(f"Disconnecting channel {channel}")
# if not channel.is_closed():
# await channel.close()
# del self.channels[websocket]
# async def disconnect_all(self):
# """
# Disconnect all Channels
# """
# with self._lock:
# for websocket in self.channels.copy().keys():
# await self.on_disconnect(websocket)
# async def broadcast(self, message: WSMessageSchemaType):
# """
# Broadcast a message on all Channels
# :param message: The message to send
# """
# with self._lock:
# for channel in self.channels.copy().values():
# if channel.subscribed_to(message.get('type')):
# await self.send_direct(channel, message)
# async def send_direct(
# self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]):
# """
# Send a message directly through direct_channel only
# :param direct_channel: The WebSocketChannel object to send the message through
# :param message: The message to send
# """
# if not await channel.send(message):
# await self.on_disconnect(channel.raw_websocket)
# def has_channels(self):
# """
# Flag for more than 0 channels
# """
# return len(self.channels) > 0

View File

@@ -0,0 +1,23 @@
import asyncio
class MessageStream:
"""
A message stream for consumers to subscribe to,
and for producers to publish to.
"""
def __init__(self):
self._loop = asyncio.get_running_loop()
self._waiter = self._loop.create_future()
def publish(self, message):
waiter, self._waiter = self._waiter, self._loop.create_future()
waiter.set_result((message, self._waiter))
async def subscribe(self):
waiter = self._waiter
while True:
message, waiter = await waiter
yield message
__aiter__ = subscribe

View File

@@ -1,5 +1,6 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Union
import orjson
import rapidjson
@@ -7,6 +8,7 @@ from pandas import DataFrame
from freqtrade.misc import dataframe_to_json, json_to_dataframe
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
logger = logging.getLogger(__name__)
@@ -24,7 +26,7 @@ class WebSocketSerializer(ABC):
def _deserialize(self, data):
raise NotImplementedError()
async def send(self, data: bytes):
async def send(self, data: Union[WSMessageSchemaType, Dict[str, Any]]):
await self._websocket.send(self._serialize(data))
async def recv(self) -> bytes:
@@ -32,8 +34,8 @@ class WebSocketSerializer(ABC):
return self._deserialize(data)
async def close(self, code: int = 1000):
await self._websocket.close(code)
# async def close(self, code: int = 1000):
# await self._websocket.close(code)
class HybridJSONWebSocketSerializer(WebSocketSerializer):