diff --git a/freqtrade/freqai/base_models/PyTorchModelTrainer.py b/freqtrade/freqai/base_models/PyTorchModelTrainer.py index 0ca28d2e9..99ee44e3b 100644 --- a/freqtrade/freqai/base_models/PyTorchModelTrainer.py +++ b/freqtrade/freqai/base_models/PyTorchModelTrainer.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Optional import pandas as pd import torch @@ -21,7 +21,7 @@ class PyTorchModelTrainer: device: str, batch_size: int, max_iters: int, - eval_iters: int, + max_n_eval_batches: int, init_model: Dict, model_meta_data: Dict[str, Any] = {}, ): @@ -34,7 +34,7 @@ class PyTorchModelTrainer: :param max_iters: The number of training iterations to run. iteration here refers to the number of times we call self.optimizer.step(). used to calculate n_epochs. - :param eval_iters: The number of iterations used to estimate the loss. + :param max_n_eval_batches: The maximum number batches to use for evaluation. :param init_model: A dictionary containing the initial model/optimizer state_dict and model_meta_data saved by self.save() method. :param model_meta_data: Additional metadata about the model (optional). @@ -46,7 +46,7 @@ class PyTorchModelTrainer: self.device = device self.max_iters = max_iters self.batch_size = batch_size - self.eval_iters = eval_iters + self.max_n_eval_batches = max_n_eval_batches if init_model: self.load_from_checkpoint(init_model) @@ -67,7 +67,7 @@ class PyTorchModelTrainer: ) for epoch in range(epochs): # evaluation - losses = self.estimate_loss(data_loaders_dictionary, data_dictionary) + losses = self.estimate_loss(data_loaders_dictionary, self.max_n_eval_batches) logger.info( f"epoch ({epoch}/{epochs}):" f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}" @@ -88,27 +88,27 @@ class PyTorchModelTrainer: def estimate_loss( self, data_loader_dictionary: Dict[str, DataLoader], - data_dictionary: Dict[str, pd.DataFrame] + max_n_eval_batches: Optional[int] ) -> Dict[str, float]: self.model.eval() - epochs = self.calc_n_epochs( - n_obs=len(data_dictionary["test_features"]), - batch_size=self.batch_size, - n_iters=self.eval_iters - ) loss_dictionary = {} + n_batches = 0 for split in ["train", "test"]: - losses = torch.zeros(epochs) + losses = [] for i, batch in enumerate(data_loader_dictionary[split]): + if max_n_eval_batches and i > max_n_eval_batches: + n_batches += 1 + break + xb, yb = batch xb = xb.to(self.device) yb = yb.to(self.device) yb_pred = self.model(xb) loss = self.criterion(yb_pred, yb) - losses[i] = loss.item() + losses.append(loss.item()) - loss_dictionary[split] = losses.mean().item() + loss_dictionary[split] = sum(losses) / len(losses) self.model.train() return loss_dictionary diff --git a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py index a5b8b1591..f951778bf 100644 --- a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py +++ b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py @@ -39,7 +39,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): self.max_iters = model_training_parameters.get("max_iters", 100) self.batch_size = model_training_parameters.get("batch_size", 64) self.learning_rate = model_training_parameters.get("learning_rate", 3e-4) - self.eval_iters = model_training_parameters.get("eval_iters", 10) + self.max_n_eval_batches = model_training_parameters.get("max_n_eval_batches", None) self.class_name_to_index = None self.index_to_class_name = None @@ -79,7 +79,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): device=self.device, batch_size=self.batch_size, max_iters=self.max_iters, - eval_iters=self.eval_iters, + max_n_eval_batches=self.max_n_eval_batches, init_model=init_model ) trainer.fit(data_dictionary)