Directly set model_type in base RL model
This commit is contained in:
parent
1cdf5e0cfd
commit
9cbfa12011
@ -64,6 +64,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
self.policy_type = self.freqai_info['rl_config']['policy_type']
|
self.policy_type = self.freqai_info['rl_config']['policy_type']
|
||||||
self.unset_outlier_removal()
|
self.unset_outlier_removal()
|
||||||
self.net_arch = self.rl_config.get('net_arch', [128, 128])
|
self.net_arch = self.rl_config.get('net_arch', [128, 128])
|
||||||
|
self.dd.model_type = "stable_baselines"
|
||||||
|
|
||||||
def unset_outlier_removal(self):
|
def unset_outlier_removal(self):
|
||||||
"""
|
"""
|
||||||
|
@ -99,12 +99,7 @@ class FreqaiDataDrawer:
|
|||||||
self.empty_pair_dict: pair_info = {
|
self.empty_pair_dict: pair_info = {
|
||||||
"model_filename": "", "trained_timestamp": 0,
|
"model_filename": "", "trained_timestamp": 0,
|
||||||
"data_path": "", "extras": {}}
|
"data_path": "", "extras": {}}
|
||||||
if 'Reinforcement' in self.config['freqaimodel']:
|
self.model_type = self.freqai_info.get('model_save_type', 'joblib')
|
||||||
self.model_type = 'stable_baselines'
|
|
||||||
logger.warning('User passed a ReinforcementLearner model, FreqAI will '
|
|
||||||
'now use stable_baselines3 to save models.')
|
|
||||||
else:
|
|
||||||
self.model_type = self.freqai_info.get('model_save_type', 'joblib')
|
|
||||||
|
|
||||||
def update_metric_tracker(self, metric: str, value: float, pair: str) -> None:
|
def update_metric_tracker(self, metric: str, value: float, pair: str) -> None:
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user