automatically handle model_save_type for user

This commit is contained in:
robcaulk 2022-10-03 18:42:20 +02:00
parent cf882fa84e
commit 292d72d593
3 changed files with 13 additions and 10 deletions

View File

@ -53,7 +53,7 @@ Mandatory parameters are marked as **Required** and have to be set in one of the
| `max_trade_duration_candles`| Guides the agent training to keep trades below desired length. Example usage shown in `prediction_models/ReinforcementLearner.py` within the user customizable `calculate_reward()` <br> **Datatype:** int. | `max_trade_duration_candles`| Guides the agent training to keep trades below desired length. Example usage shown in `prediction_models/ReinforcementLearner.py` within the user customizable `calculate_reward()` <br> **Datatype:** int.
| `model_type` | Model string from stable_baselines3 or SBcontrib. Available strings include: `'TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO', 'PPO', 'A2C', 'DQN'`. User should ensure that `model_training_parameters` match those available to the corresponding stable_baselines3 model by visiting their documentaiton. [PPO doc](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (external website) <br> **Datatype:** string. | `model_type` | Model string from stable_baselines3 or SBcontrib. Available strings include: `'TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO', 'PPO', 'A2C', 'DQN'`. User should ensure that `model_training_parameters` match those available to the corresponding stable_baselines3 model by visiting their documentaiton. [PPO doc](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (external website) <br> **Datatype:** string.
| `policy_type` | One of the available policy types from stable_baselines3 <br> **Datatype:** string. | `policy_type` | One of the available policy types from stable_baselines3 <br> **Datatype:** string.
| `continual_learning` | If true, the agent will start new trainings from the model selected during the previous training. If false, a new agent is trained from scratch for each training. <br> **Datatype:** Bool. | `max_training_drawdown_pct` | The maximum drawdown that the agent is allowed to experience during training. <br> **Datatype:** float. <br> Default: 0.8
| `cpu_count` | Number of threads/cpus to dedicate to the Reinforcement Learning training process (depending on if `ReinforcementLearning_multiproc` is selected or not). <br> **Datatype:** int. | `cpu_count` | Number of threads/cpus to dedicate to the Reinforcement Learning training process (depending on if `ReinforcementLearning_multiproc` is selected or not). <br> **Datatype:** int.
| `model_reward_parameters` | Parameters used inside the user customizable `calculate_reward()` function in `ReinforcementLearner.py` <br> **Datatype:** int. | `model_reward_parameters` | Parameters used inside the user customizable `calculate_reward()` function in `ReinforcementLearner.py` <br> **Datatype:** int.
| | **Extraneous parameters** | | **Extraneous parameters**

View File

@ -118,7 +118,6 @@ In order to configure the `Reinforcement Learner` the following dictionary to th
"cpu_count": 8, "cpu_count": 8,
"model_type": "PPO", "model_type": "PPO",
"policy_type": "MlpPolicy", "policy_type": "MlpPolicy",
"continual_learning": false,
"model_reward_parameters": { "model_reward_parameters": {
"rr": 1, "rr": 1,
"profit_aim": 0.025 "profit_aim": 0.025

View File

@ -92,6 +92,12 @@ class FreqaiDataDrawer:
"model_filename": "", "trained_timestamp": 0, "model_filename": "", "trained_timestamp": 0,
"data_path": "", "extras": {}} "data_path": "", "extras": {}}
self.limit_ram_use = self.freqai_info.get('limit_ram_usage', False) self.limit_ram_use = self.freqai_info.get('limit_ram_usage', False)
if 'rl_config' in self.freqai_info:
self.model_type = 'stable_baselines'
logger.warning('User indicated rl_config, FreqAI will now use stable_baselines3'
' to save models.')
else:
self.model_type = self.freqai_info.get('model_save_type', 'joblib')
def load_drawer_from_disk(self): def load_drawer_from_disk(self):
""" """
@ -414,12 +420,11 @@ class FreqaiDataDrawer:
save_path = Path(dk.data_path) save_path = Path(dk.data_path)
# Save the trained model # Save the trained model
model_type = self.freqai_info.get('model_save_type', 'joblib') if self.model_type == 'joblib':
if model_type == 'joblib':
dump(model, save_path / f"{dk.model_filename}_model.joblib") dump(model, save_path / f"{dk.model_filename}_model.joblib")
elif model_type == 'keras': elif self.model_type == 'keras':
model.save(save_path / f"{dk.model_filename}_model.h5") model.save(save_path / f"{dk.model_filename}_model.h5")
elif 'stable_baselines' in model_type: elif 'stable_baselines' in self.model_type:
model.save(save_path / f"{dk.model_filename}_model.zip") model.save(save_path / f"{dk.model_filename}_model.zip")
if dk.svm_model is not None: if dk.svm_model is not None:
@ -496,16 +501,15 @@ class FreqaiDataDrawer:
dk.data_path / f"{dk.model_filename}_trained_df.pkl" dk.data_path / f"{dk.model_filename}_trained_df.pkl"
) )
model_type = self.freqai_info.get('model_save_type', 'joblib')
# try to access model in memory instead of loading object from disk to save time # try to access model in memory instead of loading object from disk to save time
if dk.live and coin in self.model_dictionary and not self.limit_ram_use: if dk.live and coin in self.model_dictionary and not self.limit_ram_use:
model = self.model_dictionary[coin] model = self.model_dictionary[coin]
elif model_type == 'joblib': elif self.model_type == 'joblib':
model = load(dk.data_path / f"{dk.model_filename}_model.joblib") model = load(dk.data_path / f"{dk.model_filename}_model.joblib")
elif model_type == 'keras': elif self.model_type == 'keras':
from tensorflow import keras from tensorflow import keras
model = keras.models.load_model(dk.data_path / f"{dk.model_filename}_model.h5") model = keras.models.load_model(dk.data_path / f"{dk.model_filename}_model.h5")
elif model_type == 'stable_baselines': elif self.model_type == 'stable_baselines':
mod = __import__('stable_baselines3', fromlist=[ mod = __import__('stable_baselines3', fromlist=[
self.freqai_info['rl_config']['model_type']]) self.freqai_info['rl_config']['model_type']])
MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type']) MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type'])