convert single quotes to double quotes

This commit is contained in:
Yinon Polak 2023-03-09 13:29:11 +02:00
parent 2ef11faba7
commit e88a0d5248
3 changed files with 16 additions and 15 deletions

View File

@ -19,9 +19,9 @@ class BasePyTorchModel(IFreqaiModel):
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(config=kwargs['config']) super().__init__(config=kwargs["config"])
self.dd.model_type = 'pytorch' self.dd.model_type = "pytorch"
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = "cuda" if torch.cuda.is_available() else "cpu"
def train( def train(
self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs

View File

@ -61,7 +61,7 @@ class PyTorchModelTrainer:
""" """
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary) data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary)
epochs = self.calc_n_epochs( epochs = self.calc_n_epochs(
n_obs=len(data_dictionary['train_features']), n_obs=len(data_dictionary["train_features"]),
batch_size=self.batch_size, batch_size=self.batch_size,
n_iters=self.max_iters n_iters=self.max_iters
) )
@ -73,7 +73,7 @@ class PyTorchModelTrainer:
f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}" f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}"
) )
# training # training
for batch_data in data_loaders_dictionary['train']: for batch_data in data_loaders_dictionary["train"]:
xb, yb = batch_data xb, yb = batch_data
xb = xb.to(self.device) xb = xb.to(self.device)
yb = yb.to(self.device) yb = yb.to(self.device)
@ -93,12 +93,12 @@ class PyTorchModelTrainer:
self.model.eval() self.model.eval()
epochs = self.calc_n_epochs( epochs = self.calc_n_epochs(
n_obs=len(data_dictionary['test_features']), n_obs=len(data_dictionary["test_features"]),
batch_size=self.batch_size, batch_size=self.batch_size,
n_iters=self.eval_iters n_iters=self.eval_iters
) )
loss_dictionary = {} loss_dictionary = {}
for split in ['train', 'test']: for split in ["train", "test"]:
losses = torch.zeros(epochs) losses = torch.zeros(epochs)
for i, batch in enumerate(data_loader_dictionary[split]): for i, batch in enumerate(data_loader_dictionary[split]):
xb, yb = batch xb, yb = batch
@ -121,12 +121,12 @@ class PyTorchModelTrainer:
Converts the input data to PyTorch tensors using a data loader. Converts the input data to PyTorch tensors using a data loader.
""" """
data_loader_dictionary = {} data_loader_dictionary = {}
for split in ['train', 'test']: for split in ["train", "test"]:
labels_shape = data_dictionary[f'{split}_labels'].shape labels_shape = data_dictionary[f"{split}_labels"].shape
labels_view = labels_shape[0] if labels_shape[1] == 1 else labels_shape labels_view = labels_shape[0] if labels_shape[1] == 1 else labels_shape
dataset = TensorDataset( dataset = TensorDataset(
torch.from_numpy(data_dictionary[f'{split}_features'].values).float(), torch.from_numpy(data_dictionary[f"{split}_features"].values).float(),
torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values) torch.from_numpy(data_dictionary[f"{split}_labels"].astype(float).values)
.long() .long()
.view(labels_view) .view(labels_view)
) )
@ -148,6 +148,7 @@ class PyTorchModelTrainer:
Calculates the number of epochs required to reach the maximum number Calculates the number of epochs required to reach the maximum number
of iterations specified in the model training parameters. of iterations specified in the model training parameters.
""" """
n_batches = n_obs // batch_size n_batches = n_obs // batch_size
epochs = n_iters // n_batches epochs = n_iters // n_batches
return epochs return epochs
@ -160,9 +161,9 @@ class PyTorchModelTrainer:
""" """
torch.save({ torch.save({
'model_state_dict': self.model.state_dict(), "model_state_dict": self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(),
'model_meta_data': self.model_meta_data, "model_meta_data": self.model_meta_data,
}, path) }, path)
def load_from_file(self, path: Path): def load_from_file(self, path: Path):

View File

@ -59,7 +59,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
self.init_class_names_to_index_mapping(self.class_names) self.init_class_names_to_index_mapping(self.class_names)
self.encode_classes_name(data_dictionary, dk) self.encode_classes_name(data_dictionary, dk)
n_features = data_dictionary['train_features'].shape[-1] n_features = data_dictionary["train_features"].shape[-1]
model = PyTorchMLPModel( model = PyTorchMLPModel(
input_dim=n_features, input_dim=n_features,
hidden_dim=self.n_hidden, hidden_dim=self.n_hidden,