refactor classifiers class names
This commit is contained in:
parent
501e746c52
commit
601c37f862
@ -4,11 +4,11 @@ import torch
|
|||||||
|
|
||||||
from freqtrade.freqai.base_models.PyTorchModelTrainer import PyTorchModelTrainer
|
from freqtrade.freqai.base_models.PyTorchModelTrainer import PyTorchModelTrainer
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.prediction_models.PyTorchClassifierClassifier import PyTorchClassifier
|
from freqtrade.freqai.prediction_models.PyTorchClassifier import PyTorchClassifier
|
||||||
from freqtrade.freqai.prediction_models.PyTorchMLPModel import PyTorchMLPModel
|
from freqtrade.freqai.prediction_models.PyTorchMLPModel import PyTorchMLPModel
|
||||||
|
|
||||||
|
|
||||||
class MLPPyTorchClassifier(PyTorchClassifier):
|
class PyTorchMLPClassifier(PyTorchClassifier):
|
||||||
"""
|
"""
|
||||||
This class implements the fit method of IFreqaiModel.
|
This class implements the fit method of IFreqaiModel.
|
||||||
int the fit method we initialize the model and trainer objects.
|
int the fit method we initialize the model and trainer objects.
|
@ -49,8 +49,8 @@ class PyTorchMLPModel(nn.Module):
|
|||||||
x = self.relu(self.input_layer(x))
|
x = self.relu(self.input_layer(x))
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
x = self.blocks(x)
|
x = self.blocks(x)
|
||||||
logits = self.output_layer(x)
|
x = self.output_layer(x)
|
||||||
return logits
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user