diff --git a/freqtrade/freqai/base_models/PyTorchModelTrainer.py b/freqtrade/freqai/base_models/PyTorchModelTrainer.py index 4a091f52c..a934814ef 100644 --- a/freqtrade/freqai/base_models/PyTorchModelTrainer.py +++ b/freqtrade/freqai/base_models/PyTorchModelTrainer.py @@ -26,29 +26,6 @@ class PyTorchModelTrainer: model_meta_data: Dict[str, Any] = {}, ): """ - A class for training PyTorch models. - Implements the training loop logic, load/save methods. - - fit method - training loop logic: - - Calculates the predicted output for the batch using the PyTorch model. - - Calculates the loss between the predicted and actual output using a loss function. - - Computes the gradients of the loss with respect to the model's parameters using - backpropagation. - - Updates the model's parameters using an optimizer. - - save method: - called by DataDrawer - - Saving any nn.Module state_dict - - Saving model_meta_data, this dict should contain any additional data that the - user needs to store. e.g class_names for classification models. - - load method: - currently DataDrawer is responsible for the actual loading. - when using continual_learning the DataDrawer will load the dict - (saved by self.save(path)). and this class will populate the necessary - state_dict of the self.model & self.optimizer and self.model_meta_data. - - :param model: The PyTorch model to be trained. :param optimizer: The optimizer to use for training. :param criterion: The loss function to use for training. @@ -76,7 +53,11 @@ class PyTorchModelTrainer: def fit(self, data_dictionary: Dict[str, pd.DataFrame]): """ - General training loop. + - Calculates the predicted output for the batch using the PyTorch model. + - Calculates the loss between the predicted and actual output using a loss function. + - Computes the gradients of the loss with respect to the model's parameters using + backpropagation. + - Updates the model's parameters using an optimizer. """ data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary) epochs = self.calc_n_epochs( @@ -172,6 +153,12 @@ class PyTorchModelTrainer: return epochs def save(self, path: Path): + """ + - Saving any nn.Module state_dict + - Saving model_meta_data, this dict should contain any additional data that the + user needs to store. e.g class_names for classification models. + """ + torch.save({ 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), @@ -183,7 +170,14 @@ class PyTorchModelTrainer: return self.load_from_checkpoint(checkpoint) def load_from_checkpoint(self, checkpoint: Dict): - self.model.load_state_dict(checkpoint['model_state_dict']) - self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + """ + when using continual_learning, DataDrawer will load the dictionary + (containing state dicts and model_meta_data) by calling torch.load(path). + you can access this dict from any class that inherits IFreqaiModel by calling + get_init_model method. + """ + + self.model.load_state_dict(checkpoint["model_state_dict"]) + self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.model_meta_data = checkpoint["model_meta_data"] return self