diff --git a/freqtrade/freqai/data_kitchen.py b/freqtrade/freqai/data_kitchen.py index 2cde31441..5c9eef03f 100644 --- a/freqtrade/freqai/data_kitchen.py +++ b/freqtrade/freqai/data_kitchen.py @@ -112,7 +112,7 @@ class FreqaiDataKitchen: self.unique_class_list: list = [] self.backtest_live_models_data: Dict[str, Any] = {} self.normalizer: Normalization = normalization_factory(config, self.data, self.pkl_data, - self.unique_class_list) + self.unique_class_list) def set_paths( self, diff --git a/freqtrade/freqai/normalization.py b/freqtrade/freqai/normalization.py index 641bfcdaa..ae395008a 100644 --- a/freqtrade/freqai/normalization.py +++ b/freqtrade/freqai/normalization.py @@ -16,20 +16,19 @@ def normalization_factory( config: Config, meta_data: Dict[str, Any], pickle_meta_data: Dict[str, Any], - unique_class_list: list - ): - freqai_config: Dict[str, Any] = config["freqai"] - norm_config_id = freqai_config["feature_parameters"].get("data_normalization", "legacy") - if norm_config_id.lower() == "legacy": - return LegacyNormalization(config, meta_data, pickle_meta_data, unique_class_list) - elif norm_config_id.lower() == "standard": - return StandardNormalization(config, meta_data, pickle_meta_data, unique_class_list) - elif norm_config_id.lower() == "minmax": - return MinMaxNormalization(config, meta_data, pickle_meta_data, unique_class_list) - elif norm_config_id.lower() == "quantile": - return QuantileNormalization(config, meta_data, pickle_meta_data, unique_class_list) - else: - raise OperationalException(f"Invalid data normalization identifier '{norm_config_id}'") + unique_class_list: list): + freqai_config: Dict[str, Any] = config["freqai"] + norm_config_id = freqai_config["feature_parameters"].get("data_normalization", "legacy") + if norm_config_id.lower() == "legacy": + return LegacyNormalization(config, meta_data, pickle_meta_data, unique_class_list) + elif norm_config_id.lower() == "standard": + return StandardNormalization(config, meta_data, pickle_meta_data, unique_class_list) + elif norm_config_id.lower() == "minmax": + return MinMaxNormalization(config, meta_data, pickle_meta_data, unique_class_list) + elif norm_config_id.lower() == "quantile": + return QuantileNormalization(config, meta_data, pickle_meta_data, unique_class_list) + else: + raise OperationalException(f"Invalid data normalization identifier '{norm_config_id}'") class Normalization(ABC): @@ -268,5 +267,3 @@ class QuantileNormalization(SKLearnNormalization): unique_class_list: list): super().__init__(config, meta_data, pickle_meta_data, unique_class_list, QuantileTransformer) - - diff --git a/tests/freqai/test_normalization.py b/tests/freqai/test_normalization.py index 36294215e..375f6c477 100644 --- a/tests/freqai/test_normalization.py +++ b/tests/freqai/test_normalization.py @@ -84,6 +84,7 @@ def test_assertion_invalid_normalization_id(mocker, freqai_conf): assert str(e_info).startswith("Invalid data normalization identifier"), \ "unexpected exception string" + @pytest.mark.parametrize( "config_id", [