Merge branch 'freqtrade:feat/freqai' into feat/freqai

This commit is contained in:
lolong
2022-07-13 08:44:49 +02:00
committed by GitHub
6 changed files with 131 additions and 29 deletions

View File

@@ -12,9 +12,9 @@ logger = logging.getLogger(__name__)
class BaseRegressionModel(IFreqaiModel):
"""
User created prediction model. The class needs to override three necessary
functions, predict(), train(), fit(). The class inherits ModelHandler which
has its own DataHandler where data is held, saved, loaded, and managed.
Base class for regression type models (e.g. Catboost, LightGBM, XGboost etc.).
User *must* inherit from this class and set fit() and predict(). See example scripts
such as prediction_models/CatboostPredictionModel.py for guidance.
"""
def return_values(self, dataframe: DataFrame, dk: FreqaiDataKitchen) -> DataFrame:
@@ -71,7 +71,8 @@ class BaseRegressionModel(IFreqaiModel):
data_dictionary['train_features'], model, dk, pair)
elif self.freqai_info.get('fit_live_predictions_candles', 0):
dk.fit_live_predictions()
self.dd.save_historic_predictions_to_disk()
self.dd.save_historic_predictions_to_disk()
logger.info(f"--------------------done training {pair}--------------------")

View File

@@ -0,0 +1,78 @@
import logging
from typing import Tuple
from pandas import DataFrame
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel
logger = logging.getLogger(__name__)
class BaseTensorFlowModel(IFreqaiModel):
"""
Base class for TensorFlow type models.
User *must* inherit from this class and set fit() and predict().
"""
def return_values(self, dataframe: DataFrame, dk: FreqaiDataKitchen) -> DataFrame:
"""
User uses this function to add any additional return values to the dataframe.
e.g.
dataframe['volatility'] = dk.volatility_values
"""
return dataframe
def train(
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen
) -> Tuple[DataFrame, DataFrame]:
"""
Filter the training data and train a model to it. Train makes heavy use of the datakitchen
for storing, saving, loading, and analyzing the data.
:params:
:unfiltered_dataframe: Full dataframe for the current training period
:metadata: pair metadata from strategy.
:returns:
:model: Trained model which can be used to inference (self.predict)
"""
logger.info("--------------------Starting training " f"{pair} --------------------")
# filter the features requested by user in the configuration file and elegantly handle NaNs
features_filtered, labels_filtered = dk.filter_features(
unfiltered_dataframe,
dk.training_features_list,
dk.label_list,
training_filter=True,
)
# split data into train/test data.
data_dictionary = dk.make_train_test_datasets(features_filtered, labels_filtered)
if not self.freqai_info.get('fit_live_predictions', 0):
dk.fit_labels()
# normalize all data based on train_dataset only
data_dictionary = dk.normalize_data(data_dictionary)
# optional additional data cleaning/analysis
self.data_cleaning_train(dk)
logger.info(
f'Training model on {len(dk.data_dictionary["train_features"].columns)}' " features"
)
logger.info(f'Training model on {len(data_dictionary["train_features"])} data points')
model = self.fit(data_dictionary)
if pair not in self.dd.historic_predictions:
self.set_initial_historic_predictions(
data_dictionary['train_features'], model, dk, pair)
elif self.freqai_info.get('fit_live_predictions_candles', 0):
dk.fit_live_predictions()
self.dd.save_historic_predictions_to_disk()
logger.info(f"--------------------done training {pair}--------------------")
return model