Merge branch 'develop' into backtest_live_models
This commit is contained in:
@@ -245,6 +245,7 @@ class FreqaiDataKitchen:
|
||||
self.data["filter_drop_index_training"] = drop_index
|
||||
|
||||
else:
|
||||
filtered_df = self.check_pred_labels(filtered_df)
|
||||
# we are backtesting so we need to preserve row number to send back to strategy,
|
||||
# so now we use do_predict to avoid any prediction based on a NaN
|
||||
drop_index = pd.isnull(filtered_df).any(axis=1)
|
||||
@@ -487,6 +488,24 @@ class FreqaiDataKitchen:
|
||||
|
||||
return df
|
||||
|
||||
def check_pred_labels(self, df_predictions: DataFrame) -> DataFrame:
|
||||
"""
|
||||
Check that prediction feature labels match training feature labels.
|
||||
:params:
|
||||
:df_predictions: incoming predictions
|
||||
"""
|
||||
train_labels = self.data_dictionary["train_features"].columns
|
||||
pred_labels = df_predictions.columns
|
||||
num_diffs = len(pred_labels.difference(train_labels))
|
||||
if num_diffs != 0:
|
||||
df_predictions = df_predictions[train_labels]
|
||||
logger.warning(
|
||||
f"Removed {num_diffs} features from prediction features, "
|
||||
f"these were likely considered constant values during most recent training."
|
||||
)
|
||||
|
||||
return df_predictions
|
||||
|
||||
def principal_component_analysis(self) -> None:
|
||||
"""
|
||||
Performs Principal Component Analysis on the data for dimensionality reduction
|
||||
|
@@ -200,16 +200,15 @@ class IFreqaiModel(ABC):
|
||||
(_, trained_timestamp, _) = self.dd.get_pair_dict_info(pair)
|
||||
|
||||
dk = FreqaiDataKitchen(self.config, self.live, pair)
|
||||
dk.set_paths(pair, trained_timestamp)
|
||||
(
|
||||
retrain,
|
||||
new_trained_timerange,
|
||||
data_load_timerange,
|
||||
) = dk.check_if_new_training_required(trained_timestamp)
|
||||
dk.set_paths(pair, new_trained_timerange.stopts)
|
||||
|
||||
if retrain:
|
||||
self.train_timer('start')
|
||||
dk.set_paths(pair, new_trained_timerange.stopts)
|
||||
try:
|
||||
self.extract_data_and_train_model(
|
||||
new_trained_timerange, pair, strategy, dk, data_load_timerange
|
||||
@@ -290,9 +289,7 @@ class IFreqaiModel(ABC):
|
||||
if dk.backtest_live_models:
|
||||
timestamp_model_id = int(tr_backtest.startts)
|
||||
|
||||
dk.data_path = Path(
|
||||
dk.full_path / f"sub-train-{pair.split('/')[0]}_{timestamp_model_id}"
|
||||
)
|
||||
dk.set_paths(pair, timestamp_model_id)
|
||||
|
||||
dk.set_new_model_names(pair, timestamp_model_id)
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from catboost import CatBoostClassifier, Pool
|
||||
@@ -31,8 +32,9 @@ class CatboostClassifier(BaseClassifierModel):
|
||||
)
|
||||
|
||||
cbr = CatBoostClassifier(
|
||||
allow_writing_files=False,
|
||||
allow_writing_files=True,
|
||||
loss_function='MultiClass',
|
||||
train_dir=Path(dk.data_path),
|
||||
**self.model_training_parameters,
|
||||
)
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from catboost import CatBoostRegressor, Pool
|
||||
@@ -41,7 +42,8 @@ class CatboostRegressor(BaseRegressionModel):
|
||||
init_model = self.get_init_model(dk.pair)
|
||||
|
||||
model = CatBoostRegressor(
|
||||
allow_writing_files=False,
|
||||
allow_writing_files=True,
|
||||
train_dir=Path(dk.data_path),
|
||||
**self.model_training_parameters,
|
||||
)
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from catboost import CatBoostRegressor, Pool
|
||||
@@ -26,7 +27,8 @@ class CatboostRegressorMultiTarget(BaseRegressionModel):
|
||||
"""
|
||||
|
||||
cbr = CatBoostRegressor(
|
||||
allow_writing_files=False,
|
||||
allow_writing_files=True,
|
||||
train_dir=Path(dk.data_path),
|
||||
**self.model_training_parameters,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user