reformat code

This commit is contained in:
Yinon Polak 2023-03-06 17:50:02 +02:00
parent 348a08f1c4
commit e6e747bcd8
2 changed files with 16 additions and 14 deletions

View File

@ -69,7 +69,7 @@ class PyTorchModelTrainer:
self.model.eval() self.model.eval()
epochs = self.calc_n_epochs( epochs = self.calc_n_epochs(
n_obs=len(data_dictionary[f'test_features']), n_obs=len(data_dictionary['test_features']),
batch_size=self.batch_size, batch_size=self.batch_size,
n_iters=self.eval_iters n_iters=self.eval_iters
) )
@ -101,8 +101,11 @@ class PyTorchModelTrainer:
torch.from_numpy(data_dictionary[f'{split}_features'].values).float(), torch.from_numpy(data_dictionary[f'{split}_features'].values).float(),
torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values) torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values)
.long() .long()
.view(labels_view) # todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes]. need to resolve it per ClassifierModel .view(labels_view)
) )
# todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes].
# need to resolve it per ClassifierModel
data_loader = DataLoader( data_loader = DataLoader(
dataset, dataset,
batch_size=self.batch_size, batch_size=self.batch_size,

View File

@ -24,7 +24,6 @@ from freqtrade.exceptions import OperationalException
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.strategy.interface import IStrategy from freqtrade.strategy.interface import IStrategy
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -90,8 +89,8 @@ class FreqaiDataDrawer:
self.metric_tracker_lock = threading.Lock() self.metric_tracker_lock = threading.Lock()
self.old_DBSCAN_eps: Dict[str, float] = {} self.old_DBSCAN_eps: Dict[str, float] = {}
self.empty_pair_dict: pair_info = { self.empty_pair_dict: pair_info = {
"model_filename": "", "trained_timestamp": 0, "model_filename": "", "trained_timestamp": 0,
"data_path": "", "extras": {}} "data_path": "", "extras": {}}
self.model_type = self.freqai_info.get('model_save_type', 'joblib') self.model_type = self.freqai_info.get('model_save_type', 'joblib')
def update_metric_tracker(self, metric: str, value: float, pair: str) -> None: def update_metric_tracker(self, metric: str, value: float, pair: str) -> None:
@ -446,9 +445,9 @@ class FreqaiDataDrawer:
dump(model, save_path / f"{dk.model_filename}_model.joblib") dump(model, save_path / f"{dk.model_filename}_model.joblib")
elif self.model_type == 'keras': elif self.model_type == 'keras':
model.save(save_path / f"{dk.model_filename}_model.h5") model.save(save_path / f"{dk.model_filename}_model.h5")
elif 'stable_baselines' in self.model_type or\ elif ('stable_baselines' in self.model_type or
'sb3_contrib' == self.model_type or\ 'sb3_contrib' == self.model_type or
'pytorch' == self.model_type: 'pytorch' == self.model_type):
model.save(save_path / f"{dk.model_filename}_model.zip") model.save(save_path / f"{dk.model_filename}_model.zip")
if dk.svm_model is not None: if dk.svm_model is not None:
@ -581,16 +580,16 @@ class FreqaiDataDrawer:
if len(df_dp.index) == 0: if len(df_dp.index) == 0:
continue continue
if str(hist_df.iloc[-1]["date"]) == str( if str(hist_df.iloc[-1]["date"]) == str(
df_dp.iloc[-1:]["date"].iloc[-1] df_dp.iloc[-1:]["date"].iloc[-1]
): ):
continue continue
try: try:
index = ( index = (
df_dp.loc[ df_dp.loc[
df_dp["date"] == hist_df.iloc[-1]["date"] df_dp["date"] == hist_df.iloc[-1]["date"]
].index[0] ].index[0]
+ 1 + 1
) )
except IndexError: except IndexError:
if hist_df.iloc[-1]['date'] < df_dp['date'].iloc[0]: if hist_df.iloc[-1]['date'] < df_dp['date'].iloc[0]:
@ -643,7 +642,7 @@ class FreqaiDataDrawer:
) )
def get_base_and_corr_dataframes( def get_base_and_corr_dataframes(
self, timerange: TimeRange, pair: str, dk: FreqaiDataKitchen self, timerange: TimeRange, pair: str, dk: FreqaiDataKitchen
) -> Tuple[Dict[Any, Any], Dict[Any, Any]]: ) -> Tuple[Dict[Any, Any], Dict[Any, Any]]:
""" """
Searches through our historic_data in memory and returns the dataframes relevant Searches through our historic_data in memory and returns the dataframes relevant