diff --git a/build_helpers/publish_docker_multi.sh b/build_helpers/publish_docker_multi.sh index 3e5e61564..6c5d11d94 100755 --- a/build_helpers/publish_docker_multi.sh +++ b/build_helpers/publish_docker_multi.sh @@ -7,6 +7,7 @@ TAG=$(echo "${BRANCH_NAME}" | sed -e "s/\//_/g") TAG_PLOT=${TAG}_plot TAG_FREQAI=${TAG}_freqai TAG_FREQAI_RL=${TAG_FREQAI}rl +TAG_FREQAI_RL=${TAG_FREQAI}torch TAG_PI="${TAG}_pi" PI_PLATFORM="linux/arm/v7" @@ -64,6 +65,7 @@ docker build --cache-from freqtrade:${TAG_FREQAI} --build-arg sourceimage=${CACH docker tag freqtrade:$TAG_PLOT ${CACHE_IMAGE}:$TAG_PLOT docker tag freqtrade:$TAG_FREQAI ${CACHE_IMAGE}:$TAG_FREQAI docker tag freqtrade:$TAG_FREQAI_RL ${CACHE_IMAGE}:$TAG_FREQAI_RL +docker tag freqtrade:$TAG_FREQAI_RL ${CACHE_IMAGE}:$TAG_FREQAI_TORCH # Run backtest docker run --rm -v $(pwd)/config_examples/config_bittrex.example.json:/freqtrade/config.json:ro -v $(pwd)/tests:/tests freqtrade:${TAG} backtesting --datadir /tests/testdata --strategy-path /tests/strategy/strats/ --strategy StrategyTestV3 diff --git a/docs/freqai-configuration.md b/docs/freqai-configuration.md index 442705b53..8f1aa5079 100644 --- a/docs/freqai-configuration.md +++ b/docs/freqai-configuration.md @@ -237,7 +237,7 @@ df['&s-up_or_down'] = np.where( df["close"].shift(-100) > df["close"], 'up', 'do df['&s-up_or_down'] = np.where( df["close"].shift(-100) == df["close"], 'same', df['&s-up_or_down']) ``` -## PyTorch Models +## PyTorch Module ### Quick start @@ -247,14 +247,16 @@ The easiest way to quickly run a pytorch model is with the following command (fo freqtrade trade --config config_examples/config_freqai.example.json --strategy FreqaiExampleStrategy --freqaimodel PyTorchMLPRegressor --strategy-path freqtrade/templates ``` +!!! note "Installation/docker" + The PyTorch module requires large packages such as `torch`, which should be explicitly requested during `./setup.sh -i` by answering "y" to the question "Do you also want dependencies for freqai-rl or PyTorch (~700mb additional space required) [y/N]?". + Users who prefer docker should ensure they use the docker image appended with `_freqaitorch`. + ### Structure #### Model -You can use any pytorch model. Here is an example of logistic regression model implementation using pytorch (should be used with nn.BCELoss criterion) for classification tasks. +You can construct your own Neural Network architecture in PyTorch by simply defining your `nn.Module` class inside your custom [`IFreqaiModel` file](#using-different-prediction-models) and then using that class in your `def train()` function. Here is an example of logistic regression model implementation using PyTorch (should be used with nn.BCELoss criterion) for classification tasks. ```python -import torch.nn as nn -import torch class LogisticRegression(nn.Module): def __init__(self, input_size: int): @@ -268,11 +270,59 @@ class LogisticRegression(nn.Module): out = self.linear(x) out = self.activation(out) return out + +class MyCoolPyTorchClassifier(BasePyTorchClassifier): + """ + This is a custom IFreqaiModel showing how a user might setup their own + custom Neural Network architecture for their training. + """ + + @property + def data_convertor(self) -> PyTorchDataConvertor: + return DefaultPyTorchDataConvertor(target_tensor_type=torch.float) + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + config = self.freqai_info.get("model_training_parameters", {}) + self.learning_rate: float = config.get("learning_rate", 3e-4) + self.model_kwargs: Dict[str, Any] = config.get("model_kwargs", {}) + self.trainer_kwargs: Dict[str, Any] = config.get("trainer_kwargs", {}) + + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: + """ + User sets up the training and test data to fit their desired model here + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model + """ + + class_names = self.get_class_names() + self.convert_label_column_to_int(data_dictionary, dk, class_names) + n_features = data_dictionary["train_features"].shape[-1] + model = LogisticRegression( + input_dim=n_features + ) + model.to(self.device) + optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate) + criterion = torch.nn.CrossEntropyLoss() + init_model = self.get_init_model(dk.pair) + trainer = PyTorchModelTrainer( + model=model, + optimizer=optimizer, + criterion=criterion, + model_meta_data={"class_names": class_names}, + device=self.device, + init_model=init_model, + data_convertor=self.data_convertor, + **self.trainer_kwargs, + ) + trainer.fit(data_dictionary, self.splits) + return trainer + ``` - #### Trainer -The `PyTorchModelTrainer` performs the idiomatic pytorch train loop: +The `PyTorchModelTrainer` performs the idiomatic PyTorch train loop: Define our model, loss function, and optimizer, and then move them to the appropriate device (GPU or CPU). Inside the loop, we iterate through the batches in the dataloader, move the data to the device, compute the prediction and loss, backpropagate, and update the model parameters using the optimizer. In addition, the trainer is responsible for the following: diff --git a/freqtrade/freqai/data_drawer.py b/freqtrade/freqai/data_drawer.py index c8dadb171..b68a9dcad 100644 --- a/freqtrade/freqai/data_drawer.py +++ b/freqtrade/freqai/data_drawer.py @@ -539,7 +539,9 @@ class FreqaiDataDrawer: model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model") elif self.model_type == 'pytorch': import torch - model = torch.load(dk.data_path / f"{dk.model_filename}_model.zip") + zip = torch.load(dk.data_path / f"{dk.model_filename}_model.zip") + model = zip["pytrainer"] + model = model.load_from_checkpoint(zip) if Path(dk.data_path / f"{dk.model_filename}_svm_model.joblib").is_file(): dk.svm_model = load(dk.data_path / f"{dk.model_filename}_svm_model.joblib") diff --git a/freqtrade/freqai/prediction_models/CatboostClassifier.py b/freqtrade/freqai/prediction_models/CatboostClassifier.py index ca1d8ece0..b9904e40d 100644 --- a/freqtrade/freqai/prediction_models/CatboostClassifier.py +++ b/freqtrade/freqai/prediction_models/CatboostClassifier.py @@ -14,16 +14,20 @@ logger = logging.getLogger(__name__) class CatboostClassifier(BaseClassifierModel): """ - User created prediction model. The class needs to override three necessary - functions, predict(), train(), fit(). The class inherits ModelHandler which - has its own DataHandler where data is held, saved, loaded, and managed. + User created prediction model. The class inherits IFreqaiModel, which + means it has full access to all Frequency AI functionality. Typically, + users would use this to override the common `fit()`, `train()`, or + `predict()` methods to add their custom data handling tools or change + various aspects of the training that cannot be configured via the + top level config.json file. """ def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here - :param data_dictionary: the dictionary constructed by DataHandler to hold - all the training and test data/labels. + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model """ train_data = Pool( diff --git a/freqtrade/freqai/prediction_models/CatboostClassifierMultiTarget.py b/freqtrade/freqai/prediction_models/CatboostClassifierMultiTarget.py index c6f900fad..58c47566a 100644 --- a/freqtrade/freqai/prediction_models/CatboostClassifierMultiTarget.py +++ b/freqtrade/freqai/prediction_models/CatboostClassifierMultiTarget.py @@ -15,16 +15,20 @@ logger = logging.getLogger(__name__) class CatboostClassifierMultiTarget(BaseClassifierModel): """ - User created prediction model. The class needs to override three necessary - functions, predict(), train(), fit(). The class inherits ModelHandler which - has its own DataHandler where data is held, saved, loaded, and managed. + User created prediction model. The class inherits IFreqaiModel, which + means it has full access to all Frequency AI functionality. Typically, + users would use this to override the common `fit()`, `train()`, or + `predict()` methods to add their custom data handling tools or change + various aspects of the training that cannot be configured via the + top level config.json file. """ def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here - :param data_dictionary: the dictionary constructed by DataHandler to hold - all the training and test data/labels. + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model """ cbc = CatBoostClassifier( diff --git a/freqtrade/freqai/prediction_models/CatboostRegressor.py b/freqtrade/freqai/prediction_models/CatboostRegressor.py index 4b17a703b..28b1b11cc 100644 --- a/freqtrade/freqai/prediction_models/CatboostRegressor.py +++ b/freqtrade/freqai/prediction_models/CatboostRegressor.py @@ -14,16 +14,20 @@ logger = logging.getLogger(__name__) class CatboostRegressor(BaseRegressionModel): """ - User created prediction model. The class needs to override three necessary - functions, predict(), train(), fit(). The class inherits ModelHandler which - has its own DataHandler where data is held, saved, loaded, and managed. + User created prediction model. The class inherits IFreqaiModel, which + means it has full access to all Frequency AI functionality. Typically, + users would use this to override the common `fit()`, `train()`, or + `predict()` methods to add their custom data handling tools or change + various aspects of the training that cannot be configured via the + top level config.json file. """ def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here - :param data_dictionary: the dictionary constructed by DataHandler to hold - all the training and test data/labels. + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model """ train_data = Pool( diff --git a/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py index 976d0b29b..1562c2024 100644 --- a/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py @@ -15,16 +15,20 @@ logger = logging.getLogger(__name__) class CatboostRegressorMultiTarget(BaseRegressionModel): """ - User created prediction model. The class needs to override three necessary - functions, predict(), train(), fit(). The class inherits ModelHandler which - has its own DataHandler where data is held, saved, loaded, and managed. + User created prediction model. The class inherits IFreqaiModel, which + means it has full access to all Frequency AI functionality. Typically, + users would use this to override the common `fit()`, `train()`, or + `predict()` methods to add their custom data handling tools or change + various aspects of the training that cannot be configured via the + top level config.json file. """ def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here - :param data_dictionary: the dictionary constructed by DataHandler to hold - all the training and test data/labels. + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model """ cbr = CatBoostRegressor( diff --git a/freqtrade/freqai/prediction_models/LightGBMClassifier.py b/freqtrade/freqai/prediction_models/LightGBMClassifier.py index e467ad3c1..45f3a31d0 100644 --- a/freqtrade/freqai/prediction_models/LightGBMClassifier.py +++ b/freqtrade/freqai/prediction_models/LightGBMClassifier.py @@ -12,16 +12,20 @@ logger = logging.getLogger(__name__) class LightGBMClassifier(BaseClassifierModel): """ - User created prediction model. The class needs to override three necessary - functions, predict(), train(), fit(). The class inherits ModelHandler which - has its own DataHandler where data is held, saved, loaded, and managed. + User created prediction model. The class inherits IFreqaiModel, which + means it has full access to all Frequency AI functionality. Typically, + users would use this to override the common `fit()`, `train()`, or + `predict()` methods to add their custom data handling tools or change + various aspects of the training that cannot be configured via the + top level config.json file. """ def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here - :param data_dictionary: the dictionary constructed by DataHandler to hold - all the training and test data/labels. + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model """ if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) == 0: diff --git a/freqtrade/freqai/prediction_models/LightGBMClassifierMultiTarget.py b/freqtrade/freqai/prediction_models/LightGBMClassifierMultiTarget.py index d1eb6daa2..72a8ee259 100644 --- a/freqtrade/freqai/prediction_models/LightGBMClassifierMultiTarget.py +++ b/freqtrade/freqai/prediction_models/LightGBMClassifierMultiTarget.py @@ -13,16 +13,20 @@ logger = logging.getLogger(__name__) class LightGBMClassifierMultiTarget(BaseClassifierModel): """ - User created prediction model. The class needs to override three necessary - functions, predict(), train(), fit(). The class inherits ModelHandler which - has its own DataHandler where data is held, saved, loaded, and managed. + User created prediction model. The class inherits IFreqaiModel, which + means it has full access to all Frequency AI functionality. Typically, + users would use this to override the common `fit()`, `train()`, or + `predict()` methods to add their custom data handling tools or change + various aspects of the training that cannot be configured via the + top level config.json file. """ def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here - :param data_dictionary: the dictionary constructed by DataHandler to hold - all the training and test data/labels. + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model """ lgb = LGBMClassifier(**self.model_training_parameters) diff --git a/freqtrade/freqai/prediction_models/LightGBMRegressor.py b/freqtrade/freqai/prediction_models/LightGBMRegressor.py index 85c9b691c..3d1c30ed3 100644 --- a/freqtrade/freqai/prediction_models/LightGBMRegressor.py +++ b/freqtrade/freqai/prediction_models/LightGBMRegressor.py @@ -12,18 +12,20 @@ logger = logging.getLogger(__name__) class LightGBMRegressor(BaseRegressionModel): """ - User created prediction model. The class needs to override three necessary - functions, predict(), train(), fit(). The class inherits ModelHandler which - has its own DataHandler where data is held, saved, loaded, and managed. + User created prediction model. The class inherits IFreqaiModel, which + means it has full access to all Frequency AI functionality. Typically, + users would use this to override the common `fit()`, `train()`, or + `predict()` methods to add their custom data handling tools or change + various aspects of the training that cannot be configured via the + top level config.json file. """ def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ - Most regressors use the same function names and arguments e.g. user - can drop in LGBMRegressor in place of CatBoostRegressor and all data - management will be properly handled by Freqai. - :param data_dictionary: the dictionary constructed by DataHandler to hold - all the training and test data/labels. + User sets up the training and test data to fit their desired model here + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model """ if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) == 0: diff --git a/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py index 37c6bb186..663a611f0 100644 --- a/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py @@ -13,16 +13,20 @@ logger = logging.getLogger(__name__) class LightGBMRegressorMultiTarget(BaseRegressionModel): """ - User created prediction model. The class needs to override three necessary - functions, predict(), train(), fit(). The class inherits ModelHandler which - has its own DataHandler where data is held, saved, loaded, and managed. + User created prediction model. The class inherits IFreqaiModel, which + means it has full access to all Frequency AI functionality. Typically, + users would use this to override the common `fit()`, `train()`, or + `predict()` methods to add their custom data handling tools or change + various aspects of the training that cannot be configured via the + top level config.json file. """ def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here - :param data_dictionary: the dictionary constructed by DataHandler to hold - all the training and test data/labels. + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model """ lgb = LGBMRegressor(**self.model_training_parameters) diff --git a/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py b/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py index 8694453be..ea7981405 100644 --- a/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py +++ b/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py @@ -57,8 +57,9 @@ class PyTorchMLPClassifier(BasePyTorchClassifier): def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here - :param data_dictionary: the dictionary constructed by DataHandler to hold - all the training and test data/labels. + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model :raises ValueError: If self.class_names is not defined in the parent class. """ diff --git a/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py b/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py index 5ca3486e1..64f0f4b03 100644 --- a/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py +++ b/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py @@ -55,8 +55,9 @@ class PyTorchMLPRegressor(BasePyTorchRegressor): def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here - :param data_dictionary: the dictionary constructed by DataHandler to hold - all the training and test data/labels. + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model """ n_features = data_dictionary["train_features"].shape[-1] diff --git a/freqtrade/freqai/prediction_models/XGBoostClassifier.py b/freqtrade/freqai/prediction_models/XGBoostClassifier.py index 67c7c7783..b6f04b497 100644 --- a/freqtrade/freqai/prediction_models/XGBoostClassifier.py +++ b/freqtrade/freqai/prediction_models/XGBoostClassifier.py @@ -18,16 +18,20 @@ logger = logging.getLogger(__name__) class XGBoostClassifier(BaseClassifierModel): """ - User created prediction model. The class needs to override three necessary - functions, predict(), train(), fit(). The class inherits ModelHandler which - has its own DataHandler where data is held, saved, loaded, and managed. + User created prediction model. The class inherits IFreqaiModel, which + means it has full access to all Frequency AI functionality. Typically, + users would use this to override the common `fit()`, `train()`, or + `predict()` methods to add their custom data handling tools or change + various aspects of the training that cannot be configured via the + top level config.json file. """ def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here - :param data_dictionary: the dictionary constructed by DataHandler to hold - all the training and test data/labels. + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model """ X = data_dictionary["train_features"].to_numpy() diff --git a/freqtrade/freqai/prediction_models/XGBoostRFClassifier.py b/freqtrade/freqai/prediction_models/XGBoostRFClassifier.py index 470c283ea..20156e9fd 100644 --- a/freqtrade/freqai/prediction_models/XGBoostRFClassifier.py +++ b/freqtrade/freqai/prediction_models/XGBoostRFClassifier.py @@ -18,16 +18,20 @@ logger = logging.getLogger(__name__) class XGBoostRFClassifier(BaseClassifierModel): """ - User created prediction model. The class needs to override three necessary - functions, predict(), train(), fit(). The class inherits ModelHandler which - has its own DataHandler where data is held, saved, loaded, and managed. + User created prediction model. The class inherits IFreqaiModel, which + means it has full access to all Frequency AI functionality. Typically, + users would use this to override the common `fit()`, `train()`, or + `predict()` methods to add their custom data handling tools or change + various aspects of the training that cannot be configured via the + top level config.json file. """ def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here - :param data_dictionary: the dictionary constructed by DataHandler to hold - all the training and test data/labels. + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model """ X = data_dictionary["train_features"].to_numpy() diff --git a/freqtrade/freqai/prediction_models/XGBoostRFRegressor.py b/freqtrade/freqai/prediction_models/XGBoostRFRegressor.py index e7cc27f2e..1aefbf19a 100644 --- a/freqtrade/freqai/prediction_models/XGBoostRFRegressor.py +++ b/freqtrade/freqai/prediction_models/XGBoostRFRegressor.py @@ -12,16 +12,20 @@ logger = logging.getLogger(__name__) class XGBoostRFRegressor(BaseRegressionModel): """ - User created prediction model. The class needs to override three necessary - functions, predict(), train(), fit(). The class inherits ModelHandler which - has its own DataHandler where data is held, saved, loaded, and managed. + User created prediction model. The class inherits IFreqaiModel, which + means it has full access to all Frequency AI functionality. Typically, + users would use this to override the common `fit()`, `train()`, or + `predict()` methods to add their custom data handling tools or change + various aspects of the training that cannot be configured via the + top level config.json file. """ def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here - :param data_dictionary: the dictionary constructed by DataHandler to hold - all the training and test data/labels. + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model """ X = data_dictionary["train_features"] diff --git a/freqtrade/freqai/prediction_models/XGBoostRegressor.py b/freqtrade/freqai/prediction_models/XGBoostRegressor.py index 9a280286b..93dfb319e 100644 --- a/freqtrade/freqai/prediction_models/XGBoostRegressor.py +++ b/freqtrade/freqai/prediction_models/XGBoostRegressor.py @@ -12,16 +12,20 @@ logger = logging.getLogger(__name__) class XGBoostRegressor(BaseRegressionModel): """ - User created prediction model. The class needs to override three necessary - functions, predict(), train(), fit(). The class inherits ModelHandler which - has its own DataHandler where data is held, saved, loaded, and managed. + User created prediction model. The class inherits IFreqaiModel, which + means it has full access to all Frequency AI functionality. Typically, + users would use this to override the common `fit()`, `train()`, or + `predict()` methods to add their custom data handling tools or change + various aspects of the training that cannot be configured via the + top level config.json file. """ def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here - :param data_dictionary: the dictionary constructed by DataHandler to hold - all the training and test data/labels. + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model """ X = data_dictionary["train_features"] diff --git a/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py index 920745ec9..a0330485e 100644 --- a/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py @@ -13,16 +13,20 @@ logger = logging.getLogger(__name__) class XGBoostRegressorMultiTarget(BaseRegressionModel): """ - User created prediction model. The class needs to override three necessary - functions, predict(), train(), fit(). The class inherits ModelHandler which - has its own DataHandler where data is held, saved, loaded, and managed. + User created prediction model. The class inherits IFreqaiModel, which + means it has full access to all Frequency AI functionality. Typically, + users would use this to override the common `fit()`, `train()`, or + `predict()` methods to add their custom data handling tools or change + various aspects of the training that cannot be configured via the + top level config.json file. """ def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here - :param data_dictionary: the dictionary constructed by DataHandler to hold - all the training and test data/labels. + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model """ xgb = XGBRegressor(**self.model_training_parameters) diff --git a/freqtrade/freqai/torch/PyTorchModelTrainer.py b/freqtrade/freqai/torch/PyTorchModelTrainer.py index 6449d98b5..9c1a1cb6e 100644 --- a/freqtrade/freqai/torch/PyTorchModelTrainer.py +++ b/freqtrade/freqai/torch/PyTorchModelTrainer.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional import pandas as pd import torch -import torch.nn as nn +from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, TensorDataset @@ -169,6 +169,12 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): n_batches = math.ceil(n_obs // batch_size) epochs = math.ceil(n_iters // n_batches) + if epochs <= 10: + logger.warning("User set `max_iters` in such a way that the trainer will only perform " + f" {epochs} epochs. Please consider increasing this value accordingly") + if epochs <= 1: + logger.warning("Epochs set to 1. Please review your `max_iters` value") + epochs = 1 return epochs def save(self, path: Path): @@ -182,6 +188,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "model_meta_data": self.model_meta_data, + "pytrainer": self }, path) def load(self, path: Path): @@ -195,7 +202,6 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): you can access this dict from any class that inherits IFreqaiModel by calling get_init_model method. """ - self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.model_meta_data = checkpoint["model_meta_data"] diff --git a/freqtrade/freqai/torch/PyTorchTrainerInterface.py b/freqtrade/freqai/torch/PyTorchTrainerInterface.py index 6686555f9..840c145f7 100644 --- a/freqtrade/freqai/torch/PyTorchTrainerInterface.py +++ b/freqtrade/freqai/torch/PyTorchTrainerInterface.py @@ -4,7 +4,7 @@ from typing import Dict, List import pandas as pd import torch -import torch.nn as nn +from torch import nn class PyTorchTrainerInterface(ABC): diff --git a/setup.sh b/setup.sh index a9ff36536..77c77000d 100755 --- a/setup.sh +++ b/setup.sh @@ -85,7 +85,7 @@ function updateenv() { if [[ $REPLY =~ ^[Yy]$ ]] then REQUIREMENTS_FREQAI="-r requirements-freqai.txt --use-pep517" - read -p "Do you also want dependencies for freqai-rl (~700mb additional space required) [y/N]? " + read -p "Do you also want dependencies for freqai-rl or PyTorch (~700mb additional space required) [y/N]? " if [[ $REPLY =~ ^[Yy]$ ]] then REQUIREMENTS_FREQAI="-r requirements-freqai-rl.txt"