automatically handle model_save_type for user

This commit is contained in:
robcaulk
2022-10-03 18:42:20 +02:00
parent cf882fa84e
commit 292d72d593
3 changed files with 13 additions and 10 deletions

View File

@@ -92,6 +92,12 @@ class FreqaiDataDrawer:
"model_filename": "", "trained_timestamp": 0,
"data_path": "", "extras": {}}
self.limit_ram_use = self.freqai_info.get('limit_ram_usage', False)
if 'rl_config' in self.freqai_info:
self.model_type = 'stable_baselines'
logger.warning('User indicated rl_config, FreqAI will now use stable_baselines3'
' to save models.')
else:
self.model_type = self.freqai_info.get('model_save_type', 'joblib')
def load_drawer_from_disk(self):
"""
@@ -414,12 +420,11 @@ class FreqaiDataDrawer:
save_path = Path(dk.data_path)
# Save the trained model
model_type = self.freqai_info.get('model_save_type', 'joblib')
if model_type == 'joblib':
if self.model_type == 'joblib':
dump(model, save_path / f"{dk.model_filename}_model.joblib")
elif model_type == 'keras':
elif self.model_type == 'keras':
model.save(save_path / f"{dk.model_filename}_model.h5")
elif 'stable_baselines' in model_type:
elif 'stable_baselines' in self.model_type:
model.save(save_path / f"{dk.model_filename}_model.zip")
if dk.svm_model is not None:
@@ -496,16 +501,15 @@ class FreqaiDataDrawer:
dk.data_path / f"{dk.model_filename}_trained_df.pkl"
)
model_type = self.freqai_info.get('model_save_type', 'joblib')
# try to access model in memory instead of loading object from disk to save time
if dk.live and coin in self.model_dictionary and not self.limit_ram_use:
model = self.model_dictionary[coin]
elif model_type == 'joblib':
elif self.model_type == 'joblib':
model = load(dk.data_path / f"{dk.model_filename}_model.joblib")
elif model_type == 'keras':
elif self.model_type == 'keras':
from tensorflow import keras
model = keras.models.load_model(dk.data_path / f"{dk.model_filename}_model.h5")
elif model_type == 'stable_baselines':
elif self.model_type == 'stable_baselines':
mod = __import__('stable_baselines3', fromlist=[
self.freqai_info['rl_config']['model_type']])
MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type'])