fix config example in pytorch mlp documentation
This commit is contained in:
parent
026b6a39a9
commit
b795a70102
@ -26,7 +26,7 @@ class PyTorchMLPClassifier(BasePyTorchClassifier):
|
|||||||
"trainer_kwargs": {
|
"trainer_kwargs": {
|
||||||
"max_iters": 5000,
|
"max_iters": 5000,
|
||||||
"batch_size": 64,
|
"batch_size": 64,
|
||||||
"max_n_eval_batches": None,
|
"max_n_eval_batches": null,
|
||||||
},
|
},
|
||||||
"model_kwargs": {
|
"model_kwargs": {
|
||||||
"hidden_dim": 512,
|
"hidden_dim": 512,
|
||||||
@ -49,7 +49,7 @@ class PyTorchMLPClassifier(BasePyTorchClassifier):
|
|||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||||
all the training and test data/labels.
|
all the training and test data/labels.
|
||||||
:raises ValueError: If self.class_names is not defined in the parent class.
|
:raises ValueError: If self.class_names is not defined in the parent class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ class PyTorchMLPRegressor(BasePyTorchRegressor):
|
|||||||
"trainer_kwargs": {
|
"trainer_kwargs": {
|
||||||
"max_iters": 5000,
|
"max_iters": 5000,
|
||||||
"batch_size": 64,
|
"batch_size": 64,
|
||||||
"max_n_eval_batches": None,
|
"max_n_eval_batches": null,
|
||||||
},
|
},
|
||||||
"model_kwargs": {
|
"model_kwargs": {
|
||||||
"hidden_dim": 512,
|
"hidden_dim": 512,
|
||||||
@ -50,7 +50,7 @@ class PyTorchMLPRegressor(BasePyTorchRegressor):
|
|||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||||
all the training and test data/labels.
|
all the training and test data/labels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
n_features = data_dictionary["train_features"].shape[-1]
|
n_features = data_dictionary["train_features"].shape[-1]
|
||||||
|
Loading…
Reference in New Issue
Block a user