Add ground work for TensorFlow models, add protections from common mistakes

This commit is contained in:
robcaulk
2022-07-12 18:09:17 +02:00
parent fea63fba12
commit ef409dd345
4 changed files with 44 additions and 21 deletions

View File

@@ -12,9 +12,9 @@ logger = logging.getLogger(__name__)
class BaseRegressionModel(IFreqaiModel):
"""
User created prediction model. The class needs to override three necessary
functions, predict(), train(), fit(). The class inherits ModelHandler which
has its own DataHandler where data is held, saved, loaded, and managed.
Base class for regression type models (e.g. Catboost, LightGBM, XGboost etc.).
User *must* inherit from this class and set fit() and predict(). See example scripts
such as prediction_models/CatboostPredictionModel.py for guidance.
"""
def return_values(self, dataframe: DataFrame, dk: FreqaiDataKitchen) -> DataFrame: