From 9cbfa1201113afeb143fb22b3b9ee4be125c5263 Mon Sep 17 00:00:00 2001 From: Emre Date: Mon, 28 Nov 2022 16:02:17 +0300 Subject: [PATCH] Directly set model_type in base RL model --- freqtrade/freqai/RL/BaseReinforcementLearningModel.py | 1 + freqtrade/freqai/data_drawer.py | 7 +------ 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 709ded048..e1381ab62 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -64,6 +64,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.policy_type = self.freqai_info['rl_config']['policy_type'] self.unset_outlier_removal() self.net_arch = self.rl_config.get('net_arch', [128, 128]) + self.dd.model_type = "stable_baselines" def unset_outlier_removal(self): """ diff --git a/freqtrade/freqai/data_drawer.py b/freqtrade/freqai/data_drawer.py index 3b9352efe..ab41240e9 100644 --- a/freqtrade/freqai/data_drawer.py +++ b/freqtrade/freqai/data_drawer.py @@ -99,12 +99,7 @@ class FreqaiDataDrawer: self.empty_pair_dict: pair_info = { "model_filename": "", "trained_timestamp": 0, "data_path": "", "extras": {}} - if 'Reinforcement' in self.config['freqaimodel']: - self.model_type = 'stable_baselines' - logger.warning('User passed a ReinforcementLearner model, FreqAI will ' - 'now use stable_baselines3 to save models.') - else: - self.model_type = self.freqai_info.get('model_save_type', 'joblib') + self.model_type = self.freqai_info.get('model_save_type', 'joblib') def update_metric_tracker(self, metric: str, value: float, pair: str) -> None: """