From 72b1d1c9aea5c88b959964e24ba0bb215104564f Mon Sep 17 00:00:00 2001 From: robcaulk Date: Mon, 5 Dec 2022 20:55:05 +0100 Subject: [PATCH] allow users to pass 0 test data --- docs/freqai-parameter-table.md | 2 +- freqtrade/freqai/base_models/BaseTensorFlowModel.py | 1 + freqtrade/freqai/prediction_models/CNNPredictionModel.py | 7 ++++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/freqai-parameter-table.md b/docs/freqai-parameter-table.md index d05ce80f3..c0d82004b 100644 --- a/docs/freqai-parameter-table.md +++ b/docs/freqai-parameter-table.md @@ -89,6 +89,6 @@ Mandatory parameters are marked as **Required** and have to be set in one of the | Parameter | Description | |------------|-------------| | | **Extraneous parameters** -| `freqai.keras` | If the selected model makes use of Keras (typical for Tensorflow-based prediction models), this flag needs to be activated so that the model save/loading follows Keras standards.
**Datatype:** Boolean.
Default: `False`. +| `freqai.keras` | If the selected model makes use of Keras (typical for Tensorflow-based prediction models), this flag should be activated so that the model save/loading follows Keras standards. If the the provided `CNNPredictionModel` is used, then this is handled automatically.
**Datatype:** Boolean.
Default: `False`. | `freqai.conv_width` | The width of a convolutional neural network input tensor. This replaces the need for shifting candles (`include_shifted_candles`) by feeding in historical data points as the second dimension of the tensor. Technically, this parameter can also be used for regressors, but it only adds computational overhead and does not change the model training/prediction.
**Datatype:** Integer.
Default: `2`. | `freqai.reduce_df_footprint` | Recast all numeric columns to float32/int32, with the objective of reducing ram/disk usage and decreasing train/inference timing. This parameter is set in the main level of the Freqtrade configuration file (not inside FreqAI).
**Datatype:** Boolean.
Default: `False`. diff --git a/freqtrade/freqai/base_models/BaseTensorFlowModel.py b/freqtrade/freqai/base_models/BaseTensorFlowModel.py index a12a6a9ef..f7aec1f8b 100644 --- a/freqtrade/freqai/base_models/BaseTensorFlowModel.py +++ b/freqtrade/freqai/base_models/BaseTensorFlowModel.py @@ -23,6 +23,7 @@ class BaseTensorFlowModel(IFreqaiModel): if self.ft_params.get("DI_threshold", 0): self.ft_params["DI_threshold"] = 0 logger.warning("DI threshold is not configured for Keras models yet. Deactivating.") + self.dd.model_type = 'keras' def train( self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs diff --git a/freqtrade/freqai/prediction_models/CNNPredictionModel.py b/freqtrade/freqai/prediction_models/CNNPredictionModel.py index 80e3c447f..b6ab73138 100644 --- a/freqtrade/freqai/prediction_models/CNNPredictionModel.py +++ b/freqtrade/freqai/prediction_models/CNNPredictionModel.py @@ -75,11 +75,16 @@ class CNNPredictionModel(BaseTensorFlowModel): metrics=[tf.metrics.MeanAbsoluteError()], ) + if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) == 0: + val_data = None + else: + val_data = w1.val + model.fit( w1.train, epochs=MAX_EPOCHS, shuffle=False, - validation_data=w1.val, + validation_data=val_data, callbacks=[early_stopping], verbose=1, )