revert to using model_training_parameters
This commit is contained in:
		| @@ -34,13 +34,15 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): | ||||
|         """ | ||||
|  | ||||
|         super().__init__(**kwargs) | ||||
|         trainer_kwargs = self.freqai_info.get("trainer_kwargs", {}) | ||||
|         self.n_hidden: int = trainer_kwargs.get("n_hidden", 1024) | ||||
|         self.max_iters: int = trainer_kwargs.get("max_iters", 100) | ||||
|         self.batch_size: int = trainer_kwargs.get("batch_size", 64) | ||||
|         self.learning_rate: float = trainer_kwargs.get("learning_rate", 3e-4) | ||||
|         self.max_n_eval_batches: Optional[int] = trainer_kwargs.get("max_n_eval_batches", None) | ||||
|         self.model_kwargs: Dict = trainer_kwargs.get("model_kwargs", {}) | ||||
|         model_training_params = self.freqai_info.get("model_training_parameters", {}) | ||||
|         self.n_hidden: int = model_training_params.get("n_hidden", 1024) | ||||
|         self.max_iters: int = model_training_params.get("max_iters", 100) | ||||
|         self.batch_size: int = model_training_params.get("batch_size", 64) | ||||
|         self.learning_rate: float = model_training_params.get("learning_rate", 3e-4) | ||||
|         self.max_n_eval_batches: Optional[int] = model_training_params.get( | ||||
|             "max_n_eval_batches", None | ||||
|         ) | ||||
|         self.model_kwargs: Dict = model_training_params.get("model_kwargs", {}) | ||||
|         self.class_name_to_index = None | ||||
|         self.index_to_class_name = None | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user