add kwargs, reduce duplicated code
This commit is contained in:
parent
97077ba18a
commit
4c9ac6b7c0
@ -661,11 +661,20 @@ class IFreqaiModel(ABC):
|
|||||||
self.train_time = 0
|
self.train_time = 0
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def get_init_model(self, pair: str) -> Any:
|
||||||
|
if pair not in self.dd.model_dictionary or not self.continual_learning:
|
||||||
|
init_model = None
|
||||||
|
else:
|
||||||
|
init_model = self.dd.model_dictionary[pair]
|
||||||
|
|
||||||
|
return init_model
|
||||||
|
|
||||||
# Following methods which are overridden by user made prediction models.
|
# Following methods which are overridden by user made prediction models.
|
||||||
# See freqai/prediction_models/CatboostPredictionModel.py for an example.
|
# See freqai/prediction_models/CatboostPredictionModel.py for an example.
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def train(self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen) -> Any:
|
def train(self, unfiltered_dataframe: DataFrame, pair: str,
|
||||||
|
dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
Filter the training data and train a model to it. Train makes heavy use of the datahandler
|
Filter the training data and train a model to it. Train makes heavy use of the datahandler
|
||||||
for storing, saving, loading, and analyzing the data.
|
for storing, saving, loading, and analyzing the data.
|
||||||
@ -675,7 +684,7 @@ class IFreqaiModel(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen) -> Any:
|
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
Most regressors use the same function names and arguments e.g. user
|
Most regressors use the same function names and arguments e.g. user
|
||||||
can drop in LGBMRegressor in place of CatBoostRegressor and all data
|
can drop in LGBMRegressor in place of CatBoostRegressor and all data
|
||||||
@ -688,7 +697,7 @@ class IFreqaiModel(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def predict(
|
def predict(
|
||||||
self, dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = True
|
self, dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = True, **kwargs
|
||||||
) -> Tuple[DataFrame, NDArray[np.int_]]:
|
) -> Tuple[DataFrame, NDArray[np.int_]]:
|
||||||
"""
|
"""
|
||||||
Filter the prediction features data and predict with it.
|
Filter the prediction features data and predict with it.
|
||||||
|
@ -2,6 +2,7 @@ import logging
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from catboost import CatBoostClassifier, Pool
|
from catboost import CatBoostClassifier, Pool
|
||||||
|
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel
|
from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel
|
||||||
|
|
||||||
@ -16,7 +17,7 @@ class CatboostClassifier(BaseClassifierModel):
|
|||||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any:
|
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:params:
|
:params:
|
||||||
@ -36,10 +37,7 @@ class CatboostClassifier(BaseClassifierModel):
|
|||||||
**self.model_training_parameters,
|
**self.model_training_parameters,
|
||||||
)
|
)
|
||||||
|
|
||||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
init_model = self.get_init_model(dk.pair)
|
||||||
init_model = None
|
|
||||||
else:
|
|
||||||
init_model = self.dd.model_dictionary[dk.pair]
|
|
||||||
|
|
||||||
cbr.fit(train_data, init_model=init_model)
|
cbr.fit(train_data, init_model=init_model)
|
||||||
|
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import gc
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from catboost import CatBoostRegressor, Pool
|
from catboost import CatBoostRegressor, Pool
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
|
||||||
|
|
||||||
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
|
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
|
||||||
|
|
||||||
|
|
||||||
@ -18,7 +17,7 @@ class CatboostRegressor(BaseRegressionModel):
|
|||||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any:
|
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||||
@ -39,10 +38,7 @@ class CatboostRegressor(BaseRegressionModel):
|
|||||||
weight=data_dictionary["test_weights"],
|
weight=data_dictionary["test_weights"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
init_model = self.get_init_model(dk.pair)
|
||||||
init_model = None
|
|
||||||
else:
|
|
||||||
init_model = self.dd.model_dictionary[dk.pair]
|
|
||||||
|
|
||||||
model = CatBoostRegressor(
|
model = CatBoostRegressor(
|
||||||
allow_writing_files=False,
|
allow_writing_files=False,
|
||||||
|
@ -3,6 +3,7 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
from catboost import CatBoostRegressor # , Pool
|
from catboost import CatBoostRegressor # , Pool
|
||||||
from sklearn.multioutput import MultiOutputRegressor
|
from sklearn.multioutput import MultiOutputRegressor
|
||||||
|
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
|
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
|
||||||
|
|
||||||
@ -17,7 +18,7 @@ class CatboostRegressorMultiTarget(BaseRegressionModel):
|
|||||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any:
|
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||||
|
@ -3,8 +3,9 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
from lightgbm import LGBMClassifier
|
from lightgbm import LGBMClassifier
|
||||||
|
|
||||||
from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel
|
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
|
from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -16,7 +17,7 @@ class LightGBMClassifier(BaseClassifierModel):
|
|||||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any:
|
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:params:
|
:params:
|
||||||
@ -35,10 +36,7 @@ class LightGBMClassifier(BaseClassifierModel):
|
|||||||
y = data_dictionary["train_labels"].to_numpy()[:, 0]
|
y = data_dictionary["train_labels"].to_numpy()[:, 0]
|
||||||
train_weights = data_dictionary["train_weights"]
|
train_weights = data_dictionary["train_weights"]
|
||||||
|
|
||||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
init_model = self.get_init_model(dk.pair)
|
||||||
init_model = None
|
|
||||||
else:
|
|
||||||
init_model = self.dd.model_dictionary[dk.pair]
|
|
||||||
|
|
||||||
model = LGBMClassifier(**self.model_training_parameters)
|
model = LGBMClassifier(**self.model_training_parameters)
|
||||||
|
|
||||||
|
@ -3,8 +3,9 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
from lightgbm import LGBMRegressor
|
from lightgbm import LGBMRegressor
|
||||||
|
|
||||||
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
|
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
|
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -16,7 +17,7 @@ class LightGBMRegressor(BaseRegressionModel):
|
|||||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any:
|
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
Most regressors use the same function names and arguments e.g. user
|
Most regressors use the same function names and arguments e.g. user
|
||||||
can drop in LGBMRegressor in place of CatBoostRegressor and all data
|
can drop in LGBMRegressor in place of CatBoostRegressor and all data
|
||||||
@ -35,10 +36,7 @@ class LightGBMRegressor(BaseRegressionModel):
|
|||||||
y = data_dictionary["train_labels"]
|
y = data_dictionary["train_labels"]
|
||||||
train_weights = data_dictionary["train_weights"]
|
train_weights = data_dictionary["train_weights"]
|
||||||
|
|
||||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
init_model = self.get_init_model(dk.pair)
|
||||||
init_model = None
|
|
||||||
else:
|
|
||||||
init_model = self.dd.model_dictionary[dk.pair]
|
|
||||||
|
|
||||||
model = LGBMRegressor(**self.model_training_parameters)
|
model = LGBMRegressor(**self.model_training_parameters)
|
||||||
|
|
||||||
|
@ -4,8 +4,9 @@ from typing import Any, Dict
|
|||||||
from lightgbm import LGBMRegressor
|
from lightgbm import LGBMRegressor
|
||||||
from sklearn.multioutput import MultiOutputRegressor
|
from sklearn.multioutput import MultiOutputRegressor
|
||||||
|
|
||||||
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
|
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
|
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -17,7 +18,7 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel):
|
|||||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any:
|
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||||
|
Loading…
Reference in New Issue
Block a user