refactor classifiers class names
This commit is contained in:
parent
501e746c52
commit
601c37f862
@ -4,11 +4,11 @@ import torch
|
||||
|
||||
from freqtrade.freqai.base_models.PyTorchModelTrainer import PyTorchModelTrainer
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.prediction_models.PyTorchClassifierClassifier import PyTorchClassifier
|
||||
from freqtrade.freqai.prediction_models.PyTorchClassifier import PyTorchClassifier
|
||||
from freqtrade.freqai.prediction_models.PyTorchMLPModel import PyTorchMLPModel
|
||||
|
||||
|
||||
class MLPPyTorchClassifier(PyTorchClassifier):
|
||||
class PyTorchMLPClassifier(PyTorchClassifier):
|
||||
"""
|
||||
This class implements the fit method of IFreqaiModel.
|
||||
int the fit method we initialize the model and trainer objects.
|
@ -49,8 +49,8 @@ class PyTorchMLPModel(nn.Module):
|
||||
x = self.relu(self.input_layer(x))
|
||||
x = self.dropout(x)
|
||||
x = self.blocks(x)
|
||||
logits = self.output_layer(x)
|
||||
return logits
|
||||
x = self.output_layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
Loading…
Reference in New Issue
Block a user