import logging import time from typing import Any, Dict from fastapi import APIRouter, Depends from fastapi.websockets import WebSocket from pydantic import ValidationError from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.misc import sync_to_async_iter from freqtrade.rpc.api_server.api_auth import validate_ws_token from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel from freqtrade.rpc.api_server.ws.message_stream import MessageStream from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, WSRequestSchema, WSWhitelistMessage) from freqtrade.rpc.rpc import RPC logger = logging.getLogger(__name__) # Private router, protected by API Key authentication router = APIRouter() async def channel_reader(channel: WebSocketChannel, rpc: RPC): """ Iterate over the messages from the channel and process the request """ async for message in channel: await _process_consumer_request(message, channel, rpc) async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream): """ Iterate over messages in the message stream and send them """ async for message, ts in message_stream: if channel.subscribed_to(message.get('type')): # Log a warning if this channel is behind # on the message stream by a lot if (time.time() - ts) > 60: logger.warning(f"Channel {channel} is behind MessageStream by 1 minute," " this can cause a memory leak if you see this message" " often, consider reducing pair list size or amount of" " consumers.") await channel.send(message, timeout=True) async def _process_consumer_request( request: Dict[str, Any], channel: WebSocketChannel, rpc: RPC ): """ Validate and handle a request from a websocket consumer """ # Validate the request, makes sure it matches the schema try: websocket_request = WSRequestSchema.parse_obj(request) except ValidationError as e: logger.error(f"Invalid request from {channel}: {e}") return type, data = websocket_request.type, websocket_request.data response: WSMessageSchema logger.debug(f"Request of type {type} from {channel}") # If we have a request of type SUBSCRIBE, set the topics in this channel if type == RPCRequestType.SUBSCRIBE: # If the request is empty, do nothing if not data: return # If all topics passed are a valid RPCMessageType, set subscriptions on channel if all([any(x.value == topic for x in RPCMessageType) for topic in data]): channel.set_subscriptions(data) # We don't send a response for subscriptions return elif type == RPCRequestType.WHITELIST: # Get whitelist whitelist = rpc._ws_request_whitelist() # Format response response = WSWhitelistMessage(data=whitelist) await channel.send(response.dict(exclude_none=True)) elif type == RPCRequestType.ANALYZED_DF: # Limit the amount of candles per dataframe to 'limit' or 1500 limit = min(data.get('limit', 1500), 1500) if data else None # For every pair in the generator, send a separate message async for message in sync_to_async_iter(rpc._ws_request_analyzed_df(limit)): # Format response response = WSAnalyzedDFMessage(data=message) await channel.send(response.dict(exclude_none=True)) @router.websocket("/message/ws") async def message_endpoint( websocket: WebSocket, token: str = Depends(validate_ws_token), rpc: RPC = Depends(get_rpc), message_stream: MessageStream = Depends(get_message_stream) ): if token: async with create_channel(websocket) as channel: await channel.run_channel_tasks( channel_reader(channel, rpc), channel_broadcaster(channel, message_stream) )