Merge pull request #7621 from wizrds/fix/channel-api
Improved WebSocketChannel API
This commit is contained in:
commit
110db8b241
@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
@ -11,6 +10,7 @@ from freqtrade.enums import RPCMessageType, RPCRequestType
|
|||||||
from freqtrade.rpc.api_server.api_auth import validate_ws_token
|
from freqtrade.rpc.api_server.api_auth import validate_ws_token
|
||||||
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
|
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
|
||||||
from freqtrade.rpc.api_server.ws import WebSocketChannel
|
from freqtrade.rpc.api_server.ws import WebSocketChannel
|
||||||
|
from freqtrade.rpc.api_server.ws.channel import ChannelManager
|
||||||
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
|
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
|
||||||
WSRequestSchema, WSWhitelistMessage)
|
WSRequestSchema, WSWhitelistMessage)
|
||||||
from freqtrade.rpc.rpc import RPC
|
from freqtrade.rpc.rpc import RPC
|
||||||
@ -37,7 +37,8 @@ async def is_websocket_alive(ws: WebSocket) -> bool:
|
|||||||
async def _process_consumer_request(
|
async def _process_consumer_request(
|
||||||
request: Dict[str, Any],
|
request: Dict[str, Any],
|
||||||
channel: WebSocketChannel,
|
channel: WebSocketChannel,
|
||||||
rpc: RPC
|
rpc: RPC,
|
||||||
|
channel_manager: ChannelManager
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Validate and handle a request from a websocket consumer
|
Validate and handle a request from a websocket consumer
|
||||||
@ -74,7 +75,7 @@ async def _process_consumer_request(
|
|||||||
# Format response
|
# Format response
|
||||||
response = WSWhitelistMessage(data=whitelist)
|
response = WSWhitelistMessage(data=whitelist)
|
||||||
# Send it back
|
# Send it back
|
||||||
await channel.send(response.dict(exclude_none=True))
|
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
|
||||||
|
|
||||||
elif type == RPCRequestType.ANALYZED_DF:
|
elif type == RPCRequestType.ANALYZED_DF:
|
||||||
limit = None
|
limit = None
|
||||||
@ -89,9 +90,7 @@ async def _process_consumer_request(
|
|||||||
# For every dataframe, send as a separate message
|
# For every dataframe, send as a separate message
|
||||||
for _, message in analyzed_df.items():
|
for _, message in analyzed_df.items():
|
||||||
response = WSAnalyzedDFMessage(data=message)
|
response = WSAnalyzedDFMessage(data=message)
|
||||||
await channel.send(response.dict(exclude_none=True))
|
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
|
||||||
# Throttle the messages to 50/s
|
|
||||||
await asyncio.sleep(0.02)
|
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/message/ws")
|
@router.websocket("/message/ws")
|
||||||
@ -116,7 +115,7 @@ async def message_endpoint(
|
|||||||
request = await channel.recv()
|
request = await channel.recv()
|
||||||
|
|
||||||
# Process the request here
|
# Process the request here
|
||||||
await _process_consumer_request(request, channel, rpc)
|
await _process_consumer_request(request, channel, rpc, channel_manager)
|
||||||
|
|
||||||
except (WebSocketDisconnect, WebSocketException):
|
except (WebSocketDisconnect, WebSocketException):
|
||||||
# Handle client disconnects
|
# Handle client disconnects
|
||||||
|
@ -16,6 +16,7 @@ from freqtrade.constants import Config
|
|||||||
from freqtrade.exceptions import OperationalException
|
from freqtrade.exceptions import OperationalException
|
||||||
from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer
|
from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer
|
||||||
from freqtrade.rpc.api_server.ws import ChannelManager
|
from freqtrade.rpc.api_server.ws import ChannelManager
|
||||||
|
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
|
||||||
from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler
|
from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler
|
||||||
|
|
||||||
|
|
||||||
@ -127,7 +128,7 @@ class ApiServer(RPCHandler):
|
|||||||
cls._has_rpc = False
|
cls._has_rpc = False
|
||||||
cls._rpc = None
|
cls._rpc = None
|
||||||
|
|
||||||
def send_msg(self, msg: Dict[str, str]) -> None:
|
def send_msg(self, msg: Dict[str, Any]) -> None:
|
||||||
if self._ws_queue:
|
if self._ws_queue:
|
||||||
sync_q = self._ws_queue.sync_q
|
sync_q = self._ws_queue.sync_q
|
||||||
sync_q.put(msg)
|
sync_q.put(msg)
|
||||||
@ -194,14 +195,10 @@ class ApiServer(RPCHandler):
|
|||||||
while True:
|
while True:
|
||||||
logger.debug("Getting queue messages...")
|
logger.debug("Getting queue messages...")
|
||||||
# Get data from queue
|
# Get data from queue
|
||||||
message = await async_queue.get()
|
message: WSMessageSchemaType = await async_queue.get()
|
||||||
logger.debug(f"Found message of type: {message.get('type')}")
|
logger.debug(f"Found message of type: {message.get('type')}")
|
||||||
# Broadcast it
|
# Broadcast it
|
||||||
await self._ws_channel_manager.broadcast(message)
|
await self._ws_channel_manager.broadcast(message)
|
||||||
# Limit messages per sec.
|
|
||||||
# Could cause problems with queue size if too low, and
|
|
||||||
# problems with network traffik if too high.
|
|
||||||
await asyncio.sleep(0.001)
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
from typing import Any, Dict, List, Optional, Type
|
from typing import Any, Dict, List, Optional, Type, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from fastapi import WebSocket as FastAPIWebSocket
|
from fastapi import WebSocket as FastAPIWebSocket
|
||||||
@ -10,6 +10,7 @@ from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
|
|||||||
from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer,
|
from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer,
|
||||||
WebSocketSerializer)
|
WebSocketSerializer)
|
||||||
from freqtrade.rpc.api_server.ws.types import WebSocketType
|
from freqtrade.rpc.api_server.ws.types import WebSocketType
|
||||||
|
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -24,6 +25,8 @@ class WebSocketChannel:
|
|||||||
self,
|
self,
|
||||||
websocket: WebSocketType,
|
websocket: WebSocketType,
|
||||||
channel_id: Optional[str] = None,
|
channel_id: Optional[str] = None,
|
||||||
|
drain_timeout: int = 3,
|
||||||
|
throttle: float = 0.01,
|
||||||
serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer
|
serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer
|
||||||
):
|
):
|
||||||
|
|
||||||
@ -34,7 +37,11 @@ class WebSocketChannel:
|
|||||||
# The Serializing class for the WebSocket object
|
# The Serializing class for the WebSocket object
|
||||||
self._serializer_cls = serializer_cls
|
self._serializer_cls = serializer_cls
|
||||||
|
|
||||||
|
self.drain_timeout = drain_timeout
|
||||||
|
self.throttle = throttle
|
||||||
|
|
||||||
self._subscriptions: List[str] = []
|
self._subscriptions: List[str] = []
|
||||||
|
# 32 is the size of the receiving queue in websockets package
|
||||||
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
|
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
|
||||||
self._relay_task = asyncio.create_task(self.relay())
|
self._relay_task = asyncio.create_task(self.relay())
|
||||||
|
|
||||||
@ -47,6 +54,10 @@ class WebSocketChannel:
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
|
return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def raw_websocket(self):
|
||||||
|
return self._websocket.raw_websocket
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def remote_addr(self):
|
def remote_addr(self):
|
||||||
return self._websocket.remote_addr
|
return self._websocket.remote_addr
|
||||||
@ -57,11 +68,19 @@ class WebSocketChannel:
|
|||||||
"""
|
"""
|
||||||
await self._wrapped_ws.send(data)
|
await self._wrapped_ws.send(data)
|
||||||
|
|
||||||
async def send(self, data):
|
async def send(self, data) -> bool:
|
||||||
"""
|
"""
|
||||||
Add the data to the queue to be sent
|
Add the data to the queue to be sent.
|
||||||
|
:returns: True if data added to queue, False otherwise
|
||||||
"""
|
"""
|
||||||
self.queue.put_nowait(data)
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self.queue.put(data),
|
||||||
|
timeout=self.drain_timeout
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return False
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
"""
|
"""
|
||||||
@ -119,8 +138,8 @@ class WebSocketChannel:
|
|||||||
# Limit messages per sec.
|
# Limit messages per sec.
|
||||||
# Could cause problems with queue size if too low, and
|
# Could cause problems with queue size if too low, and
|
||||||
# problems with network traffik if too high.
|
# problems with network traffik if too high.
|
||||||
# 0.001 = 1000/s
|
# 0.01 = 100/s
|
||||||
await asyncio.sleep(0.001)
|
await asyncio.sleep(self.throttle)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# The connection was closed, just exit the task
|
# The connection was closed, just exit the task
|
||||||
return
|
return
|
||||||
@ -160,6 +179,7 @@ class ChannelManager:
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
channel = self.channels.get(websocket)
|
channel = self.channels.get(websocket)
|
||||||
if channel:
|
if channel:
|
||||||
|
logger.info(f"Disconnecting channel {channel}")
|
||||||
if not channel.is_closed():
|
if not channel.is_closed():
|
||||||
await channel.close()
|
await channel.close()
|
||||||
|
|
||||||
@ -170,36 +190,30 @@ class ChannelManager:
|
|||||||
Disconnect all Channels
|
Disconnect all Channels
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for websocket, channel in self.channels.copy().items():
|
for websocket in self.channels.copy().keys():
|
||||||
if not channel.is_closed():
|
await self.on_disconnect(websocket)
|
||||||
await channel.close()
|
|
||||||
|
|
||||||
self.channels = dict()
|
async def broadcast(self, message: WSMessageSchemaType):
|
||||||
|
|
||||||
async def broadcast(self, data):
|
|
||||||
"""
|
"""
|
||||||
Broadcast data on all Channels
|
Broadcast a message on all Channels
|
||||||
|
|
||||||
:param data: The data to send
|
:param message: The message to send
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
message_type = data.get('type')
|
for channel in self.channels.copy().values():
|
||||||
for websocket, channel in self.channels.copy().items():
|
if channel.subscribed_to(message.get('type')):
|
||||||
if channel.subscribed_to(message_type):
|
await self.send_direct(channel, message)
|
||||||
if not channel.queue.full():
|
|
||||||
await channel.send(data)
|
|
||||||
else:
|
|
||||||
logger.info(f"Channel {channel} is too far behind, disconnecting")
|
|
||||||
await self.on_disconnect(websocket)
|
|
||||||
|
|
||||||
async def send_direct(self, channel, data):
|
async def send_direct(
|
||||||
|
self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
Send data directly through direct_channel only
|
Send a message directly through direct_channel only
|
||||||
|
|
||||||
:param direct_channel: The WebSocketChannel object to send data through
|
:param direct_channel: The WebSocketChannel object to send the message through
|
||||||
:param data: The data to send
|
:param message: The message to send
|
||||||
"""
|
"""
|
||||||
await channel.send(data)
|
if not await channel.send(message):
|
||||||
|
await self.on_disconnect(channel.raw_websocket)
|
||||||
|
|
||||||
def has_channels(self):
|
def has_channels(self):
|
||||||
"""
|
"""
|
||||||
|
@ -15,6 +15,10 @@ class WebSocketProxy:
|
|||||||
def __init__(self, websocket: WebSocketType):
|
def __init__(self, websocket: WebSocketType):
|
||||||
self._websocket: Union[FastAPIWebSocket, WebSocket] = websocket
|
self._websocket: Union[FastAPIWebSocket, WebSocket] = websocket
|
||||||
|
|
||||||
|
@property
|
||||||
|
def raw_websocket(self):
|
||||||
|
return self._websocket
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def remote_addr(self) -> Tuple[Any, ...]:
|
def remote_addr(self) -> Tuple[Any, ...]:
|
||||||
if isinstance(self._websocket, WebSocket):
|
if isinstance(self._websocket, WebSocket):
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, TypedDict
|
||||||
|
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -18,6 +18,12 @@ class WSRequestSchema(BaseArbitraryModel):
|
|||||||
data: Optional[Any] = None
|
data: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class WSMessageSchemaType(TypedDict):
|
||||||
|
# Type for typing to avoid doing pydantic typechecks.
|
||||||
|
type: RPCMessageType
|
||||||
|
data: Optional[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class WSMessageSchema(BaseArbitraryModel):
|
class WSMessageSchema(BaseArbitraryModel):
|
||||||
type: RPCMessageType
|
type: RPCMessageType
|
||||||
data: Optional[Any] = None
|
data: Optional[Any] = None
|
||||||
|
@ -270,6 +270,11 @@ class ExternalMessageConsumer:
|
|||||||
logger.debug(f"Connection to {channel} still alive...")
|
logger.debug(f"Connection to {channel} still alive...")
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
except (websockets.exceptions.ConnectionClosed):
|
||||||
|
# Just eat the error and continue reconnecting
|
||||||
|
logger.warning(f"Disconnection in {channel} - retrying in {self.sleep_time}s")
|
||||||
|
await asyncio.sleep(self.sleep_time)
|
||||||
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Ping error {channel} - retrying in {self.sleep_time}s")
|
logger.warning(f"Ping error {channel} - retrying in {self.sleep_time}s")
|
||||||
logger.debug(e, exc_info=e)
|
logger.debug(e, exc_info=e)
|
||||||
|
Loading…
Reference in New Issue
Block a user