fix model loading from disk bug, improve doc, clarify installation/docker instructions, add a torch tag to the freqairl docker image. Fix seriously outdated prediction_model docstrings
This commit is contained in:
		| @@ -539,7 +539,9 @@ class FreqaiDataDrawer: | ||||
|             model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model") | ||||
|         elif self.model_type == 'pytorch': | ||||
|             import torch | ||||
|             model = torch.load(dk.data_path / f"{dk.model_filename}_model.zip") | ||||
|             zip = torch.load(dk.data_path / f"{dk.model_filename}_model.zip") | ||||
|             model = zip["pytrainer"] | ||||
|             model = model.load_from_checkpoint(zip) | ||||
|  | ||||
|         if Path(dk.data_path / f"{dk.model_filename}_svm_model.joblib").is_file(): | ||||
|             dk.svm_model = load(dk.data_path / f"{dk.model_filename}_svm_model.joblib") | ||||
|   | ||||
| @@ -14,16 +14,20 @@ logger = logging.getLogger(__name__) | ||||
|  | ||||
| class CatboostClassifier(BaseClassifierModel): | ||||
|     """ | ||||
|     User created prediction model. The class needs to override three necessary | ||||
|     functions, predict(), train(), fit(). The class inherits ModelHandler which | ||||
|     has its own DataHandler where data is held, saved, loaded, and managed. | ||||
|     User created prediction model. The class inherits IFreqaiModel, which | ||||
|     means it has full access to all Frequency AI functionality. Typically, | ||||
|     users would use this to override the common `fit()`, `train()`, or | ||||
|     `predict()` methods to add their custom data handling tools or change | ||||
|     various aspects of the training that cannot be configured via the | ||||
|     top level config.json file. | ||||
|     """ | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         :param data_dictionary: the dictionary holding all data for train, test, | ||||
|             labels, weights | ||||
|         :param dk: The datakitchen object for the current coin/model | ||||
|         """ | ||||
|  | ||||
|         train_data = Pool( | ||||
|   | ||||
| @@ -15,16 +15,20 @@ logger = logging.getLogger(__name__) | ||||
|  | ||||
| class CatboostClassifierMultiTarget(BaseClassifierModel): | ||||
|     """ | ||||
|     User created prediction model. The class needs to override three necessary | ||||
|     functions, predict(), train(), fit(). The class inherits ModelHandler which | ||||
|     has its own DataHandler where data is held, saved, loaded, and managed. | ||||
|     User created prediction model. The class inherits IFreqaiModel, which | ||||
|     means it has full access to all Frequency AI functionality. Typically, | ||||
|     users would use this to override the common `fit()`, `train()`, or | ||||
|     `predict()` methods to add their custom data handling tools or change | ||||
|     various aspects of the training that cannot be configured via the | ||||
|     top level config.json file. | ||||
|     """ | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         :param data_dictionary: the dictionary holding all data for train, test, | ||||
|             labels, weights | ||||
|         :param dk: The datakitchen object for the current coin/model | ||||
|         """ | ||||
|  | ||||
|         cbc = CatBoostClassifier( | ||||
|   | ||||
| @@ -14,16 +14,20 @@ logger = logging.getLogger(__name__) | ||||
|  | ||||
| class CatboostRegressor(BaseRegressionModel): | ||||
|     """ | ||||
|     User created prediction model. The class needs to override three necessary | ||||
|     functions, predict(), train(), fit(). The class inherits ModelHandler which | ||||
|     has its own DataHandler where data is held, saved, loaded, and managed. | ||||
|     User created prediction model. The class inherits IFreqaiModel, which | ||||
|     means it has full access to all Frequency AI functionality. Typically, | ||||
|     users would use this to override the common `fit()`, `train()`, or | ||||
|     `predict()` methods to add their custom data handling tools or change | ||||
|     various aspects of the training that cannot be configured via the | ||||
|     top level config.json file. | ||||
|     """ | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         :param data_dictionary: the dictionary holding all data for train, test, | ||||
|             labels, weights | ||||
|         :param dk: The datakitchen object for the current coin/model | ||||
|         """ | ||||
|  | ||||
|         train_data = Pool( | ||||
|   | ||||
| @@ -15,16 +15,20 @@ logger = logging.getLogger(__name__) | ||||
|  | ||||
| class CatboostRegressorMultiTarget(BaseRegressionModel): | ||||
|     """ | ||||
|     User created prediction model. The class needs to override three necessary | ||||
|     functions, predict(), train(), fit(). The class inherits ModelHandler which | ||||
|     has its own DataHandler where data is held, saved, loaded, and managed. | ||||
|     User created prediction model. The class inherits IFreqaiModel, which | ||||
|     means it has full access to all Frequency AI functionality. Typically, | ||||
|     users would use this to override the common `fit()`, `train()`, or | ||||
|     `predict()` methods to add their custom data handling tools or change | ||||
|     various aspects of the training that cannot be configured via the | ||||
|     top level config.json file. | ||||
|     """ | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         :param data_dictionary: the dictionary holding all data for train, test, | ||||
|             labels, weights | ||||
|         :param dk: The datakitchen object for the current coin/model | ||||
|         """ | ||||
|  | ||||
|         cbr = CatBoostRegressor( | ||||
|   | ||||
| @@ -12,16 +12,20 @@ logger = logging.getLogger(__name__) | ||||
|  | ||||
| class LightGBMClassifier(BaseClassifierModel): | ||||
|     """ | ||||
|     User created prediction model. The class needs to override three necessary | ||||
|     functions, predict(), train(), fit(). The class inherits ModelHandler which | ||||
|     has its own DataHandler where data is held, saved, loaded, and managed. | ||||
|     User created prediction model. The class inherits IFreqaiModel, which | ||||
|     means it has full access to all Frequency AI functionality. Typically, | ||||
|     users would use this to override the common `fit()`, `train()`, or | ||||
|     `predict()` methods to add their custom data handling tools or change | ||||
|     various aspects of the training that cannot be configured via the | ||||
|     top level config.json file. | ||||
|     """ | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         :param data_dictionary: the dictionary holding all data for train, test, | ||||
|             labels, weights | ||||
|         :param dk: The datakitchen object for the current coin/model | ||||
|         """ | ||||
|  | ||||
|         if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) == 0: | ||||
|   | ||||
| @@ -13,16 +13,20 @@ logger = logging.getLogger(__name__) | ||||
|  | ||||
| class LightGBMClassifierMultiTarget(BaseClassifierModel): | ||||
|     """ | ||||
|     User created prediction model. The class needs to override three necessary | ||||
|     functions, predict(), train(), fit(). The class inherits ModelHandler which | ||||
|     has its own DataHandler where data is held, saved, loaded, and managed. | ||||
|     User created prediction model. The class inherits IFreqaiModel, which | ||||
|     means it has full access to all Frequency AI functionality. Typically, | ||||
|     users would use this to override the common `fit()`, `train()`, or | ||||
|     `predict()` methods to add their custom data handling tools or change | ||||
|     various aspects of the training that cannot be configured via the | ||||
|     top level config.json file. | ||||
|     """ | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         :param data_dictionary: the dictionary holding all data for train, test, | ||||
|             labels, weights | ||||
|         :param dk: The datakitchen object for the current coin/model | ||||
|         """ | ||||
|  | ||||
|         lgb = LGBMClassifier(**self.model_training_parameters) | ||||
|   | ||||
| @@ -12,18 +12,20 @@ logger = logging.getLogger(__name__) | ||||
|  | ||||
| class LightGBMRegressor(BaseRegressionModel): | ||||
|     """ | ||||
|     User created prediction model. The class needs to override three necessary | ||||
|     functions, predict(), train(), fit(). The class inherits ModelHandler which | ||||
|     has its own DataHandler where data is held, saved, loaded, and managed. | ||||
|     User created prediction model. The class inherits IFreqaiModel, which | ||||
|     means it has full access to all Frequency AI functionality. Typically, | ||||
|     users would use this to override the common `fit()`, `train()`, or | ||||
|     `predict()` methods to add their custom data handling tools or change | ||||
|     various aspects of the training that cannot be configured via the | ||||
|     top level config.json file. | ||||
|     """ | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|         Most regressors use the same function names and arguments e.g. user | ||||
|         can drop in LGBMRegressor in place of CatBoostRegressor and all data | ||||
|         management will be properly handled by Freqai. | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary holding all data for train, test, | ||||
|             labels, weights | ||||
|         :param dk: The datakitchen object for the current coin/model | ||||
|         """ | ||||
|  | ||||
|         if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) == 0: | ||||
|   | ||||
| @@ -13,16 +13,20 @@ logger = logging.getLogger(__name__) | ||||
|  | ||||
| class LightGBMRegressorMultiTarget(BaseRegressionModel): | ||||
|     """ | ||||
|     User created prediction model. The class needs to override three necessary | ||||
|     functions, predict(), train(), fit(). The class inherits ModelHandler which | ||||
|     has its own DataHandler where data is held, saved, loaded, and managed. | ||||
|     User created prediction model. The class inherits IFreqaiModel, which | ||||
|     means it has full access to all Frequency AI functionality. Typically, | ||||
|     users would use this to override the common `fit()`, `train()`, or | ||||
|     `predict()` methods to add their custom data handling tools or change | ||||
|     various aspects of the training that cannot be configured via the | ||||
|     top level config.json file. | ||||
|     """ | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         :param data_dictionary: the dictionary holding all data for train, test, | ||||
|             labels, weights | ||||
|         :param dk: The datakitchen object for the current coin/model | ||||
|         """ | ||||
|  | ||||
|         lgb = LGBMRegressor(**self.model_training_parameters) | ||||
|   | ||||
| @@ -57,8 +57,9 @@ class PyTorchMLPClassifier(BasePyTorchClassifier): | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|         all the training and test data/labels. | ||||
|         :param data_dictionary: the dictionary holding all data for train, test, | ||||
|             labels, weights | ||||
|         :param dk: The datakitchen object for the current coin/model | ||||
|         :raises ValueError: If self.class_names is not defined in the parent class. | ||||
|         """ | ||||
|  | ||||
|   | ||||
| @@ -55,8 +55,9 @@ class PyTorchMLPRegressor(BasePyTorchRegressor): | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|         all the training and test data/labels. | ||||
|         :param data_dictionary: the dictionary holding all data for train, test, | ||||
|             labels, weights | ||||
|         :param dk: The datakitchen object for the current coin/model | ||||
|         """ | ||||
|  | ||||
|         n_features = data_dictionary["train_features"].shape[-1] | ||||
|   | ||||
| @@ -18,16 +18,20 @@ logger = logging.getLogger(__name__) | ||||
|  | ||||
| class XGBoostClassifier(BaseClassifierModel): | ||||
|     """ | ||||
|     User created prediction model. The class needs to override three necessary | ||||
|     functions, predict(), train(), fit(). The class inherits ModelHandler which | ||||
|     has its own DataHandler where data is held, saved, loaded, and managed. | ||||
|     User created prediction model. The class inherits IFreqaiModel, which | ||||
|     means it has full access to all Frequency AI functionality. Typically, | ||||
|     users would use this to override the common `fit()`, `train()`, or | ||||
|     `predict()` methods to add their custom data handling tools or change | ||||
|     various aspects of the training that cannot be configured via the | ||||
|     top level config.json file. | ||||
|     """ | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         :param data_dictionary: the dictionary holding all data for train, test, | ||||
|             labels, weights | ||||
|         :param dk: The datakitchen object for the current coin/model | ||||
|         """ | ||||
|  | ||||
|         X = data_dictionary["train_features"].to_numpy() | ||||
|   | ||||
| @@ -18,16 +18,20 @@ logger = logging.getLogger(__name__) | ||||
|  | ||||
| class XGBoostRFClassifier(BaseClassifierModel): | ||||
|     """ | ||||
|     User created prediction model. The class needs to override three necessary | ||||
|     functions, predict(), train(), fit(). The class inherits ModelHandler which | ||||
|     has its own DataHandler where data is held, saved, loaded, and managed. | ||||
|     User created prediction model. The class inherits IFreqaiModel, which | ||||
|     means it has full access to all Frequency AI functionality. Typically, | ||||
|     users would use this to override the common `fit()`, `train()`, or | ||||
|     `predict()` methods to add their custom data handling tools or change | ||||
|     various aspects of the training that cannot be configured via the | ||||
|     top level config.json file. | ||||
|     """ | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|             all the training and test data/labels. | ||||
|         :param data_dictionary: the dictionary holding all data for train, test, | ||||
|             labels, weights | ||||
|         :param dk: The datakitchen object for the current coin/model | ||||
|         """ | ||||
|  | ||||
|         X = data_dictionary["train_features"].to_numpy() | ||||
|   | ||||
| @@ -12,16 +12,20 @@ logger = logging.getLogger(__name__) | ||||
|  | ||||
| class XGBoostRFRegressor(BaseRegressionModel): | ||||
|     """ | ||||
|     User created prediction model. The class needs to override three necessary | ||||
|     functions, predict(), train(), fit(). The class inherits ModelHandler which | ||||
|     has its own DataHandler where data is held, saved, loaded, and managed. | ||||
|     User created prediction model. The class inherits IFreqaiModel, which | ||||
|     means it has full access to all Frequency AI functionality. Typically, | ||||
|     users would use this to override the common `fit()`, `train()`, or | ||||
|     `predict()` methods to add their custom data handling tools or change | ||||
|     various aspects of the training that cannot be configured via the | ||||
|     top level config.json file. | ||||
|     """ | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         :param data_dictionary: the dictionary holding all data for train, test, | ||||
|             labels, weights | ||||
|         :param dk: The datakitchen object for the current coin/model | ||||
|         """ | ||||
|  | ||||
|         X = data_dictionary["train_features"] | ||||
|   | ||||
| @@ -12,16 +12,20 @@ logger = logging.getLogger(__name__) | ||||
|  | ||||
| class XGBoostRegressor(BaseRegressionModel): | ||||
|     """ | ||||
|     User created prediction model. The class needs to override three necessary | ||||
|     functions, predict(), train(), fit(). The class inherits ModelHandler which | ||||
|     has its own DataHandler where data is held, saved, loaded, and managed. | ||||
|     User created prediction model. The class inherits IFreqaiModel, which | ||||
|     means it has full access to all Frequency AI functionality. Typically, | ||||
|     users would use this to override the common `fit()`, `train()`, or | ||||
|     `predict()` methods to add their custom data handling tools or change | ||||
|     various aspects of the training that cannot be configured via the | ||||
|     top level config.json file. | ||||
|     """ | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         :param data_dictionary: the dictionary holding all data for train, test, | ||||
|             labels, weights | ||||
|         :param dk: The datakitchen object for the current coin/model | ||||
|         """ | ||||
|  | ||||
|         X = data_dictionary["train_features"] | ||||
|   | ||||
| @@ -13,16 +13,20 @@ logger = logging.getLogger(__name__) | ||||
|  | ||||
| class XGBoostRegressorMultiTarget(BaseRegressionModel): | ||||
|     """ | ||||
|     User created prediction model. The class needs to override three necessary | ||||
|     functions, predict(), train(), fit(). The class inherits ModelHandler which | ||||
|     has its own DataHandler where data is held, saved, loaded, and managed. | ||||
|     User created prediction model. The class inherits IFreqaiModel, which | ||||
|     means it has full access to all Frequency AI functionality. Typically, | ||||
|     users would use this to override the common `fit()`, `train()`, or | ||||
|     `predict()` methods to add their custom data handling tools or change | ||||
|     various aspects of the training that cannot be configured via the | ||||
|     top level config.json file. | ||||
|     """ | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         :param data_dictionary: the dictionary holding all data for train, test, | ||||
|             labels, weights | ||||
|         :param dk: The datakitchen object for the current coin/model | ||||
|         """ | ||||
|  | ||||
|         xgb = XGBRegressor(**self.model_training_parameters) | ||||
|   | ||||
| @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional | ||||
|  | ||||
| import pandas as pd | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch import nn | ||||
| from torch.optim import Optimizer | ||||
| from torch.utils.data import DataLoader, TensorDataset | ||||
|  | ||||
| @@ -169,6 +169,12 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): | ||||
|  | ||||
|         n_batches = math.ceil(n_obs // batch_size) | ||||
|         epochs = math.ceil(n_iters // n_batches) | ||||
|         if epochs <= 10: | ||||
|             logger.warning("User set `max_iters` in such a way that the trainer will only perform " | ||||
|                            f" {epochs} epochs. Please consider increasing this value accordingly") | ||||
|             if epochs <= 1: | ||||
|                 logger.warning("Epochs set to 1. Please review your `max_iters` value") | ||||
|                 epochs = 1 | ||||
|         return epochs | ||||
|  | ||||
|     def save(self, path: Path): | ||||
| @@ -182,6 +188,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): | ||||
|             "model_state_dict": self.model.state_dict(), | ||||
|             "optimizer_state_dict": self.optimizer.state_dict(), | ||||
|             "model_meta_data": self.model_meta_data, | ||||
|             "pytrainer": self | ||||
|         }, path) | ||||
|  | ||||
|     def load(self, path: Path): | ||||
| @@ -195,7 +202,6 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): | ||||
|         you can access this dict from any class that inherits IFreqaiModel by calling | ||||
|         get_init_model method. | ||||
|         """ | ||||
|  | ||||
|         self.model.load_state_dict(checkpoint["model_state_dict"]) | ||||
|         self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | ||||
|         self.model_meta_data = checkpoint["model_meta_data"] | ||||
|   | ||||
| @@ -4,7 +4,7 @@ from typing import Dict, List | ||||
|  | ||||
| import pandas as pd | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch import nn | ||||
|  | ||||
|  | ||||
| class PyTorchTrainerInterface(ABC): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user