From 396e666e9b46c4447907c9c093bef67931b09087 Mon Sep 17 00:00:00 2001 From: Emre Date: Thu, 1 Dec 2022 11:03:51 +0300 Subject: [PATCH] Keep old behavior of model loading --- freqtrade/freqai/data_drawer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/freqtrade/freqai/data_drawer.py b/freqtrade/freqai/data_drawer.py index 5e1f3a344..848fb20eb 100644 --- a/freqtrade/freqai/data_drawer.py +++ b/freqtrade/freqai/data_drawer.py @@ -503,7 +503,7 @@ class FreqaiDataDrawer: dump(model, save_path / f"{dk.model_filename}_model.joblib") elif self.model_type == 'keras': model.save(save_path / f"{dk.model_filename}_model.h5") - elif self.model_type in ['stable_baselines3', 'sb3_contrib']: + elif 'stable_baselines' in self.model_type or 'sb3_contrib' == self.model_type: model.save(save_path / f"{dk.model_filename}_model.zip") if dk.svm_model is not None: @@ -589,7 +589,7 @@ class FreqaiDataDrawer: elif self.model_type == 'keras': from tensorflow import keras model = keras.models.load_model(dk.data_path / f"{dk.model_filename}_model.h5") - elif self.model_type in ['stable_baselines3', 'sb3_contrib']: + elif 'stable_baselines' in self.model_type or 'sb3_contrib' == self.model_type: mod = importlib.import_module( self.model_type, self.freqai_info['rl_config']['model_type']) MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type'])