diff --git a/freqtrade/freqai/base_models/BaseTensorFlowModel.py b/freqtrade/freqai/base_models/BaseTensorFlowModel.py index 2f95b6314..a86a7f5e3 100644 --- a/freqtrade/freqai/base_models/BaseTensorFlowModel.py +++ b/freqtrade/freqai/base_models/BaseTensorFlowModel.py @@ -43,7 +43,6 @@ class BaseTensorFlowModel(IFreqaiModel): start_time = time() - # filter the features requested by user in the configuration file and elegantly handle NaNs features_filtered, labels_filtered = dk.filter_features( unfiltered_df, dk.training_features_list, diff --git a/freqtrade/freqai/prediction_models/CNNPredictionModel.py b/freqtrade/freqai/prediction_models/CNNPredictionModel.py index 3b4de4ca3..5a957f59e 100644 --- a/freqtrade/freqai/prediction_models/CNNPredictionModel.py +++ b/freqtrade/freqai/prediction_models/CNNPredictionModel.py @@ -14,9 +14,6 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen logger = logging.getLogger(__name__) -# tf.config.run_functions_eagerly(True) -# tf.data.experimental.enable_debug_mode() - class CNNPredictionModel(BaseTensorFlowModel): """ @@ -49,7 +46,8 @@ class CNNPredictionModel(BaseTensorFlowModel): # we need to remove batch_size from the model_training_params because # we dont want fit() to get the incorrect assignment (we use the WindowGenerator) # to handle our batches. - self.model_training_parameters.pop('batch_size') + if 'batch_size' in self.model_training_parameters: + self.model_training_parameters.pop('batch_size') input_dims = [BATCH_SIZE, self.CONV_WIDTH, n_features] w1 = WindowGenerator( diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index c53137093..f7251949c 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -34,7 +34,8 @@ def is_mac() -> bool: ('CatboostRegressor', False, False, False), ('ReinforcementLearner', False, True, False), ('ReinforcementLearner_multiproc', False, False, False), - ('ReinforcementLearner_test_4ac', False, False, False) + ('ReinforcementLearner_test_4ac', False, False, False), + ('CNNPredictionModel', False, False, False) ]) def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, dbscan, float32): if is_arm() and model == 'CatboostRegressor': @@ -71,6 +72,10 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, if 'test_4ac' in model: freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models") + if 'CNNPredictionModel' in model: + freqai_conf['freqai']['model_training_parameters'].pop('n_estimators') + model_save_ext = 'h5' + strategy = get_patched_freqai_strategy(mocker, freqai_conf) exchange = get_patched_exchange(mocker, freqai_conf) strategy.dp = DataProvider(freqai_conf, exchange)