diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 9d2fae583..81f8edfc4 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -64,7 +64,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.policy_type = self.freqai_info['rl_config']['policy_type'] self.unset_outlier_removal() self.net_arch = self.rl_config.get('net_arch', [128, 128]) - self.dd.model_type = "stable_baselines" + self.dd.model_type = import_str def unset_outlier_removal(self): """ diff --git a/freqtrade/freqai/data_drawer.py b/freqtrade/freqai/data_drawer.py index 99e3686b3..848fb20eb 100644 --- a/freqtrade/freqai/data_drawer.py +++ b/freqtrade/freqai/data_drawer.py @@ -503,7 +503,7 @@ class FreqaiDataDrawer: dump(model, save_path / f"{dk.model_filename}_model.joblib") elif self.model_type == 'keras': model.save(save_path / f"{dk.model_filename}_model.h5") - elif 'stable_baselines' in self.model_type: + elif 'stable_baselines' in self.model_type or 'sb3_contrib' == self.model_type: model.save(save_path / f"{dk.model_filename}_model.zip") if dk.svm_model is not None: @@ -589,9 +589,9 @@ class FreqaiDataDrawer: elif self.model_type == 'keras': from tensorflow import keras model = keras.models.load_model(dk.data_path / f"{dk.model_filename}_model.h5") - elif self.model_type == 'stable_baselines': + elif 'stable_baselines' in self.model_type or 'sb3_contrib' == self.model_type: mod = importlib.import_module( - 'stable_baselines3', self.freqai_info['rl_config']['model_type']) + self.model_type, self.freqai_info['rl_config']['model_type']) MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type']) model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")