reformat code
This commit is contained in:
parent
348a08f1c4
commit
e6e747bcd8
@ -69,7 +69,7 @@ class PyTorchModelTrainer:
|
|||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
epochs = self.calc_n_epochs(
|
epochs = self.calc_n_epochs(
|
||||||
n_obs=len(data_dictionary[f'test_features']),
|
n_obs=len(data_dictionary['test_features']),
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
n_iters=self.eval_iters
|
n_iters=self.eval_iters
|
||||||
)
|
)
|
||||||
@ -101,8 +101,11 @@ class PyTorchModelTrainer:
|
|||||||
torch.from_numpy(data_dictionary[f'{split}_features'].values).float(),
|
torch.from_numpy(data_dictionary[f'{split}_features'].values).float(),
|
||||||
torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values)
|
torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values)
|
||||||
.long()
|
.long()
|
||||||
.view(labels_view) # todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes]. need to resolve it per ClassifierModel
|
.view(labels_view)
|
||||||
)
|
)
|
||||||
|
# todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes].
|
||||||
|
# need to resolve it per ClassifierModel
|
||||||
|
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
|
@ -24,7 +24,6 @@ from freqtrade.exceptions import OperationalException
|
|||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.strategy.interface import IStrategy
|
from freqtrade.strategy.interface import IStrategy
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -90,8 +89,8 @@ class FreqaiDataDrawer:
|
|||||||
self.metric_tracker_lock = threading.Lock()
|
self.metric_tracker_lock = threading.Lock()
|
||||||
self.old_DBSCAN_eps: Dict[str, float] = {}
|
self.old_DBSCAN_eps: Dict[str, float] = {}
|
||||||
self.empty_pair_dict: pair_info = {
|
self.empty_pair_dict: pair_info = {
|
||||||
"model_filename": "", "trained_timestamp": 0,
|
"model_filename": "", "trained_timestamp": 0,
|
||||||
"data_path": "", "extras": {}}
|
"data_path": "", "extras": {}}
|
||||||
self.model_type = self.freqai_info.get('model_save_type', 'joblib')
|
self.model_type = self.freqai_info.get('model_save_type', 'joblib')
|
||||||
|
|
||||||
def update_metric_tracker(self, metric: str, value: float, pair: str) -> None:
|
def update_metric_tracker(self, metric: str, value: float, pair: str) -> None:
|
||||||
@ -446,9 +445,9 @@ class FreqaiDataDrawer:
|
|||||||
dump(model, save_path / f"{dk.model_filename}_model.joblib")
|
dump(model, save_path / f"{dk.model_filename}_model.joblib")
|
||||||
elif self.model_type == 'keras':
|
elif self.model_type == 'keras':
|
||||||
model.save(save_path / f"{dk.model_filename}_model.h5")
|
model.save(save_path / f"{dk.model_filename}_model.h5")
|
||||||
elif 'stable_baselines' in self.model_type or\
|
elif ('stable_baselines' in self.model_type or
|
||||||
'sb3_contrib' == self.model_type or\
|
'sb3_contrib' == self.model_type or
|
||||||
'pytorch' == self.model_type:
|
'pytorch' == self.model_type):
|
||||||
model.save(save_path / f"{dk.model_filename}_model.zip")
|
model.save(save_path / f"{dk.model_filename}_model.zip")
|
||||||
|
|
||||||
if dk.svm_model is not None:
|
if dk.svm_model is not None:
|
||||||
@ -581,16 +580,16 @@ class FreqaiDataDrawer:
|
|||||||
if len(df_dp.index) == 0:
|
if len(df_dp.index) == 0:
|
||||||
continue
|
continue
|
||||||
if str(hist_df.iloc[-1]["date"]) == str(
|
if str(hist_df.iloc[-1]["date"]) == str(
|
||||||
df_dp.iloc[-1:]["date"].iloc[-1]
|
df_dp.iloc[-1:]["date"].iloc[-1]
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
index = (
|
index = (
|
||||||
df_dp.loc[
|
df_dp.loc[
|
||||||
df_dp["date"] == hist_df.iloc[-1]["date"]
|
df_dp["date"] == hist_df.iloc[-1]["date"]
|
||||||
].index[0]
|
].index[0]
|
||||||
+ 1
|
+ 1
|
||||||
)
|
)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
if hist_df.iloc[-1]['date'] < df_dp['date'].iloc[0]:
|
if hist_df.iloc[-1]['date'] < df_dp['date'].iloc[0]:
|
||||||
@ -643,7 +642,7 @@ class FreqaiDataDrawer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_base_and_corr_dataframes(
|
def get_base_and_corr_dataframes(
|
||||||
self, timerange: TimeRange, pair: str, dk: FreqaiDataKitchen
|
self, timerange: TimeRange, pair: str, dk: FreqaiDataKitchen
|
||||||
) -> Tuple[Dict[Any, Any], Dict[Any, Any]]:
|
) -> Tuple[Dict[Any, Any], Dict[Any, Any]]:
|
||||||
"""
|
"""
|
||||||
Searches through our historic_data in memory and returns the dataframes relevant
|
Searches through our historic_data in memory and returns the dataframes relevant
|
||||||
|
Loading…
Reference in New Issue
Block a user