Working base for reinforcement learning model

This commit is contained in:
robcaulk
2022-08-08 15:41:16 +02:00
parent a6d78a8615
commit 05ed1b544f
9 changed files with 748 additions and 12 deletions

View File

@@ -390,10 +390,13 @@ class FreqaiDataDrawer:
save_path = Path(dk.data_path)
# Save the trained model
if not dk.keras:
model_type = self.freqai_info.get('model_save_type', 'joblib')
if model_type == 'joblib':
dump(model, save_path / f"{dk.model_filename}_model.joblib")
else:
elif model_type == 'keras':
model.save(save_path / f"{dk.model_filename}_model.h5")
elif model_type == 'stable_baselines':
model.save(save_path / f"{dk.model_filename}_model.zip")
if dk.svm_model is not None:
dump(dk.svm_model, save_path / f"{dk.model_filename}_svm_model.joblib")
@@ -459,15 +462,18 @@ class FreqaiDataDrawer:
dk.data_path / f"{dk.model_filename}_trained_df.pkl"
)
model_type = self.freqai_info.get('model_save_type', 'joblib')
# try to access model in memory instead of loading object from disk to save time
if dk.live and coin in self.model_dictionary:
model = self.model_dictionary[coin]
elif not dk.keras:
elif model_type == 'joblib':
model = load(dk.data_path / f"{dk.model_filename}_model.joblib")
else:
elif model_type == 'keras':
from tensorflow import keras
model = keras.models.load_model(dk.data_path / f"{dk.model_filename}_model.h5")
elif model_type == 'stable_baselines':
from stable_baselines3.ppo.ppo import PPO
model = PPO.load(dk.data_path / f"{dk.model_filename}_model.zip")
if Path(dk.data_path / f"{dk.model_filename}_svm_model.joblib").is_file():
dk.svm_model = load(dk.data_path / f"{dk.model_filename}_svm_model.joblib")