move default attributes of pytorch classifier to initializer,
to prevent mypy from complaining
This commit is contained in:
		| @@ -41,12 +41,18 @@ class PyTorchMLPClassifier(PyTorchClassifier): | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|     def __init__( | ||||
|             self, | ||||
|             learning_rate: float = 3e-4, | ||||
|             model_kwargs: Dict[str, Any] = {}, | ||||
|             trainer_kwargs: Dict[str, Any] = {}, | ||||
|             **kwargs | ||||
|     ): | ||||
|         super().__init__(**kwargs) | ||||
|         model_training_params = self.freqai_info.get("model_training_parameters", {}) | ||||
|         self.learning_rate: float = model_training_params.get("learning_rate", 3e-4) | ||||
|         self.model_kwargs: Dict[str, any] = model_training_params.get("model_kwargs", {}) | ||||
|         self.trainer_kwargs: Dict[str, any] = model_training_params.get("trainer_kwargs", {}) | ||||
|         config = self.freqai_info.get("model_training_parameters", {}) | ||||
|         self.learning_rate: float = config.get("learning_rate", learning_rate) | ||||
|         self.model_kwargs: Dict[str, any] = config.get("model_kwargs", model_kwargs) | ||||
|         self.trainer_kwargs: Dict[str, any] = config.get("trainer_kwargs", trainer_kwargs) | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|   | ||||
| @@ -41,12 +41,18 @@ class PyTorchMLPRegressor(PyTorchRegressor): | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|     def __init__( | ||||
|             self, | ||||
|             learning_rate: float = 3e-4, | ||||
|             model_kwargs: Dict[str, Any] = {}, | ||||
|             trainer_kwargs: Dict[str, Any] = {}, | ||||
|             **kwargs | ||||
|     ): | ||||
|         super().__init__(**kwargs) | ||||
|         model_training_params = self.freqai_info.get("model_training_parameters", {}) | ||||
|         self.learning_rate: float = model_training_params.get("learning_rate", 3e-4) | ||||
|         self.model_kwargs: Dict[str, any] = model_training_params.get("model_kwargs", {}) | ||||
|         self.trainer_kwargs: Dict[str, any] = model_training_params.get("trainer_kwargs", {}) | ||||
|         config = self.freqai_info.get("model_training_parameters", {}) | ||||
|         self.learning_rate: float = config.get("learning_rate", learning_rate) | ||||
|         self.model_kwargs: Dict[str, any] = config.get("model_kwargs", model_kwargs) | ||||
|         self.trainer_kwargs: Dict[str, any] = config.get("trainer_kwargs", trainer_kwargs) | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user