add kwargs, reduce duplicated code

This commit is contained in:
robcaulk 2022-09-07 18:58:55 +02:00
parent 97077ba18a
commit 4c9ac6b7c0
7 changed files with 31 additions and 30 deletions

View File

@ -661,11 +661,20 @@ class IFreqaiModel(ABC):
self.train_time = 0 self.train_time = 0
return return
def get_init_model(self, pair: str) -> Any:
if pair not in self.dd.model_dictionary or not self.continual_learning:
init_model = None
else:
init_model = self.dd.model_dictionary[pair]
return init_model
# Following methods which are overridden by user made prediction models. # Following methods which are overridden by user made prediction models.
# See freqai/prediction_models/CatboostPredictionModel.py for an example. # See freqai/prediction_models/CatboostPredictionModel.py for an example.
@abstractmethod @abstractmethod
def train(self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen) -> Any: def train(self, unfiltered_dataframe: DataFrame, pair: str,
dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
Filter the training data and train a model to it. Train makes heavy use of the datahandler Filter the training data and train a model to it. Train makes heavy use of the datahandler
for storing, saving, loading, and analyzing the data. for storing, saving, loading, and analyzing the data.
@ -675,7 +684,7 @@ class IFreqaiModel(ABC):
""" """
@abstractmethod @abstractmethod
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen) -> Any: def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
Most regressors use the same function names and arguments e.g. user Most regressors use the same function names and arguments e.g. user
can drop in LGBMRegressor in place of CatBoostRegressor and all data can drop in LGBMRegressor in place of CatBoostRegressor and all data
@ -688,7 +697,7 @@ class IFreqaiModel(ABC):
@abstractmethod @abstractmethod
def predict( def predict(
self, dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = True self, dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = True, **kwargs
) -> Tuple[DataFrame, NDArray[np.int_]]: ) -> Tuple[DataFrame, NDArray[np.int_]]:
""" """
Filter the prediction features data and predict with it. Filter the prediction features data and predict with it.

View File

@ -2,6 +2,7 @@ import logging
from typing import Any, Dict from typing import Any, Dict
from catboost import CatBoostClassifier, Pool from catboost import CatBoostClassifier, Pool
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel
@ -16,7 +17,7 @@ class CatboostClassifier(BaseClassifierModel):
has its own DataHandler where data is held, saved, loaded, and managed. has its own DataHandler where data is held, saved, loaded, and managed.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:params: :params:
@ -36,10 +37,7 @@ class CatboostClassifier(BaseClassifierModel):
**self.model_training_parameters, **self.model_training_parameters,
) )
if dk.pair not in self.dd.model_dictionary or not self.continual_learning: init_model = self.get_init_model(dk.pair)
init_model = None
else:
init_model = self.dd.model_dictionary[dk.pair]
cbr.fit(train_data, init_model=init_model) cbr.fit(train_data, init_model=init_model)

View File

@ -1,10 +1,9 @@
import gc
import logging import logging
from typing import Any, Dict from typing import Any, Dict
from catboost import CatBoostRegressor, Pool from catboost import CatBoostRegressor, Pool
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
@ -18,7 +17,7 @@ class CatboostRegressor(BaseRegressionModel):
has its own DataHandler where data is held, saved, loaded, and managed. has its own DataHandler where data is held, saved, loaded, and managed.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary constructed by DataHandler to hold
@ -39,10 +38,7 @@ class CatboostRegressor(BaseRegressionModel):
weight=data_dictionary["test_weights"], weight=data_dictionary["test_weights"],
) )
if dk.pair not in self.dd.model_dictionary or not self.continual_learning: init_model = self.get_init_model(dk.pair)
init_model = None
else:
init_model = self.dd.model_dictionary[dk.pair]
model = CatBoostRegressor( model = CatBoostRegressor(
allow_writing_files=False, allow_writing_files=False,

View File

@ -3,6 +3,7 @@ from typing import Any, Dict
from catboost import CatBoostRegressor # , Pool from catboost import CatBoostRegressor # , Pool
from sklearn.multioutput import MultiOutputRegressor from sklearn.multioutput import MultiOutputRegressor
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
@ -17,7 +18,7 @@ class CatboostRegressorMultiTarget(BaseRegressionModel):
has its own DataHandler where data is held, saved, loaded, and managed. has its own DataHandler where data is held, saved, loaded, and managed.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary constructed by DataHandler to hold

View File

@ -3,8 +3,9 @@ from typing import Any, Dict
from lightgbm import LGBMClassifier from lightgbm import LGBMClassifier
from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,7 +17,7 @@ class LightGBMClassifier(BaseClassifierModel):
has its own DataHandler where data is held, saved, loaded, and managed. has its own DataHandler where data is held, saved, loaded, and managed.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:params: :params:
@ -35,10 +36,7 @@ class LightGBMClassifier(BaseClassifierModel):
y = data_dictionary["train_labels"].to_numpy()[:, 0] y = data_dictionary["train_labels"].to_numpy()[:, 0]
train_weights = data_dictionary["train_weights"] train_weights = data_dictionary["train_weights"]
if dk.pair not in self.dd.model_dictionary or not self.continual_learning: init_model = self.get_init_model(dk.pair)
init_model = None
else:
init_model = self.dd.model_dictionary[dk.pair]
model = LGBMClassifier(**self.model_training_parameters) model = LGBMClassifier(**self.model_training_parameters)

View File

@ -3,8 +3,9 @@ from typing import Any, Dict
from lightgbm import LGBMRegressor from lightgbm import LGBMRegressor
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,7 +17,7 @@ class LightGBMRegressor(BaseRegressionModel):
has its own DataHandler where data is held, saved, loaded, and managed. has its own DataHandler where data is held, saved, loaded, and managed.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
Most regressors use the same function names and arguments e.g. user Most regressors use the same function names and arguments e.g. user
can drop in LGBMRegressor in place of CatBoostRegressor and all data can drop in LGBMRegressor in place of CatBoostRegressor and all data
@ -35,10 +36,7 @@ class LightGBMRegressor(BaseRegressionModel):
y = data_dictionary["train_labels"] y = data_dictionary["train_labels"]
train_weights = data_dictionary["train_weights"] train_weights = data_dictionary["train_weights"]
if dk.pair not in self.dd.model_dictionary or not self.continual_learning: init_model = self.get_init_model(dk.pair)
init_model = None
else:
init_model = self.dd.model_dictionary[dk.pair]
model = LGBMRegressor(**self.model_training_parameters) model = LGBMRegressor(**self.model_training_parameters)

View File

@ -4,8 +4,9 @@ from typing import Any, Dict
from lightgbm import LGBMRegressor from lightgbm import LGBMRegressor
from sklearn.multioutput import MultiOutputRegressor from sklearn.multioutput import MultiOutputRegressor
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,7 +18,7 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel):
has its own DataHandler where data is held, saved, loaded, and managed. has its own DataHandler where data is held, saved, loaded, and managed.
""" """
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
User sets up the training and test data to fit their desired model here User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold :param data_dictionary: the dictionary constructed by DataHandler to hold