Add ground work for TensorFlow models, add protections from common mistakes

This commit is contained in:
robcaulk 2022-07-12 18:09:17 +02:00
parent fea63fba12
commit ef409dd345
4 changed files with 44 additions and 21 deletions

View File

@ -57,6 +57,7 @@ class FreqaiDataKitchen:
self.live = live
self.pair = pair
self.svm_model: linear_model.SGDOneClassSVM = None
self.keras = self.freqai_config.get("keras", False)
self.set_all_pairs()
if not self.live:
self.full_timerange = self.create_fulltimerange(
@ -92,7 +93,7 @@ class FreqaiDataKitchen:
return
def save_data(self, model: Any, coin: str = "", keras_model=False, label=None) -> None:
def save_data(self, model: Any, coin: str = "", label=None) -> None:
"""
Saves all data associated with a model for a single sub-train time range
:params:
@ -106,7 +107,7 @@ class FreqaiDataKitchen:
save_path = Path(self.data_path)
# Save the trained model
if not keras_model:
if not self.keras:
dump(model, save_path / f"{self.model_filename}_model.joblib")
else:
model.save(save_path / f"{self.model_filename}_model.h5")
@ -140,7 +141,7 @@ class FreqaiDataKitchen:
return
def load_data(self, coin: str = "", keras_model=False) -> Any:
def load_data(self, coin: str = "") -> Any:
"""
loads all data required to make a prediction on a sub-train time range
:returns:
@ -174,7 +175,7 @@ class FreqaiDataKitchen:
# try to access model in memory instead of loading object from disk to save time
if self.live and self.model_filename in self.dd.model_dictionary:
model = self.dd.model_dictionary[self.model_filename]
elif not keras_model:
elif not self.keras:
model = load(self.data_path / str(self.model_filename + "_model.joblib"))
else:
from tensorflow import keras
@ -559,6 +560,13 @@ class FreqaiDataKitchen:
predict: bool = If true, inference an existing SVM model, else construct one
"""
if self.keras:
logger.warning("SVM outlier removal not currently supported for Keras based models. "
"Skipping user requested function.")
if predict:
self.do_predict = np.ones(len(self.data_dictionary["prediction_features"]))
return
if predict:
assert self.svm_model, "No svm model available for outlier removal"
y_pred = self.svm_model.predict(self.data_dictionary["prediction_features"])

View File

@ -69,6 +69,9 @@ class IFreqaiModel(ABC):
self.ready_to_scan = False
self.first = True
self.keras = self.freqai_info.get("keras", False)
if self.keras and self.freqai_info.get("feature_parameters", {}).get("DI_threshold", 0):
self.freqai_info["feature_parameters"]["DI_threshold"] = 0
logger.warning("DI threshold is not configured for Keras models yet. Deactivating.")
self.CONV_WIDTH = self.freqai_info.get("conv_width", 2)
def assert_config(self, config: Dict[str, Any]) -> None:
@ -197,9 +200,9 @@ class IFreqaiModel(ABC):
self.model = self.train(dataframe_train, metadata["pair"], dk)
self.dd.pair_dict[metadata["pair"]]["trained_timestamp"] = trained_timestamp.stopts
dk.set_new_model_names(metadata["pair"], trained_timestamp)
dk.save_data(self.model, metadata["pair"], keras_model=self.keras)
dk.save_data(self.model, metadata["pair"])
else:
self.model = dk.load_data(metadata["pair"], keras_model=self.keras)
self.model = dk.load_data(metadata["pair"])
self.check_if_feature_list_matches_strategy(dataframe_train, dk)
@ -276,7 +279,7 @@ class IFreqaiModel(ABC):
)
# load the model and associated data into the data kitchen
self.model = dk.load_data(coin=metadata["pair"], keras_model=self.keras)
self.model = dk.load_data(coin=metadata["pair"])
if not self.model:
logger.warning(
@ -353,13 +356,15 @@ class IFreqaiModel(ABC):
of how outlier data points are dropped from the dataframe used for training.
"""
if self.freqai_info.get("feature_parameters", {}).get("principal_component_analysis"):
if self.freqai_info.get("feature_parameters", {}).get(
"principal_component_analysis", False
):
dk.principal_component_analysis()
if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers"):
if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers", False):
dk.use_SVM_to_remove_outliers(predict=False)
if self.freqai_info.get("feature_parameters", {}).get("DI_threshold"):
if self.freqai_info.get("feature_parameters", {}).get("DI_threshold", 0):
dk.data["avg_mean_dist"] = dk.compute_distances()
# if self.feature_parameters["determine_statistical_distributions"]:
@ -378,13 +383,15 @@ class IFreqaiModel(ABC):
of how the do_predict vector is modified. do_predict is ultimately passed back to strategy
for buy signals.
"""
if self.freqai_info.get("feature_parameters", {}).get("principal_component_analysis"):
if self.freqai_info.get("feature_parameters", {}).get(
"principal_component_analysis", False
):
dk.pca_transform(dataframe)
if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers"):
if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers", False):
dk.use_SVM_to_remove_outliers(predict=True)
if self.freqai_info.get("feature_parameters", {}).get("DI_threshold"):
if self.freqai_info.get("feature_parameters", {}).get("DI_threshold", 0):
dk.check_if_pred_in_training_spaces()
# if self.feature_parameters["determine_statistical_distributions"]:
@ -479,14 +486,15 @@ class IFreqaiModel(ABC):
if self.dd.pair_dict[pair]["priority"] == 1 and self.scanning:
with self.lock:
self.dd.pair_to_end_of_training_queue(pair)
dk.save_data(model, coin=pair, keras_model=self.keras)
dk.save_data(model, coin=pair)
if self.freqai_info.get("purge_old_models", False):
self.dd.purge_old_models()
# self.retrain = False
def set_initial_historic_predictions(self, df: DataFrame, model: Any,
dk: FreqaiDataKitchen, pair: str) -> None:
def set_initial_historic_predictions(
self, df: DataFrame, model: Any, dk: FreqaiDataKitchen, pair: str
) -> None:
trained_predictions = model.predict(df)
pred_df = DataFrame(trained_predictions, columns=dk.label_list)
for label in dk.label_list:

View File

@ -12,9 +12,9 @@ logger = logging.getLogger(__name__)
class BaseRegressionModel(IFreqaiModel):
"""
User created prediction model. The class needs to override three necessary
functions, predict(), train(), fit(). The class inherits ModelHandler which
has its own DataHandler where data is held, saved, loaded, and managed.
Base class for regression type models (e.g. Catboost, LightGBM, XGboost etc.).
User *must* inherit from this class and set fit() and predict(). See example scripts
such as prediction_models/CatboostPredictionModel.py for guidance.
"""
def return_values(self, dataframe: DataFrame, dk: FreqaiDataKitchen) -> DataFrame:

View File

@ -24,8 +24,9 @@ class FreqaiModelResolver(IResolver):
object_type = IFreqaiModel
object_type_str = "FreqaiModel"
user_subdir = USERPATH_FREQAIMODELS
initial_search_path = Path(__file__).parent.parent.joinpath(
"freqai/prediction_models").resolve()
initial_search_path = (
Path(__file__).parent.parent.joinpath("freqai/prediction_models").resolve()
)
@staticmethod
def load_freqaimodel(config: Dict) -> IFreqaiModel:
@ -33,6 +34,7 @@ class FreqaiModelResolver(IResolver):
Load the custom class from config parameter
:param config: configuration dictionary
"""
disallowed_models = ["BaseRegressionModel", "BaseTensorFlowModel"]
freqaimodel_name = config.get("freqaimodel")
if not freqaimodel_name:
@ -40,6 +42,11 @@ class FreqaiModelResolver(IResolver):
"No freqaimodel set. Please use `--freqaimodel` to "
"specify the FreqaiModel class to use.\n"
)
if freqaimodel_name in disallowed_models:
raise OperationalException(
f"{freqaimodel_name} is a baseclass and cannot be used directly. User must choose "
"an existing child class or inherit from this baseclass.\n"
)
freqaimodel = FreqaiModelResolver.load_object(
freqaimodel_name,
config,