automatically handle model_save_type for user
This commit is contained in:
@@ -92,6 +92,12 @@ class FreqaiDataDrawer:
|
||||
"model_filename": "", "trained_timestamp": 0,
|
||||
"data_path": "", "extras": {}}
|
||||
self.limit_ram_use = self.freqai_info.get('limit_ram_usage', False)
|
||||
if 'rl_config' in self.freqai_info:
|
||||
self.model_type = 'stable_baselines'
|
||||
logger.warning('User indicated rl_config, FreqAI will now use stable_baselines3'
|
||||
' to save models.')
|
||||
else:
|
||||
self.model_type = self.freqai_info.get('model_save_type', 'joblib')
|
||||
|
||||
def load_drawer_from_disk(self):
|
||||
"""
|
||||
@@ -414,12 +420,11 @@ class FreqaiDataDrawer:
|
||||
save_path = Path(dk.data_path)
|
||||
|
||||
# Save the trained model
|
||||
model_type = self.freqai_info.get('model_save_type', 'joblib')
|
||||
if model_type == 'joblib':
|
||||
if self.model_type == 'joblib':
|
||||
dump(model, save_path / f"{dk.model_filename}_model.joblib")
|
||||
elif model_type == 'keras':
|
||||
elif self.model_type == 'keras':
|
||||
model.save(save_path / f"{dk.model_filename}_model.h5")
|
||||
elif 'stable_baselines' in model_type:
|
||||
elif 'stable_baselines' in self.model_type:
|
||||
model.save(save_path / f"{dk.model_filename}_model.zip")
|
||||
|
||||
if dk.svm_model is not None:
|
||||
@@ -496,16 +501,15 @@ class FreqaiDataDrawer:
|
||||
dk.data_path / f"{dk.model_filename}_trained_df.pkl"
|
||||
)
|
||||
|
||||
model_type = self.freqai_info.get('model_save_type', 'joblib')
|
||||
# try to access model in memory instead of loading object from disk to save time
|
||||
if dk.live and coin in self.model_dictionary and not self.limit_ram_use:
|
||||
model = self.model_dictionary[coin]
|
||||
elif model_type == 'joblib':
|
||||
elif self.model_type == 'joblib':
|
||||
model = load(dk.data_path / f"{dk.model_filename}_model.joblib")
|
||||
elif model_type == 'keras':
|
||||
elif self.model_type == 'keras':
|
||||
from tensorflow import keras
|
||||
model = keras.models.load_model(dk.data_path / f"{dk.model_filename}_model.h5")
|
||||
elif model_type == 'stable_baselines':
|
||||
elif self.model_type == 'stable_baselines':
|
||||
mod = __import__('stable_baselines3', fromlist=[
|
||||
self.freqai_info['rl_config']['model_type']])
|
||||
MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type'])
|
||||
|
Reference in New Issue
Block a user