import logging from typing import Any, Dict from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect from pydantic import ValidationError # fastapi does not make this available through it, so import directly from starlette from starlette.websockets import WebSocketState from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.rpc.api_server.api_auth import get_ws_token from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc from freqtrade.rpc.api_server.ws.channel import WebSocketChannel from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, WSRequestSchema, WSWhitelistMessage) from freqtrade.rpc.rpc import RPC logger = logging.getLogger(__name__) # Private router, protected by API Key authentication router = APIRouter() async def is_websocket_alive(ws: WebSocket) -> bool: """ Check if a FastAPI Websocket is still open """ if ( ws.application_state == WebSocketState.CONNECTED and ws.client_state == WebSocketState.CONNECTED ): return True return False async def _process_consumer_request( request: Dict[str, Any], channel: WebSocketChannel, rpc: RPC ): """ Validate and handle a request from a websocket consumer """ # Validate the request, makes sure it matches the schema try: websocket_request = WSRequestSchema.parse_obj(request) except ValidationError as e: logger.error(f"Invalid request from {channel}: {e}") return type, data = websocket_request.type, websocket_request.data response: WSMessageSchema logger.debug(f"Request of type {type} from {channel}") # If we have a request of type SUBSCRIBE, set the topics in this channel if type == RPCRequestType.SUBSCRIBE: # If the request is empty, do nothing if not data: return # If all topics passed are a valid RPCMessageType, set subscriptions on channel if all([any(x.value == topic for x in RPCMessageType) for topic in data]): channel.set_subscriptions(data) # We don't send a response for subscriptions return elif type == RPCRequestType.WHITELIST: # Get whitelist whitelist = rpc._ws_request_whitelist() # Format response response = WSWhitelistMessage(data=whitelist) # Send it back await channel.send(response.dict(exclude_none=True)) elif type == RPCRequestType.ANALYZED_DF: limit = None if data: # Limit the amount of candles per dataframe to 'limit' or 1500 limit = max(data.get('limit', 1500), 1500) # They requested the full historical analyzed dataframes analyzed_df = rpc._ws_request_analyzed_df(limit) # For every dataframe, send as a separate message for _, message in analyzed_df.items(): response = WSAnalyzedDFMessage(data=message) await channel.send(response.dict(exclude_none=True)) @router.websocket("/message/ws") async def message_endpoint( ws: WebSocket, rpc: RPC = Depends(get_rpc), channel_manager=Depends(get_channel_manager), token: str = Depends(get_ws_token) ): """ Message WebSocket endpoint, facilitates sending RPC messages """ try: # TODO: # Return a channel ID, pass that instead of ws to the rest of the methods channel = await channel_manager.on_connect(ws) if await is_websocket_alive(ws): logger.info(f"Consumer connected - {channel}") # Keep connection open until explicitly closed, and process requests try: while not channel.is_closed(): request = await channel.recv() # Process the request here await _process_consumer_request(request, channel, rpc) except WebSocketDisconnect: # Handle client disconnects logger.info(f"Consumer disconnected - {channel}") await channel_manager.on_disconnect(ws) except Exception as e: logger.info(f"Consumer connection failed - {channel}") logger.exception(e) # Handle cases like - # RuntimeError('Cannot call "send" once a closed message has been sent') await channel_manager.on_disconnect(ws) else: ws.close() except Exception as e: logger.error(f"Failed to serve - {ws.client}") # Log tracebacks to keep track of what errors are happening logger.exception(e) await channel_manager.on_disconnect(ws)