From 4a9982f86bdc340441cd9c7fff1259f9813a715d Mon Sep 17 00:00:00 2001 From: Emre Date: Thu, 1 Dec 2022 10:08:42 +0300 Subject: [PATCH] Fix sb3_contrib loading issue --- freqtrade/freqai/RL/BaseReinforcementLearningModel.py | 2 +- freqtrade/freqai/data_drawer.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 9d2fae583..81f8edfc4 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -64,7 +64,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.policy_type = self.freqai_info['rl_config']['policy_type'] self.unset_outlier_removal() self.net_arch = self.rl_config.get('net_arch', [128, 128]) - self.dd.model_type = "stable_baselines" + self.dd.model_type = import_str def unset_outlier_removal(self): """ diff --git a/freqtrade/freqai/data_drawer.py b/freqtrade/freqai/data_drawer.py index 99e3686b3..5e1f3a344 100644 --- a/freqtrade/freqai/data_drawer.py +++ b/freqtrade/freqai/data_drawer.py @@ -503,7 +503,7 @@ class FreqaiDataDrawer: dump(model, save_path / f"{dk.model_filename}_model.joblib") elif self.model_type == 'keras': model.save(save_path / f"{dk.model_filename}_model.h5") - elif 'stable_baselines' in self.model_type: + elif self.model_type in ['stable_baselines3', 'sb3_contrib']: model.save(save_path / f"{dk.model_filename}_model.zip") if dk.svm_model is not None: @@ -589,9 +589,9 @@ class FreqaiDataDrawer: elif self.model_type == 'keras': from tensorflow import keras model = keras.models.load_model(dk.data_path / f"{dk.model_filename}_model.h5") - elif self.model_type == 'stable_baselines': + elif self.model_type in ['stable_baselines3', 'sb3_contrib']: mod = importlib.import_module( - 'stable_baselines3', self.freqai_info['rl_config']['model_type']) + self.model_type, self.freqai_info['rl_config']['model_type']) MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type']) model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")