stable/freqtrade/rpc/api_server/api_ws.py

203 lines
6.3 KiB
Python
Raw Normal View History

2022-11-15 03:27:45 +00:00
import asyncio
import logging
from typing import Any, Dict
2022-11-15 03:27:45 +00:00
from fastapi import APIRouter, Depends
from fastapi.websockets import WebSocket, WebSocketDisconnect
2022-09-08 19:58:28 +00:00
from pydantic import ValidationError
2022-11-15 03:27:45 +00:00
from websockets.exceptions import ConnectionClosed
from freqtrade.enums import RPCMessageType, RPCRequestType
2022-09-10 12:19:11 +00:00
from freqtrade.rpc.api_server.api_auth import validate_ws_token
2022-11-15 03:27:45 +00:00
from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc
2022-09-13 18:42:24 +00:00
from freqtrade.rpc.api_server.ws import WebSocketChannel
2022-11-15 03:27:45 +00:00
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
2022-09-08 19:58:28 +00:00
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
WSRequestSchema, WSWhitelistMessage)
from freqtrade.rpc.rpc import RPC
logger = logging.getLogger(__name__)
# Private router, protected by API Key authentication
router = APIRouter()
2022-11-15 03:27:45 +00:00
# async def is_websocket_alive(ws: WebSocket) -> bool:
# """
# Check if a FastAPI Websocket is still open
# """
# if (
# ws.application_state == WebSocketState.CONNECTED and
# ws.client_state == WebSocketState.CONNECTED
# ):
# return True
# return False
class WebSocketChannelClosed(Exception):
"""
General WebSocket exception to signal closing the channel
"""
pass
async def channel_reader(channel: WebSocketChannel, rpc: RPC):
"""
Iterate over the messages from the channel and process the request
"""
try:
async for message in channel:
await _process_consumer_request(message, channel, rpc)
except (
RuntimeError,
WebSocketDisconnect,
ConnectionClosed
):
raise WebSocketChannelClosed
except asyncio.CancelledError:
return
async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream):
2022-09-07 21:08:01 +00:00
"""
2022-11-15 03:27:45 +00:00
Iterate over messages in the message stream and send them
2022-09-07 21:08:01 +00:00
"""
2022-11-15 03:27:45 +00:00
try:
async for message in message_stream:
await channel.send(message)
except (
RuntimeError,
WebSocketDisconnect,
ConnectionClosed
):
2022-11-15 03:27:45 +00:00
raise WebSocketChannelClosed
except asyncio.CancelledError:
return
async def _process_consumer_request(
request: Dict[str, Any],
channel: WebSocketChannel,
2022-11-15 03:27:45 +00:00
rpc: RPC
):
2022-09-07 21:08:01 +00:00
"""
Validate and handle a request from a websocket consumer
"""
# Validate the request, makes sure it matches the schema
try:
websocket_request = WSRequestSchema.parse_obj(request)
except ValidationError as e:
logger.error(f"Invalid request from {channel}: {e}")
return
type, data = websocket_request.type, websocket_request.data
2022-09-08 16:34:37 +00:00
response: WSMessageSchema
logger.debug(f"Request of type {type} from {channel}")
# If we have a request of type SUBSCRIBE, set the topics in this channel
if type == RPCRequestType.SUBSCRIBE:
# If the request is empty, do nothing
if not data:
return
# If all topics passed are a valid RPCMessageType, set subscriptions on channel
if all([any(x.value == topic for x in RPCMessageType) for topic in data]):
channel.set_subscriptions(data)
2022-09-07 21:08:01 +00:00
# We don't send a response for subscriptions
2022-09-08 16:34:37 +00:00
return
2022-09-07 21:08:01 +00:00
elif type == RPCRequestType.WHITELIST:
2022-09-07 21:08:01 +00:00
# Get whitelist
whitelist = rpc._ws_request_whitelist()
2022-09-07 21:08:01 +00:00
# Format response
response = WSWhitelistMessage(data=whitelist)
2022-09-07 21:08:01 +00:00
# Send it back
2022-11-15 03:27:45 +00:00
await channel.send(response.dict(exclude_none=True))
elif type == RPCRequestType.ANALYZED_DF:
limit = None
if data:
# Limit the amount of candles per dataframe to 'limit' or 1500
2022-09-07 21:08:01 +00:00
limit = max(data.get('limit', 1500), 1500)
# For every pair in the generator, send a separate message
for message in rpc._ws_request_analyzed_df(limit):
2022-11-15 03:27:45 +00:00
# Format response
response = WSAnalyzedDFMessage(data=message)
2022-11-15 03:27:45 +00:00
await channel.send(response.dict(exclude_none=True))
@router.websocket("/message/ws")
async def message_endpoint(
2022-11-15 03:27:45 +00:00
websocket: WebSocket,
token: str = Depends(validate_ws_token),
rpc: RPC = Depends(get_rpc),
2022-11-15 03:27:45 +00:00
message_stream: MessageStream = Depends(get_message_stream)
):
2022-11-15 03:27:45 +00:00
async with WebSocketChannel(websocket).connect() as channel:
try:
logger.info(f"Channel connected - {channel}")
channel_tasks = asyncio.gather(
channel_reader(channel, rpc),
channel_broadcaster(channel, message_stream)
)
await channel_tasks
finally:
logger.info(f"Channel disconnected - {channel}")
channel_tasks.cancel()
# @router.websocket("/message/ws")
# async def message_endpoint(
# ws: WebSocket,
# rpc: RPC = Depends(get_rpc),
# channel_manager=Depends(get_channel_manager),
# token: str = Depends(validate_ws_token)
# ):
# """
# Message WebSocket endpoint, facilitates sending RPC messages
# """
# try:
# channel = await channel_manager.on_connect(ws)
# if await is_websocket_alive(ws):
# logger.info(f"Consumer connected - {channel}")
# # Keep connection open until explicitly closed, and process requests
# try:
# while not channel.is_closed():
# request = await channel.recv()
# # Process the request here
# await _process_consumer_request(request, channel, rpc, channel_manager)
# except (WebSocketDisconnect, WebSocketException):
# # Handle client disconnects
# logger.info(f"Consumer disconnected - {channel}")
# except RuntimeError:
# # Handle cases like -
# # RuntimeError('Cannot call "send" once a closed message has been sent')
# pass
# except Exception as e:
# logger.info(f"Consumer connection failed - {channel}: {e}")
# logger.debug(e, exc_info=e)
# except RuntimeError:
# # WebSocket was closed
# # Do nothing
# pass
# except Exception as e:
# logger.error(f"Failed to serve - {ws.client}")
# # Log tracebacks to keep track of what errors are happening
# logger.exception(e)
# finally:
# if channel:
# await channel_manager.on_disconnect(ws)