From b1c02674492993571c1cbc10144277eb29fda7a9 Mon Sep 17 00:00:00 2001 From: Timothy Pogue Date: Tue, 6 Sep 2022 12:40:58 -0600 Subject: [PATCH] mypy fixes --- freqtrade/rpc/api_server/ws/channel.py | 4 +++- freqtrade/rpc/api_server/ws/proxy.py | 10 ++++++---- freqtrade/rpc/api_server/ws/types.py | 2 +- freqtrade/rpc/external_message_consumer.py | 7 ++++++- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 952b3b9f5..cffe3092d 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -3,6 +3,8 @@ from threading import RLock from typing import List, Optional, Type from uuid import uuid4 +from fastapi import WebSocket as FastAPIWebSocket + from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer, WebSocketSerializer) @@ -105,7 +107,7 @@ class ChannelManager: :param websocket: The WebSocket object to attach to the Channel """ - if hasattr(websocket, "accept"): + if isinstance(websocket, FastAPIWebSocket): try: await websocket.accept() except RuntimeError: diff --git a/freqtrade/rpc/api_server/ws/proxy.py b/freqtrade/rpc/api_server/ws/proxy.py index e43ce6441..da3e04887 100644 --- a/freqtrade/rpc/api_server/ws/proxy.py +++ b/freqtrade/rpc/api_server/ws/proxy.py @@ -1,7 +1,7 @@ from typing import Any, Tuple, Union from fastapi import WebSocket as FastAPIWebSocket -from websockets import WebSocketClientProtocol as WebSocket +from websockets.client import WebSocketClientProtocol as WebSocket from freqtrade.rpc.api_server.ws.types import WebSocketType @@ -17,10 +17,12 @@ class WebSocketProxy: @property def remote_addr(self) -> Tuple[Any, ...]: - if hasattr(self._websocket, "remote_address"): + if isinstance(self._websocket, WebSocket): return self._websocket.remote_address - elif hasattr(self._websocket, "client"): - return tuple(self._websocket.client) + elif isinstance(self._websocket, FastAPIWebSocket): + if self._websocket.client: + client, port = self._websocket.client.host, self._websocket.client.port + return (client, port) return ("unknown", 0) async def send(self, data): diff --git a/freqtrade/rpc/api_server/ws/types.py b/freqtrade/rpc/api_server/ws/types.py index 814fe6649..9855f9e06 100644 --- a/freqtrade/rpc/api_server/ws/types.py +++ b/freqtrade/rpc/api_server/ws/types.py @@ -1,7 +1,7 @@ from typing import Any, Dict, TypeVar from fastapi import WebSocket as FastAPIWebSocket -from websockets import WebSocketClientProtocol as WebSocket +from websockets.client import WebSocketClientProtocol as WebSocket WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket) diff --git a/freqtrade/rpc/external_message_consumer.py b/freqtrade/rpc/external_message_consumer.py index 28628c1f6..c1ad0512e 100644 --- a/freqtrade/rpc/external_message_consumer.py +++ b/freqtrade/rpc/external_message_consumer.py @@ -8,7 +8,7 @@ import asyncio import logging import socket from threading import Thread -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import websockets @@ -18,6 +18,11 @@ from freqtrade.misc import remove_entry_exit_signals from freqtrade.rpc.api_server.ws.channel import WebSocketChannel +if TYPE_CHECKING: + import websockets.connect + import websockets.exceptions + + logger = logging.getLogger(__name__)