Add sb3 learn progress bar
This commit is contained in:
parent
dc7e834911
commit
cab82e8e60
@ -73,7 +73,8 @@
|
|||||||
10,
|
10,
|
||||||
20
|
20
|
||||||
],
|
],
|
||||||
"plot_feature_importances": 0
|
"plot_feature_importances": 0,
|
||||||
|
"progress_bar": false
|
||||||
},
|
},
|
||||||
"data_split_parameters": {
|
"data_split_parameters": {
|
||||||
"test_size": 0.33,
|
"test_size": 0.33,
|
||||||
|
@ -85,6 +85,7 @@ Mandatory parameters are marked as **Required** and have to be set in one of the
|
|||||||
| `net_arch` | Network architecture which is well described in [`stable_baselines3` doc](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#examples). In summary: `[<shared layers>, dict(vf=[<non-shared value network layers>], pi=[<non-shared policy network layers>])]`. By default this is set to `[128, 128]`, which defines 2 shared hidden layers with 128 units each.
|
| `net_arch` | Network architecture which is well described in [`stable_baselines3` doc](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#examples). In summary: `[<shared layers>, dict(vf=[<non-shared value network layers>], pi=[<non-shared policy network layers>])]`. By default this is set to `[128, 128]`, which defines 2 shared hidden layers with 128 units each.
|
||||||
| `randomize_starting_position` | Randomize the starting point of each episode to avoid overfitting. <br> **Datatype:** bool. <br> Default: `False`.
|
| `randomize_starting_position` | Randomize the starting point of each episode to avoid overfitting. <br> **Datatype:** bool. <br> Default: `False`.
|
||||||
| `drop_ohlc_from_features` | Do not include the normalized ohlc data in the feature set passed to the agent during training (ohlc will still be used for driving the environment in all cases) <br> **Datatype:** Boolean. <br> **Default:** `False`
|
| `drop_ohlc_from_features` | Do not include the normalized ohlc data in the feature set passed to the agent during training (ohlc will still be used for driving the environment in all cases) <br> **Datatype:** Boolean. <br> **Default:** `False`
|
||||||
|
| `progress_bar` | Display a progress bar with the current progress, elapsed time and estimated remaining time. <br> **Datatype:** Boolean. <br> Default: `False`.
|
||||||
|
|
||||||
### Additional parameters
|
### Additional parameters
|
||||||
|
|
||||||
|
@ -599,6 +599,7 @@ CONF_SCHEMA = {
|
|||||||
"policy_type": {"type": "string", "default": "MlpPolicy"},
|
"policy_type": {"type": "string", "default": "MlpPolicy"},
|
||||||
"net_arch": {"type": "array", "default": [128, 128]},
|
"net_arch": {"type": "array", "default": [128, 128]},
|
||||||
"randomize_startinng_position": {"type": "boolean", "default": False},
|
"randomize_startinng_position": {"type": "boolean", "default": False},
|
||||||
|
"progress_bar": {"type": "boolean", "default": False},
|
||||||
"model_reward_parameters": {
|
"model_reward_parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -71,7 +71,8 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
|
|
||||||
model.learn(
|
model.learn(
|
||||||
total_timesteps=int(total_timesteps),
|
total_timesteps=int(total_timesteps),
|
||||||
callback=[self.eval_callback, self.tensorboard_callback]
|
callback=[self.eval_callback, self.tensorboard_callback],
|
||||||
|
progress_bar=self.freqai_info["rl_config"]["progress_bar"]
|
||||||
)
|
)
|
||||||
|
|
||||||
if Path(dk.data_path / "best_model.zip").is_file():
|
if Path(dk.data_path / "best_model.zip").is_file():
|
||||||
|
@ -8,3 +8,6 @@ sb3-contrib==1.7.0; python_version < '3.11'
|
|||||||
# Gym is forced to this version by stable-baselines3.
|
# Gym is forced to this version by stable-baselines3.
|
||||||
setuptools==65.5.1 # Should be removed when gym is fixed.
|
setuptools==65.5.1 # Should be removed when gym is fixed.
|
||||||
gym==0.21; python_version < '3.11'
|
gym==0.21; python_version < '3.11'
|
||||||
|
# Progress bar for stable-baselines3 and sb3-contrib
|
||||||
|
tqdm==4.65.0; python_version < '3.11'
|
||||||
|
rich==13.3.3; python_version < '3.11'
|
||||||
|
Loading…
Reference in New Issue
Block a user