initial revision
This commit is contained in:
@@ -1,16 +1,17 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, WebSocketDisconnect
|
||||
from fastapi.websockets import WebSocket, WebSocketState
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.websockets import WebSocket, WebSocketDisconnect
|
||||
from pydantic import ValidationError
|
||||
from websockets.exceptions import WebSocketException
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from freqtrade.enums import RPCMessageType, RPCRequestType
|
||||
from freqtrade.rpc.api_server.api_auth import validate_ws_token
|
||||
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
|
||||
from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc
|
||||
from freqtrade.rpc.api_server.ws import WebSocketChannel
|
||||
from freqtrade.rpc.api_server.ws.channel import ChannelManager
|
||||
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
|
||||
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
|
||||
WSRequestSchema, WSWhitelistMessage)
|
||||
from freqtrade.rpc.rpc import RPC
|
||||
@@ -22,23 +23,63 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def is_websocket_alive(ws: WebSocket) -> bool:
|
||||
# async def is_websocket_alive(ws: WebSocket) -> bool:
|
||||
# """
|
||||
# Check if a FastAPI Websocket is still open
|
||||
# """
|
||||
# if (
|
||||
# ws.application_state == WebSocketState.CONNECTED and
|
||||
# ws.client_state == WebSocketState.CONNECTED
|
||||
# ):
|
||||
# return True
|
||||
# return False
|
||||
|
||||
|
||||
class WebSocketChannelClosed(Exception):
|
||||
"""
|
||||
Check if a FastAPI Websocket is still open
|
||||
General WebSocket exception to signal closing the channel
|
||||
"""
|
||||
if (
|
||||
ws.application_state == WebSocketState.CONNECTED and
|
||||
ws.client_state == WebSocketState.CONNECTED
|
||||
pass
|
||||
|
||||
|
||||
async def channel_reader(channel: WebSocketChannel, rpc: RPC):
|
||||
"""
|
||||
Iterate over the messages from the channel and process the request
|
||||
"""
|
||||
try:
|
||||
async for message in channel:
|
||||
await _process_consumer_request(message, channel, rpc)
|
||||
except (
|
||||
RuntimeError,
|
||||
WebSocketDisconnect,
|
||||
ConnectionClosed
|
||||
):
|
||||
return True
|
||||
return False
|
||||
raise WebSocketChannelClosed
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
|
||||
async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream):
|
||||
"""
|
||||
Iterate over messages in the message stream and send them
|
||||
"""
|
||||
try:
|
||||
async for message in message_stream:
|
||||
await channel.send(message)
|
||||
except (
|
||||
RuntimeError,
|
||||
WebSocketDisconnect,
|
||||
ConnectionClosed
|
||||
):
|
||||
raise WebSocketChannelClosed
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
|
||||
async def _process_consumer_request(
|
||||
request: Dict[str, Any],
|
||||
channel: WebSocketChannel,
|
||||
rpc: RPC,
|
||||
channel_manager: ChannelManager
|
||||
rpc: RPC
|
||||
):
|
||||
"""
|
||||
Validate and handle a request from a websocket consumer
|
||||
@@ -75,7 +116,7 @@ async def _process_consumer_request(
|
||||
# Format response
|
||||
response = WSWhitelistMessage(data=whitelist)
|
||||
# Send it back
|
||||
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
|
||||
await channel.send(response.dict(exclude_none=True))
|
||||
|
||||
elif type == RPCRequestType.ANALYZED_DF:
|
||||
limit = None
|
||||
@@ -86,53 +127,76 @@ async def _process_consumer_request(
|
||||
|
||||
# For every pair in the generator, send a separate message
|
||||
for message in rpc._ws_request_analyzed_df(limit):
|
||||
# Format response
|
||||
response = WSAnalyzedDFMessage(data=message)
|
||||
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
|
||||
await channel.send(response.dict(exclude_none=True))
|
||||
|
||||
|
||||
@router.websocket("/message/ws")
|
||||
async def message_endpoint(
|
||||
ws: WebSocket,
|
||||
websocket: WebSocket,
|
||||
token: str = Depends(validate_ws_token),
|
||||
rpc: RPC = Depends(get_rpc),
|
||||
channel_manager=Depends(get_channel_manager),
|
||||
token: str = Depends(validate_ws_token)
|
||||
message_stream: MessageStream = Depends(get_message_stream)
|
||||
):
|
||||
"""
|
||||
Message WebSocket endpoint, facilitates sending RPC messages
|
||||
"""
|
||||
try:
|
||||
channel = await channel_manager.on_connect(ws)
|
||||
if await is_websocket_alive(ws):
|
||||
async with WebSocketChannel(websocket).connect() as channel:
|
||||
try:
|
||||
logger.info(f"Channel connected - {channel}")
|
||||
|
||||
logger.info(f"Consumer connected - {channel}")
|
||||
channel_tasks = asyncio.gather(
|
||||
channel_reader(channel, rpc),
|
||||
channel_broadcaster(channel, message_stream)
|
||||
)
|
||||
await channel_tasks
|
||||
|
||||
# Keep connection open until explicitly closed, and process requests
|
||||
try:
|
||||
while not channel.is_closed():
|
||||
request = await channel.recv()
|
||||
finally:
|
||||
logger.info(f"Channel disconnected - {channel}")
|
||||
channel_tasks.cancel()
|
||||
|
||||
# Process the request here
|
||||
await _process_consumer_request(request, channel, rpc, channel_manager)
|
||||
|
||||
except (WebSocketDisconnect, WebSocketException):
|
||||
# Handle client disconnects
|
||||
logger.info(f"Consumer disconnected - {channel}")
|
||||
except RuntimeError:
|
||||
# Handle cases like -
|
||||
# RuntimeError('Cannot call "send" once a closed message has been sent')
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.info(f"Consumer connection failed - {channel}: {e}")
|
||||
logger.debug(e, exc_info=e)
|
||||
# @router.websocket("/message/ws")
|
||||
# async def message_endpoint(
|
||||
# ws: WebSocket,
|
||||
# rpc: RPC = Depends(get_rpc),
|
||||
# channel_manager=Depends(get_channel_manager),
|
||||
# token: str = Depends(validate_ws_token)
|
||||
# ):
|
||||
# """
|
||||
# Message WebSocket endpoint, facilitates sending RPC messages
|
||||
# """
|
||||
# try:
|
||||
# channel = await channel_manager.on_connect(ws)
|
||||
# if await is_websocket_alive(ws):
|
||||
|
||||
except RuntimeError:
|
||||
# WebSocket was closed
|
||||
# Do nothing
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to serve - {ws.client}")
|
||||
# Log tracebacks to keep track of what errors are happening
|
||||
logger.exception(e)
|
||||
finally:
|
||||
if channel:
|
||||
await channel_manager.on_disconnect(ws)
|
||||
# logger.info(f"Consumer connected - {channel}")
|
||||
|
||||
# # Keep connection open until explicitly closed, and process requests
|
||||
# try:
|
||||
# while not channel.is_closed():
|
||||
# request = await channel.recv()
|
||||
|
||||
# # Process the request here
|
||||
# await _process_consumer_request(request, channel, rpc, channel_manager)
|
||||
|
||||
# except (WebSocketDisconnect, WebSocketException):
|
||||
# # Handle client disconnects
|
||||
# logger.info(f"Consumer disconnected - {channel}")
|
||||
# except RuntimeError:
|
||||
# # Handle cases like -
|
||||
# # RuntimeError('Cannot call "send" once a closed message has been sent')
|
||||
# pass
|
||||
# except Exception as e:
|
||||
# logger.info(f"Consumer connection failed - {channel}: {e}")
|
||||
# logger.debug(e, exc_info=e)
|
||||
|
||||
# except RuntimeError:
|
||||
# # WebSocket was closed
|
||||
# # Do nothing
|
||||
# pass
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to serve - {ws.client}")
|
||||
# # Log tracebacks to keep track of what errors are happening
|
||||
# logger.exception(e)
|
||||
# finally:
|
||||
# if channel:
|
||||
# await channel_manager.on_disconnect(ws)
|
||||
|
||||
Reference in New Issue
Block a user