add CNN prediction model
This commit is contained in:
		| @@ -3,10 +3,10 @@ from time import time | ||||
| from typing import Any | ||||
|  | ||||
| from pandas import DataFrame | ||||
|  | ||||
| import numpy as np | ||||
| from freqtrade.freqai.data_kitchen import FreqaiDataKitchen | ||||
| from freqtrade.freqai.freqai_interface import IFreqaiModel | ||||
|  | ||||
| import tensorflow as tf | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| @@ -17,6 +17,13 @@ class BaseTensorFlowModel(IFreqaiModel): | ||||
|     User *must* inherit from this class and set fit() and predict(). | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         super().__init__(config=kwargs['config']) | ||||
|         self.keras = True | ||||
|         if self.ft_params.get("DI_threshold", 0): | ||||
|             self.ft_params["DI_threshold"] = 0 | ||||
|             logger.warning("DI threshold is not configured for Keras models yet. Deactivating.") | ||||
|  | ||||
|     def train( | ||||
|         self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs | ||||
|     ) -> Any: | ||||
| @@ -68,3 +75,76 @@ class BaseTensorFlowModel(IFreqaiModel): | ||||
|                     f"({end_time - start_time:.2f} secs) --------------------") | ||||
|  | ||||
|         return model | ||||
|  | ||||
|  | ||||
| class WindowGenerator: | ||||
|     def __init__( | ||||
|         self, | ||||
|         input_width, | ||||
|         label_width, | ||||
|         shift, | ||||
|         train_df=None, | ||||
|         val_df=None, | ||||
|         test_df=None, | ||||
|         train_labels=None, | ||||
|         val_labels=None, | ||||
|         test_labels=None, | ||||
|         batch_size=None, | ||||
|     ): | ||||
|         # Store the raw data. | ||||
|         self.train_df = train_df | ||||
|         self.val_df = val_df | ||||
|         self.test_df = test_df | ||||
|         self.train_labels = train_labels | ||||
|         self.val_labels = val_labels | ||||
|         self.test_labels = test_labels | ||||
|         self.batch_size = batch_size | ||||
|         self.input_width = input_width | ||||
|         self.label_width = label_width | ||||
|         self.shift = shift | ||||
|         self.total_window_size = input_width + shift | ||||
|         self.input_slice = slice(0, input_width) | ||||
|         self.input_indices = np.arange(self.total_window_size)[self.input_slice] | ||||
|  | ||||
|     def make_dataset(self, data, labels=None): | ||||
|         data = np.array(data, dtype=np.float32) | ||||
|         if labels is not None: | ||||
|             labels = np.array(labels, dtype=np.float32) | ||||
|         ds = tf.keras.preprocessing.timeseries_dataset_from_array( | ||||
|             data=data, | ||||
|             targets=labels, | ||||
|             sequence_length=self.total_window_size, | ||||
|             sequence_stride=1, | ||||
|             sampling_rate=1, | ||||
|             shuffle=False, | ||||
|             batch_size=self.batch_size, | ||||
|         ) | ||||
|  | ||||
|         return ds | ||||
|  | ||||
|     @property | ||||
|     def train(self): | ||||
|         return self.make_dataset(self.train_df, self.train_labels) | ||||
|  | ||||
|     @property | ||||
|     def val(self): | ||||
|         return self.make_dataset(self.val_df, self.val_labels) | ||||
|  | ||||
|     @property | ||||
|     def test(self): | ||||
|         return self.make_dataset(self.test_df, self.test_labels) | ||||
|  | ||||
|     @property | ||||
|     def inference(self): | ||||
|         return self.make_dataset(self.test_df) | ||||
|  | ||||
|     @property | ||||
|     def example(self): | ||||
|         """Get and cache an example batch of `inputs, labels` for plotting.""" | ||||
|         result = getattr(self, "_example", None) | ||||
|         if result is None: | ||||
|             # No example batch was found, so get one from the `.train` dataset | ||||
|             result = next(iter(self.train)) | ||||
|             # And cache it for next time | ||||
|             self._example = result | ||||
|         return result | ||||
|   | ||||
		Reference in New Issue
	
	Block a user