diff --git a/freqtrade/freqai/base_models/PyTorchModelTrainer.py b/freqtrade/freqai/base_models/PyTorchModelTrainer.py index 99ee44e3b..1b328f4fe 100644 --- a/freqtrade/freqai/base_models/PyTorchModelTrainer.py +++ b/freqtrade/freqai/base_models/PyTorchModelTrainer.py @@ -66,14 +66,9 @@ class PyTorchModelTrainer: n_iters=self.max_iters ) for epoch in range(epochs): - # evaluation - losses = self.estimate_loss(data_loaders_dictionary, self.max_n_eval_batches) - logger.info( - f"epoch ({epoch}/{epochs}):" - f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}" - ) # training - for batch_data in data_loaders_dictionary["train"]: + losses = [] + for i, batch_data in enumerate(data_loaders_dictionary["train"]): xb, yb = batch_data xb = xb.to(self.device) yb = yb.to(self.device) @@ -83,35 +78,40 @@ class PyTorchModelTrainer: self.optimizer.zero_grad(set_to_none=True) loss.backward() self.optimizer.step() + losses.append(loss.item()) + train_loss = sum(losses) / len(losses) + + # evaluation + test_loss = self.estimate_loss(data_loaders_dictionary, self.max_n_eval_batches, "test") + logger.info( + f"epoch ({epoch}/{epochs}):" + f" train loss {train_loss:.4f} ; test loss {test_loss:.4f}" + ) @torch.no_grad() def estimate_loss( self, data_loader_dictionary: Dict[str, DataLoader], - max_n_eval_batches: Optional[int] - ) -> Dict[str, float]: - + max_n_eval_batches: Optional[int], + split: str, + ) -> float: self.model.eval() - loss_dictionary = {} n_batches = 0 - for split in ["train", "test"]: - losses = [] - for i, batch in enumerate(data_loader_dictionary[split]): - if max_n_eval_batches and i > max_n_eval_batches: - n_batches += 1 - break + losses = [] + for i, batch in enumerate(data_loader_dictionary[split]): + if max_n_eval_batches and i > max_n_eval_batches: + n_batches += 1 + break - xb, yb = batch - xb = xb.to(self.device) - yb = yb.to(self.device) - yb_pred = self.model(xb) - loss = self.criterion(yb_pred, yb) - losses.append(loss.item()) - - loss_dictionary[split] = sum(losses) / len(losses) + xb, yb = batch + xb = xb.to(self.device) + yb = yb.to(self.device) + yb_pred = self.model(xb) + loss = self.criterion(yb_pred, yb) + losses.append(loss.item()) self.model.train() - return loss_dictionary + return sum(losses) / len(losses) def create_data_loaders_dictionary( self,