add kwargs, reduce duplicated code

This commit is contained in:
robcaulk
2022-09-07 18:58:55 +02:00
parent 97077ba18a
commit 4c9ac6b7c0
7 changed files with 31 additions and 30 deletions

View File

@@ -661,11 +661,20 @@ class IFreqaiModel(ABC):
self.train_time = 0
return
def get_init_model(self, pair: str) -> Any:
if pair not in self.dd.model_dictionary or not self.continual_learning:
init_model = None
else:
init_model = self.dd.model_dictionary[pair]
return init_model
# Following methods which are overridden by user made prediction models.
# See freqai/prediction_models/CatboostPredictionModel.py for an example.
@abstractmethod
def train(self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen) -> Any:
def train(self, unfiltered_dataframe: DataFrame, pair: str,
dk: FreqaiDataKitchen, **kwargs) -> Any:
"""
Filter the training data and train a model to it. Train makes heavy use of the datahandler
for storing, saving, loading, and analyzing the data.
@@ -675,7 +684,7 @@ class IFreqaiModel(ABC):
"""
@abstractmethod
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen) -> Any:
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs) -> Any:
"""
Most regressors use the same function names and arguments e.g. user
can drop in LGBMRegressor in place of CatBoostRegressor and all data
@@ -688,7 +697,7 @@ class IFreqaiModel(ABC):
@abstractmethod
def predict(
self, dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = True
self, dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = True, **kwargs
) -> Tuple[DataFrame, NDArray[np.int_]]:
"""
Filter the prediction features data and predict with it.