mypy fixes

This commit is contained in:
Timothy Pogue 2022-09-06 12:40:58 -06:00
parent 3535aa7724
commit b1c0267449
4 changed files with 16 additions and 7 deletions

View File

@ -3,6 +3,8 @@ from threading import RLock
from typing import List, Optional, Type
from uuid import uuid4
from fastapi import WebSocket as FastAPIWebSocket
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer,
WebSocketSerializer)
@ -105,7 +107,7 @@ class ChannelManager:
:param websocket: The WebSocket object to attach to the Channel
"""
if hasattr(websocket, "accept"):
if isinstance(websocket, FastAPIWebSocket):
try:
await websocket.accept()
except RuntimeError:

View File

@ -1,7 +1,7 @@
from typing import Any, Tuple, Union
from fastapi import WebSocket as FastAPIWebSocket
from websockets import WebSocketClientProtocol as WebSocket
from websockets.client import WebSocketClientProtocol as WebSocket
from freqtrade.rpc.api_server.ws.types import WebSocketType
@ -17,10 +17,12 @@ class WebSocketProxy:
@property
def remote_addr(self) -> Tuple[Any, ...]:
if hasattr(self._websocket, "remote_address"):
if isinstance(self._websocket, WebSocket):
return self._websocket.remote_address
elif hasattr(self._websocket, "client"):
return tuple(self._websocket.client)
elif isinstance(self._websocket, FastAPIWebSocket):
if self._websocket.client:
client, port = self._websocket.client.host, self._websocket.client.port
return (client, port)
return ("unknown", 0)
async def send(self, data):

View File

@ -1,7 +1,7 @@
from typing import Any, Dict, TypeVar
from fastapi import WebSocket as FastAPIWebSocket
from websockets import WebSocketClientProtocol as WebSocket
from websockets.client import WebSocketClientProtocol as WebSocket
WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket)

View File

@ -8,7 +8,7 @@ import asyncio
import logging
import socket
from threading import Thread
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import websockets
@ -18,6 +18,11 @@ from freqtrade.misc import remove_entry_exit_signals
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
if TYPE_CHECKING:
import websockets.connect
import websockets.exceptions
logger = logging.getLogger(__name__)