diff --git a/freqtrade/freqai/base_models/PyTorchModelTrainer.py b/freqtrade/freqai/base_models/PyTorchModelTrainer.py index 6a4b128e3..2ef4b57c9 100644 --- a/freqtrade/freqai/base_models/PyTorchModelTrainer.py +++ b/freqtrade/freqai/base_models/PyTorchModelTrainer.py @@ -22,6 +22,7 @@ class PyTorchModelTrainer: device: str, init_model: Dict, target_tensor_type: torch.dtype, + squeeze_target_tensor: bool = False, model_meta_data: Dict[str, Any] = {}, **kwargs ): @@ -35,11 +36,14 @@ class PyTorchModelTrainer: :param target_tensor_type: type of target tensor, for classification usually torch.long, for regressor usually torch.float. :param model_meta_data: Additional metadata about the model (optional). + :param squeeze_target_tensor: controls the target shape, used for loss functions + that requires 0D or 1D. :param max_iters: The number of training iterations to run. iteration here refers to the number of times we call self.optimizer.step(). used to calculate n_epochs. :param batch_size: The size of the batches to use during training. :param max_n_eval_batches: The maximum number batches to use for evaluation. + """ self.model = model self.optimizer = optimizer @@ -50,6 +54,7 @@ class PyTorchModelTrainer: self.max_iters: int = kwargs.get("max_iters", 100) self.batch_size: int = kwargs.get("batch_size", 64) self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None) + self.squeeze_target_tensor = squeeze_target_tensor if init_model: self.load_from_checkpoint(init_model) @@ -124,15 +129,14 @@ class PyTorchModelTrainer: """ data_loader_dictionary = {} for split in ["train", "test"]: - labels_shape = data_dictionary[f"{split}_labels"].shape - labels_view = (labels_shape[0], 1) if labels_shape[1] == 1 else labels_shape - dataset = TensorDataset( - torch.from_numpy(data_dictionary[f"{split}_features"].values).float(), - torch.from_numpy(data_dictionary[f"{split}_labels"].values) + x = torch.from_numpy(data_dictionary[f"{split}_features"].values).float() + y = torch.from_numpy(data_dictionary[f"{split}_labels"].values)\ .to(self.target_tensor_type) - .view(labels_view) - ) + if self.squeeze_target_tensor: + y = y.squeeze() + + dataset = TensorDataset(x, y) data_loader = DataLoader( dataset, batch_size=self.batch_size, diff --git a/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py b/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py index e26b8b52c..b8f2df28b 100644 --- a/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py +++ b/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py @@ -73,6 +73,7 @@ class PyTorchMLPClassifier(PyTorchClassifier): device=self.device, init_model=init_model, target_tensor_type=torch.long, + squeeze_target_tensor=True, **self.trainer_kwargs, ) trainer.fit(data_dictionary)