initial revision

This commit is contained in:
Timothy Pogue 2022-11-14 20:27:45 -07:00
parent a951b49541
commit 659c8c237f
7 changed files with 494 additions and 241 deletions

View File

@ -1,16 +1,17 @@
import asyncio
import logging import logging
from typing import Any, Dict from typing import Any, Dict
from fastapi import APIRouter, Depends, WebSocketDisconnect from fastapi import APIRouter, Depends
from fastapi.websockets import WebSocket, WebSocketState from fastapi.websockets import WebSocket, WebSocketDisconnect
from pydantic import ValidationError from pydantic import ValidationError
from websockets.exceptions import WebSocketException from websockets.exceptions import ConnectionClosed
from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.enums import RPCMessageType, RPCRequestType
from freqtrade.rpc.api_server.api_auth import validate_ws_token from freqtrade.rpc.api_server.api_auth import validate_ws_token
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc
from freqtrade.rpc.api_server.ws import WebSocketChannel from freqtrade.rpc.api_server.ws import WebSocketChannel
from freqtrade.rpc.api_server.ws.channel import ChannelManager from freqtrade.rpc.api_server.ws.message_stream import MessageStream
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
WSRequestSchema, WSWhitelistMessage) WSRequestSchema, WSWhitelistMessage)
from freqtrade.rpc.rpc import RPC from freqtrade.rpc.rpc import RPC
@ -22,23 +23,63 @@ logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
async def is_websocket_alive(ws: WebSocket) -> bool: # async def is_websocket_alive(ws: WebSocket) -> bool:
# """
# Check if a FastAPI Websocket is still open
# """
# if (
# ws.application_state == WebSocketState.CONNECTED and
# ws.client_state == WebSocketState.CONNECTED
# ):
# return True
# return False
class WebSocketChannelClosed(Exception):
""" """
Check if a FastAPI Websocket is still open General WebSocket exception to signal closing the channel
""" """
if ( pass
ws.application_state == WebSocketState.CONNECTED and
ws.client_state == WebSocketState.CONNECTED
async def channel_reader(channel: WebSocketChannel, rpc: RPC):
"""
Iterate over the messages from the channel and process the request
"""
try:
async for message in channel:
await _process_consumer_request(message, channel, rpc)
except (
RuntimeError,
WebSocketDisconnect,
ConnectionClosed
): ):
return True raise WebSocketChannelClosed
return False except asyncio.CancelledError:
return
async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream):
"""
Iterate over messages in the message stream and send them
"""
try:
async for message in message_stream:
await channel.send(message)
except (
RuntimeError,
WebSocketDisconnect,
ConnectionClosed
):
raise WebSocketChannelClosed
except asyncio.CancelledError:
return
async def _process_consumer_request( async def _process_consumer_request(
request: Dict[str, Any], request: Dict[str, Any],
channel: WebSocketChannel, channel: WebSocketChannel,
rpc: RPC, rpc: RPC
channel_manager: ChannelManager
): ):
""" """
Validate and handle a request from a websocket consumer Validate and handle a request from a websocket consumer
@ -75,7 +116,7 @@ async def _process_consumer_request(
# Format response # Format response
response = WSWhitelistMessage(data=whitelist) response = WSWhitelistMessage(data=whitelist)
# Send it back # Send it back
await channel_manager.send_direct(channel, response.dict(exclude_none=True)) await channel.send(response.dict(exclude_none=True))
elif type == RPCRequestType.ANALYZED_DF: elif type == RPCRequestType.ANALYZED_DF:
limit = None limit = None
@ -86,53 +127,76 @@ async def _process_consumer_request(
# For every pair in the generator, send a separate message # For every pair in the generator, send a separate message
for message in rpc._ws_request_analyzed_df(limit): for message in rpc._ws_request_analyzed_df(limit):
# Format response
response = WSAnalyzedDFMessage(data=message) response = WSAnalyzedDFMessage(data=message)
await channel_manager.send_direct(channel, response.dict(exclude_none=True)) await channel.send(response.dict(exclude_none=True))
@router.websocket("/message/ws") @router.websocket("/message/ws")
async def message_endpoint( async def message_endpoint(
ws: WebSocket, websocket: WebSocket,
token: str = Depends(validate_ws_token),
rpc: RPC = Depends(get_rpc), rpc: RPC = Depends(get_rpc),
channel_manager=Depends(get_channel_manager), message_stream: MessageStream = Depends(get_message_stream)
token: str = Depends(validate_ws_token)
): ):
""" async with WebSocketChannel(websocket).connect() as channel:
Message WebSocket endpoint, facilitates sending RPC messages try:
""" logger.info(f"Channel connected - {channel}")
try:
channel = await channel_manager.on_connect(ws)
if await is_websocket_alive(ws):
logger.info(f"Consumer connected - {channel}") channel_tasks = asyncio.gather(
channel_reader(channel, rpc),
channel_broadcaster(channel, message_stream)
)
await channel_tasks
# Keep connection open until explicitly closed, and process requests finally:
try: logger.info(f"Channel disconnected - {channel}")
while not channel.is_closed(): channel_tasks.cancel()
request = await channel.recv()
# Process the request here
await _process_consumer_request(request, channel, rpc, channel_manager)
except (WebSocketDisconnect, WebSocketException): # @router.websocket("/message/ws")
# Handle client disconnects # async def message_endpoint(
logger.info(f"Consumer disconnected - {channel}") # ws: WebSocket,
except RuntimeError: # rpc: RPC = Depends(get_rpc),
# Handle cases like - # channel_manager=Depends(get_channel_manager),
# RuntimeError('Cannot call "send" once a closed message has been sent') # token: str = Depends(validate_ws_token)
pass # ):
except Exception as e: # """
logger.info(f"Consumer connection failed - {channel}: {e}") # Message WebSocket endpoint, facilitates sending RPC messages
logger.debug(e, exc_info=e) # """
# try:
# channel = await channel_manager.on_connect(ws)
# if await is_websocket_alive(ws):
except RuntimeError: # logger.info(f"Consumer connected - {channel}")
# WebSocket was closed
# Do nothing # # Keep connection open until explicitly closed, and process requests
pass # try:
except Exception as e: # while not channel.is_closed():
logger.error(f"Failed to serve - {ws.client}") # request = await channel.recv()
# Log tracebacks to keep track of what errors are happening
logger.exception(e) # # Process the request here
finally: # await _process_consumer_request(request, channel, rpc, channel_manager)
if channel:
await channel_manager.on_disconnect(ws) # except (WebSocketDisconnect, WebSocketException):
# # Handle client disconnects
# logger.info(f"Consumer disconnected - {channel}")
# except RuntimeError:
# # Handle cases like -
# # RuntimeError('Cannot call "send" once a closed message has been sent')
# pass
# except Exception as e:
# logger.info(f"Consumer connection failed - {channel}: {e}")
# logger.debug(e, exc_info=e)
# except RuntimeError:
# # WebSocket was closed
# # Do nothing
# pass
# except Exception as e:
# logger.error(f"Failed to serve - {ws.client}")
# # Log tracebacks to keep track of what errors are happening
# logger.exception(e)
# finally:
# if channel:
# await channel_manager.on_disconnect(ws)

View File

@ -41,8 +41,8 @@ def get_exchange(config=Depends(get_config)):
return ApiServer._exchange return ApiServer._exchange
def get_channel_manager(): def get_message_stream():
return ApiServer._ws_channel_manager return ApiServer._message_stream
def is_webserver_mode(config=Depends(get_config)): def is_webserver_mode(config=Depends(get_config)):

View File

@ -1,7 +1,6 @@
import asyncio import asyncio
import logging import logging
from ipaddress import IPv4Address from ipaddress import IPv4Address
from threading import Thread
from typing import Any, Dict from typing import Any, Dict
import orjson import orjson
@ -15,7 +14,7 @@ from starlette.responses import JSONResponse
from freqtrade.constants import Config from freqtrade.constants import Config
from freqtrade.exceptions import OperationalException from freqtrade.exceptions import OperationalException
from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer
from freqtrade.rpc.api_server.ws import ChannelManager from freqtrade.rpc.api_server.ws.message_stream import MessageStream
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler
@ -51,9 +50,10 @@ class ApiServer(RPCHandler):
# Exchange - only available in webserver mode. # Exchange - only available in webserver mode.
_exchange = None _exchange = None
# websocket message queue stuff # websocket message queue stuff
_ws_channel_manager = None # _ws_channel_manager = None
_ws_thread = None # _ws_thread = None
_ws_loop = None # _ws_loop = None
_message_stream = None
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
""" """
@ -71,14 +71,15 @@ class ApiServer(RPCHandler):
return return
self._standalone: bool = standalone self._standalone: bool = standalone
self._server = None self._server = None
self._ws_queue = None self._ws_queue = None
self._ws_background_task = None self._ws_publisher_task = None
ApiServer.__initialized = True ApiServer.__initialized = True
api_config = self._config['api_server'] api_config = self._config['api_server']
ApiServer._ws_channel_manager = ChannelManager() # ApiServer._ws_channel_manager = ChannelManager()
self.app = FastAPI(title="Freqtrade API", self.app = FastAPI(title="Freqtrade API",
docs_url='/docs' if api_config.get('enable_openapi', False) else None, docs_url='/docs' if api_config.get('enable_openapi', False) else None,
@ -107,18 +108,18 @@ class ApiServer(RPCHandler):
logger.info("Stopping API Server") logger.info("Stopping API Server")
self._server.cleanup() self._server.cleanup()
if self._ws_thread and self._ws_loop: # if self._ws_thread and self._ws_loop:
logger.info("Stopping API Server background tasks") # logger.info("Stopping API Server background tasks")
if self._ws_background_task: # if self._ws_background_task:
# Cancel the queue task # # Cancel the queue task
self._ws_background_task.cancel() # self._ws_background_task.cancel()
self._ws_thread.join() # self._ws_thread.join()
self._ws_thread = None # self._ws_thread = None
self._ws_loop = None # self._ws_loop = None
self._ws_background_task = None # self._ws_background_task = None
@classmethod @classmethod
def shutdown(cls): def shutdown(cls):
@ -170,51 +171,102 @@ class ApiServer(RPCHandler):
) )
app.add_exception_handler(RPCException, self.handle_rpc_exception) app.add_exception_handler(RPCException, self.handle_rpc_exception)
app.add_event_handler(
event_type="startup",
func=self._api_startup_event
)
app.add_event_handler(
event_type="shutdown",
func=self._api_shutdown_event
)
def start_message_queue(self): async def _api_startup_event(self):
if self._ws_thread: if not ApiServer._message_stream:
return ApiServer._message_stream = MessageStream()
# Create a new loop, as it'll be just for the background thread if not self._ws_queue:
self._ws_loop = asyncio.new_event_loop() self._ws_queue = ThreadedQueue()
# Start the thread if not self._ws_publisher_task:
self._ws_thread = Thread(target=self._ws_loop.run_forever) self._ws_publisher_task = asyncio.create_task(
self._ws_thread.start() self._publish_messages()
)
# Finally, submit the coro to the thread async def _api_shutdown_event(self):
self._ws_background_task = asyncio.run_coroutine_threadsafe( if ApiServer._message_stream:
self._broadcast_queue_data(), loop=self._ws_loop) ApiServer._message_stream = None
async def _broadcast_queue_data(self): if self._ws_queue:
# Instantiate the queue in this coroutine so it's attached to our loop
self._ws_queue = ThreadedQueue()
async_queue = self._ws_queue.async_q
try:
while True:
logger.debug("Getting queue messages...")
# Get data from queue
message: WSMessageSchemaType = await async_queue.get()
logger.debug(f"Found message of type: {message.get('type')}")
async_queue.task_done()
# Broadcast it
await self._ws_channel_manager.broadcast(message)
except asyncio.CancelledError:
pass
# For testing, shouldn't happen when stable
except Exception as e:
logger.exception(f"Exception happened in background task: {e}")
finally:
# Disconnect channels and stop the loop on cancel
await self._ws_channel_manager.disconnect_all()
self._ws_loop.stop()
# Avoid adding more items to the queue if they aren't
# going to get broadcasted.
self._ws_queue = None self._ws_queue = None
if self._ws_publisher_task:
self._ws_publisher_task.cancel()
async def _publish_messages(self):
"""
Background task that reads messages from the queue and adds them
to the message stream
"""
try:
async_queue = self._ws_queue.async_q
message_stream = ApiServer._message_stream
while message_stream:
message: WSMessageSchemaType = await async_queue.get()
message_stream.publish(message)
# Make sure to throttle how fast we
# publish messages as some clients will be
# slower than others
await asyncio.sleep(0.01)
async_queue.task_done()
finally:
self._ws_queue = None
# def start_message_queue(self):
# if self._ws_thread:
# return
# # Create a new loop, as it'll be just for the background thread
# self._ws_loop = asyncio.new_event_loop()
# # Start the thread
# self._ws_thread = Thread(target=self._ws_loop.run_forever)
# self._ws_thread.start()
# # Finally, submit the coro to the thread
# self._ws_background_task = asyncio.run_coroutine_threadsafe(
# self._broadcast_queue_data(), loop=self._ws_loop)
# async def _broadcast_queue_data(self):
# # Instantiate the queue in this coroutine so it's attached to our loop
# self._ws_queue = ThreadedQueue()
# async_queue = self._ws_queue.async_q
# try:
# while True:
# logger.debug("Getting queue messages...")
# # Get data from queue
# message: WSMessageSchemaType = await async_queue.get()
# logger.debug(f"Found message of type: {message.get('type')}")
# async_queue.task_done()
# # Broadcast it
# await self._ws_channel_manager.broadcast(message)
# except asyncio.CancelledError:
# pass
# # For testing, shouldn't happen when stable
# except Exception as e:
# logger.exception(f"Exception happened in background task: {e}")
# finally:
# # Disconnect channels and stop the loop on cancel
# await self._ws_channel_manager.disconnect_all()
# self._ws_loop.stop()
# # Avoid adding more items to the queue if they aren't
# # going to get broadcasted.
# self._ws_queue = None
def start_api(self): def start_api(self):
""" """
Start API ... should be run in thread. Start API ... should be run in thread.
@ -253,7 +305,7 @@ class ApiServer(RPCHandler):
if self._standalone: if self._standalone:
self._server.run() self._server.run()
else: else:
self.start_message_queue() # self.start_message_queue()
self._server.run_in_thread() self._server.run_in_thread()
except Exception: except Exception:
logger.exception("Api server failed to start.") logger.exception("Api server failed to start.")

View File

@ -3,4 +3,5 @@
from freqtrade.rpc.api_server.ws.types import WebSocketType from freqtrade.rpc.api_server.ws.types import WebSocketType
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import HybridJSONWebSocketSerializer from freqtrade.rpc.api_server.ws.serializer import HybridJSONWebSocketSerializer
from freqtrade.rpc.api_server.ws.channel import ChannelManager, WebSocketChannel from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
from freqtrade.rpc.api_server.ws.message_stream import MessageStream

View File

@ -1,12 +1,9 @@
import asyncio import asyncio
import logging import logging
import time from contextlib import asynccontextmanager
from threading import RLock
from typing import Any, Dict, List, Optional, Type, Union from typing import Any, Dict, List, Optional, Type, Union
from uuid import uuid4 from uuid import uuid4
from fastapi import WebSocket as FastAPIWebSocket
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer, from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer,
WebSocketSerializer) WebSocketSerializer)
@ -21,32 +18,21 @@ class WebSocketChannel:
""" """
Object to help facilitate managing a websocket connection Object to help facilitate managing a websocket connection
""" """
def __init__( def __init__(
self, self,
websocket: WebSocketType, websocket: WebSocketType,
channel_id: Optional[str] = None, channel_id: Optional[str] = None,
drain_timeout: int = 3,
throttle: float = 0.01,
serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer
): ):
self.channel_id = channel_id if channel_id else uuid4().hex[:8] self.channel_id = channel_id if channel_id else uuid4().hex[:8]
# The WebSocket object
self._websocket = WebSocketProxy(websocket) self._websocket = WebSocketProxy(websocket)
self.drain_timeout = drain_timeout
self.throttle = throttle
self._subscriptions: List[str] = []
# 32 is the size of the receiving queue in websockets package
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
self._relay_task = asyncio.create_task(self.relay())
# Internal event to signify a closed websocket # Internal event to signify a closed websocket
self._closed = asyncio.Event() self._closed = asyncio.Event()
# Throttle how fast we send messages
self._throttle = 0.01
# Wrap the WebSocket in the Serializing class # Wrap the WebSocket in the Serializing class
self._wrapped_ws = serializer_cls(self._websocket) self._wrapped_ws = serializer_cls(self._websocket)
@ -61,40 +47,16 @@ class WebSocketChannel:
def remote_addr(self): def remote_addr(self):
return self._websocket.remote_addr return self._websocket.remote_addr
async def _send(self, data): async def send(self, message: Union[WSMessageSchemaType, Dict[str, Any]]):
""" """
Send data on the wrapped websocket Send a message on the wrapped websocket
""" """
await self._wrapped_ws.send(data) await asyncio.sleep(self._throttle)
await self._wrapped_ws.send(message)
async def send(self, data) -> bool:
"""
Add the data to the queue to be sent.
:returns: True if data added to queue, False otherwise
"""
# This block only runs if the queue is full, it will wait
# until self.drain_timeout for the relay to drain the outgoing queue
# We can't use asyncio.wait_for here because the queue may have been created with a
# different eventloop
start = time.time()
while self.queue.full():
await asyncio.sleep(1)
if (time.time() - start) > self.drain_timeout:
return False
# If for some reason the queue is still full, just return False
try:
self.queue.put_nowait(data)
except asyncio.QueueFull:
return False
# If we got here everything is ok
return True
async def recv(self): async def recv(self):
""" """
Receive data on the wrapped websocket Receive a message on the wrapped websocket
""" """
return await self._wrapped_ws.recv() return await self._wrapped_ws.recv()
@ -104,18 +66,23 @@ class WebSocketChannel:
""" """
return await self._websocket.ping() return await self._websocket.ping()
async def accept(self):
"""
Accept the underlying websocket connection
"""
return await self._websocket.accept()
async def close(self): async def close(self):
""" """
Close the WebSocketChannel Close the WebSocketChannel
""" """
try: try:
await self.raw_websocket.close() await self._websocket.close()
except Exception: except Exception:
pass pass
self._closed.set() self._closed.set()
self._relay_task.cancel()
def is_closed(self) -> bool: def is_closed(self) -> bool:
""" """
@ -139,99 +106,243 @@ class WebSocketChannel:
""" """
return message_type in self._subscriptions return message_type in self._subscriptions
async def relay(self): async def __aiter__(self):
""" """
Relay messages from the channel's queue and send them out. This is started Generator for received messages
as a task.
""" """
while not self._closed.is_set(): while True:
message = await self.queue.get()
try: try:
await self._send(message) yield await self.recv()
self.queue.task_done() except Exception:
break
# Limit messages per sec. @asynccontextmanager
# Could cause problems with queue size if too low, and async def connect(self):
# problems with network traffik if too high.
# 0.01 = 100/s
await asyncio.sleep(self.throttle)
except RuntimeError:
# The connection was closed, just exit the task
return
class ChannelManager:
def __init__(self):
self.channels = dict()
self._lock = RLock() # Re-entrant Lock
async def on_connect(self, websocket: WebSocketType):
""" """
Wrap websocket connection into Channel and add to list Context manager for safely opening and closing the websocket connection
:param websocket: The WebSocket object to attach to the Channel
""" """
if isinstance(websocket, FastAPIWebSocket): try:
try: await self.accept()
await websocket.accept() yield self
except RuntimeError: finally:
# The connection was closed before we could accept it await self.close()
return
ws_channel = WebSocketChannel(websocket)
with self._lock: # class WebSocketChannel:
self.channels[websocket] = ws_channel # """
# Object to help facilitate managing a websocket connection
# """
return ws_channel # def __init__(
# self,
# websocket: WebSocketType,
# channel_id: Optional[str] = None,
# drain_timeout: int = 3,
# throttle: float = 0.01,
# serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer
# ):
async def on_disconnect(self, websocket: WebSocketType): # self.channel_id = channel_id if channel_id else uuid4().hex[:8]
"""
Call close on the channel if it's not, and remove from channel list
:param websocket: The WebSocket objet attached to the Channel # # The WebSocket object
""" # self._websocket = WebSocketProxy(websocket)
with self._lock:
channel = self.channels.get(websocket)
if channel:
logger.info(f"Disconnecting channel {channel}")
if not channel.is_closed():
await channel.close()
del self.channels[websocket] # self.drain_timeout = drain_timeout
# self.throttle = throttle
async def disconnect_all(self): # self._subscriptions: List[str] = []
""" # # 32 is the size of the receiving queue in websockets package
Disconnect all Channels # self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
""" # self._relay_task = asyncio.create_task(self.relay())
with self._lock:
for websocket in self.channels.copy().keys():
await self.on_disconnect(websocket)
async def broadcast(self, message: WSMessageSchemaType): # # Internal event to signify a closed websocket
""" # self._closed = asyncio.Event()
Broadcast a message on all Channels
:param message: The message to send # # Wrap the WebSocket in the Serializing class
""" # self._wrapped_ws = serializer_cls(self._websocket)
with self._lock:
for channel in self.channels.copy().values():
if channel.subscribed_to(message.get('type')):
await self.send_direct(channel, message)
async def send_direct( # def __repr__(self):
self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]): # return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
"""
Send a message directly through direct_channel only
:param direct_channel: The WebSocketChannel object to send the message through # @property
:param message: The message to send # def raw_websocket(self):
""" # return self._websocket.raw_websocket
if not await channel.send(message):
await self.on_disconnect(channel.raw_websocket)
def has_channels(self): # @property
""" # def remote_addr(self):
Flag for more than 0 channels # return self._websocket.remote_addr
"""
return len(self.channels) > 0 # async def _send(self, data):
# """
# Send data on the wrapped websocket
# """
# await self._wrapped_ws.send(data)
# async def send(self, data) -> bool:
# """
# Add the data to the queue to be sent.
# :returns: True if data added to queue, False otherwise
# """
# # This block only runs if the queue is full, it will wait
# # until self.drain_timeout for the relay to drain the outgoing queue
# # We can't use asyncio.wait_for here because the queue may have been created with a
# # different eventloop
# start = time.time()
# while self.queue.full():
# await asyncio.sleep(1)
# if (time.time() - start) > self.drain_timeout:
# return False
# # If for some reason the queue is still full, just return False
# try:
# self.queue.put_nowait(data)
# except asyncio.QueueFull:
# return False
# # If we got here everything is ok
# return True
# async def recv(self):
# """
# Receive data on the wrapped websocket
# """
# return await self._wrapped_ws.recv()
# async def ping(self):
# """
# Ping the websocket
# """
# return await self._websocket.ping()
# async def close(self):
# """
# Close the WebSocketChannel
# """
# try:
# await self.raw_websocket.close()
# except Exception:
# pass
# self._closed.set()
# self._relay_task.cancel()
# def is_closed(self) -> bool:
# """
# Closed flag
# """
# return self._closed.is_set()
# def set_subscriptions(self, subscriptions: List[str] = []) -> None:
# """
# Set which subscriptions this channel is subscribed to
# :param subscriptions: List of subscriptions, List[str]
# """
# self._subscriptions = subscriptions
# def subscribed_to(self, message_type: str) -> bool:
# """
# Check if this channel is subscribed to the message_type
# :param message_type: The message type to check
# """
# return message_type in self._subscriptions
# async def relay(self):
# """
# Relay messages from the channel's queue and send them out. This is started
# as a task.
# """
# while not self._closed.is_set():
# message = await self.queue.get()
# try:
# await self._send(message)
# self.queue.task_done()
# # Limit messages per sec.
# # Could cause problems with queue size if too low, and
# # problems with network traffik if too high.
# # 0.01 = 100/s
# await asyncio.sleep(self.throttle)
# except RuntimeError:
# # The connection was closed, just exit the task
# return
# class ChannelManager:
# def __init__(self):
# self.channels = dict()
# self._lock = RLock() # Re-entrant Lock
# async def on_connect(self, websocket: WebSocketType):
# """
# Wrap websocket connection into Channel and add to list
# :param websocket: The WebSocket object to attach to the Channel
# """
# if isinstance(websocket, FastAPIWebSocket):
# try:
# await websocket.accept()
# except RuntimeError:
# # The connection was closed before we could accept it
# return
# ws_channel = WebSocketChannel(websocket)
# with self._lock:
# self.channels[websocket] = ws_channel
# return ws_channel
# async def on_disconnect(self, websocket: WebSocketType):
# """
# Call close on the channel if it's not, and remove from channel list
# :param websocket: The WebSocket objet attached to the Channel
# """
# with self._lock:
# channel = self.channels.get(websocket)
# if channel:
# logger.info(f"Disconnecting channel {channel}")
# if not channel.is_closed():
# await channel.close()
# del self.channels[websocket]
# async def disconnect_all(self):
# """
# Disconnect all Channels
# """
# with self._lock:
# for websocket in self.channels.copy().keys():
# await self.on_disconnect(websocket)
# async def broadcast(self, message: WSMessageSchemaType):
# """
# Broadcast a message on all Channels
# :param message: The message to send
# """
# with self._lock:
# for channel in self.channels.copy().values():
# if channel.subscribed_to(message.get('type')):
# await self.send_direct(channel, message)
# async def send_direct(
# self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]):
# """
# Send a message directly through direct_channel only
# :param direct_channel: The WebSocketChannel object to send the message through
# :param message: The message to send
# """
# if not await channel.send(message):
# await self.on_disconnect(channel.raw_websocket)
# def has_channels(self):
# """
# Flag for more than 0 channels
# """
# return len(self.channels) > 0

View File

@ -0,0 +1,23 @@
import asyncio
class MessageStream:
"""
A message stream for consumers to subscribe to,
and for producers to publish to.
"""
def __init__(self):
self._loop = asyncio.get_running_loop()
self._waiter = self._loop.create_future()
def publish(self, message):
waiter, self._waiter = self._waiter, self._loop.create_future()
waiter.set_result((message, self._waiter))
async def subscribe(self):
waiter = self._waiter
while True:
message, waiter = await waiter
yield message
__aiter__ = subscribe

View File

@ -1,5 +1,6 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Union
import orjson import orjson
import rapidjson import rapidjson
@ -7,6 +8,7 @@ from pandas import DataFrame
from freqtrade.misc import dataframe_to_json, json_to_dataframe from freqtrade.misc import dataframe_to_json, json_to_dataframe
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,7 +26,7 @@ class WebSocketSerializer(ABC):
def _deserialize(self, data): def _deserialize(self, data):
raise NotImplementedError() raise NotImplementedError()
async def send(self, data: bytes): async def send(self, data: Union[WSMessageSchemaType, Dict[str, Any]]):
await self._websocket.send(self._serialize(data)) await self._websocket.send(self._serialize(data))
async def recv(self) -> bytes: async def recv(self) -> bytes:
@ -32,8 +34,8 @@ class WebSocketSerializer(ABC):
return self._deserialize(data) return self._deserialize(data)
async def close(self, code: int = 1000): # async def close(self, code: int = 1000):
await self._websocket.close(code) # await self._websocket.close(code)
class HybridJSONWebSocketSerializer(WebSocketSerializer): class HybridJSONWebSocketSerializer(WebSocketSerializer):