diff --git a/freqtrade/freqai/data_drawer.py b/freqtrade/freqai/data_drawer.py index 5e1f3a344..848fb20eb 100644 --- a/freqtrade/freqai/data_drawer.py +++ b/freqtrade/freqai/data_drawer.py @@ -503,7 +503,7 @@ class FreqaiDataDrawer: dump(model, save_path / f"{dk.model_filename}_model.joblib") elif self.model_type == 'keras': model.save(save_path / f"{dk.model_filename}_model.h5") - elif self.model_type in ['stable_baselines3', 'sb3_contrib']: + elif 'stable_baselines' in self.model_type or 'sb3_contrib' == self.model_type: model.save(save_path / f"{dk.model_filename}_model.zip") if dk.svm_model is not None: @@ -589,7 +589,7 @@ class FreqaiDataDrawer: elif self.model_type == 'keras': from tensorflow import keras model = keras.models.load_model(dk.data_path / f"{dk.model_filename}_model.h5") - elif self.model_type in ['stable_baselines3', 'sb3_contrib']: + elif 'stable_baselines' in self.model_type or 'sb3_contrib' == self.model_type: mod = importlib.import_module( self.model_type, self.freqai_info['rl_config']['model_type']) MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type'])