Directly set model_type in base RL model

This commit is contained in:
Emre 2022-11-28 16:02:17 +03:00
parent 1cdf5e0cfd
commit 9cbfa12011
No known key found for this signature in database
GPG Key ID: 0EAD2EE11B666ABA
2 changed files with 2 additions and 6 deletions

View File

@ -64,6 +64,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.policy_type = self.freqai_info['rl_config']['policy_type'] self.policy_type = self.freqai_info['rl_config']['policy_type']
self.unset_outlier_removal() self.unset_outlier_removal()
self.net_arch = self.rl_config.get('net_arch', [128, 128]) self.net_arch = self.rl_config.get('net_arch', [128, 128])
self.dd.model_type = "stable_baselines"
def unset_outlier_removal(self): def unset_outlier_removal(self):
""" """

View File

@ -99,12 +99,7 @@ class FreqaiDataDrawer:
self.empty_pair_dict: pair_info = { self.empty_pair_dict: pair_info = {
"model_filename": "", "trained_timestamp": 0, "model_filename": "", "trained_timestamp": 0,
"data_path": "", "extras": {}} "data_path": "", "extras": {}}
if 'Reinforcement' in self.config['freqaimodel']: self.model_type = self.freqai_info.get('model_save_type', 'joblib')
self.model_type = 'stable_baselines'
logger.warning('User passed a ReinforcementLearner model, FreqAI will '
'now use stable_baselines3 to save models.')
else:
self.model_type = self.freqai_info.get('model_save_type', 'joblib')
def update_metric_tracker(self, metric: str, value: float, pair: str) -> None: def update_metric_tracker(self, metric: str, value: float, pair: str) -> None:
""" """