add test for CNNPredictionModel
This commit is contained in:
		| @@ -43,7 +43,6 @@ class BaseTensorFlowModel(IFreqaiModel): | ||||
|  | ||||
|         start_time = time() | ||||
|  | ||||
|         # filter the features requested by user in the configuration file and elegantly handle NaNs | ||||
|         features_filtered, labels_filtered = dk.filter_features( | ||||
|             unfiltered_df, | ||||
|             dk.training_features_list, | ||||
|   | ||||
| @@ -14,9 +14,6 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| # tf.config.run_functions_eagerly(True) | ||||
| # tf.data.experimental.enable_debug_mode() | ||||
|  | ||||
|  | ||||
| class CNNPredictionModel(BaseTensorFlowModel): | ||||
|     """ | ||||
| @@ -49,7 +46,8 @@ class CNNPredictionModel(BaseTensorFlowModel): | ||||
|         # we need to remove batch_size from the model_training_params because | ||||
|         # we dont want fit() to get the incorrect assignment (we use the WindowGenerator) | ||||
|         # to handle our batches. | ||||
|         self.model_training_parameters.pop('batch_size') | ||||
|         if 'batch_size' in self.model_training_parameters: | ||||
|             self.model_training_parameters.pop('batch_size') | ||||
|         input_dims = [BATCH_SIZE, self.CONV_WIDTH, n_features] | ||||
|  | ||||
|         w1 = WindowGenerator( | ||||
|   | ||||
| @@ -34,7 +34,8 @@ def is_mac() -> bool: | ||||
|     ('CatboostRegressor', False, False, False), | ||||
|     ('ReinforcementLearner', False, True, False), | ||||
|     ('ReinforcementLearner_multiproc', False, False, False), | ||||
|     ('ReinforcementLearner_test_4ac', False, False, False) | ||||
|     ('ReinforcementLearner_test_4ac', False, False, False), | ||||
|     ('CNNPredictionModel', False, False, False) | ||||
|     ]) | ||||
| def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, dbscan, float32): | ||||
|     if is_arm() and model == 'CatboostRegressor': | ||||
| @@ -71,6 +72,10 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, | ||||
|     if 'test_4ac' in model: | ||||
|         freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models") | ||||
|  | ||||
|     if 'CNNPredictionModel' in model: | ||||
|         freqai_conf['freqai']['model_training_parameters'].pop('n_estimators') | ||||
|         model_save_ext = 'h5' | ||||
|  | ||||
|     strategy = get_patched_freqai_strategy(mocker, freqai_conf) | ||||
|     exchange = get_patched_exchange(mocker, freqai_conf) | ||||
|     strategy.dp = DataProvider(freqai_conf, exchange) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user