Improve type specifitivity

This commit is contained in:
Matthias 2022-09-12 20:00:01 +02:00
parent 0052e58917
commit 867d59b930

View File

@ -8,7 +8,7 @@ import asyncio
import logging import logging
import socket import socket
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, Dict, List from typing import TYPE_CHECKING, Any, Callable, Dict, List
import websockets import websockets
from pydantic import ValidationError from pydantic import ValidationError
@ -80,7 +80,7 @@ class ExternalMessageConsumer:
] ]
# Specify which function to use for which RPCMessageType # Specify which function to use for which RPCMessageType
self._message_handlers = { self._message_handlers: Dict[str, Callable[[str, WSMessageSchema], None]] = {
RPCMessageType.WHITELIST: self._consume_whitelist_message, RPCMessageType.WHITELIST: self._consume_whitelist_message,
RPCMessageType.ANALYZED_DF: self._consume_analyzed_df_message, RPCMessageType.ANALYZED_DF: self._consume_analyzed_df_message,
} }
@ -279,7 +279,7 @@ class ExternalMessageConsumer:
message_handler(producer_name, producer_message) message_handler(producer_name, producer_message)
def _consume_whitelist_message(self, producer_name: str, message: Any): def _consume_whitelist_message(self, producer_name: str, message: WSMessageSchema):
try: try:
# Validate the message # Validate the message
message = WSWhitelistMessage.parse_obj(message) message = WSWhitelistMessage.parse_obj(message)
@ -292,7 +292,7 @@ class ExternalMessageConsumer:
logger.debug(f"Consumed message from `{producer_name}` of type `RPCMessageType.WHITELIST`") logger.debug(f"Consumed message from `{producer_name}` of type `RPCMessageType.WHITELIST`")
def _consume_analyzed_df_message(self, producer_name: str, message: Any): def _consume_analyzed_df_message(self, producer_name: str, message: WSMessageSchema):
try: try:
message = WSAnalyzedDFMessage.parse_obj(message) message = WSAnalyzedDFMessage.parse_obj(message)
except ValidationError as e: except ValidationError as e: