Add sb3 learn progress bar
This commit is contained in:
parent
dc7e834911
commit
cab82e8e60
@ -73,7 +73,8 @@
|
||||
10,
|
||||
20
|
||||
],
|
||||
"plot_feature_importances": 0
|
||||
"plot_feature_importances": 0,
|
||||
"progress_bar": false
|
||||
},
|
||||
"data_split_parameters": {
|
||||
"test_size": 0.33,
|
||||
|
@ -85,6 +85,7 @@ Mandatory parameters are marked as **Required** and have to be set in one of the
|
||||
| `net_arch` | Network architecture which is well described in [`stable_baselines3` doc](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#examples). In summary: `[<shared layers>, dict(vf=[<non-shared value network layers>], pi=[<non-shared policy network layers>])]`. By default this is set to `[128, 128]`, which defines 2 shared hidden layers with 128 units each.
|
||||
| `randomize_starting_position` | Randomize the starting point of each episode to avoid overfitting. <br> **Datatype:** bool. <br> Default: `False`.
|
||||
| `drop_ohlc_from_features` | Do not include the normalized ohlc data in the feature set passed to the agent during training (ohlc will still be used for driving the environment in all cases) <br> **Datatype:** Boolean. <br> **Default:** `False`
|
||||
| `progress_bar` | Display a progress bar with the current progress, elapsed time and estimated remaining time. <br> **Datatype:** Boolean. <br> Default: `False`.
|
||||
|
||||
### Additional parameters
|
||||
|
||||
|
@ -599,6 +599,7 @@ CONF_SCHEMA = {
|
||||
"policy_type": {"type": "string", "default": "MlpPolicy"},
|
||||
"net_arch": {"type": "array", "default": [128, 128]},
|
||||
"randomize_startinng_position": {"type": "boolean", "default": False},
|
||||
"progress_bar": {"type": "boolean", "default": False},
|
||||
"model_reward_parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -71,7 +71,8 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
|
||||
model.learn(
|
||||
total_timesteps=int(total_timesteps),
|
||||
callback=[self.eval_callback, self.tensorboard_callback]
|
||||
callback=[self.eval_callback, self.tensorboard_callback],
|
||||
progress_bar=self.freqai_info["rl_config"]["progress_bar"]
|
||||
)
|
||||
|
||||
if Path(dk.data_path / "best_model.zip").is_file():
|
||||
|
@ -8,3 +8,6 @@ sb3-contrib==1.7.0; python_version < '3.11'
|
||||
# Gym is forced to this version by stable-baselines3.
|
||||
setuptools==65.5.1 # Should be removed when gym is fixed.
|
||||
gym==0.21; python_version < '3.11'
|
||||
# Progress bar for stable-baselines3 and sb3-contrib
|
||||
tqdm==4.65.0; python_version < '3.11'
|
||||
rich==13.3.3; python_version < '3.11'
|
||||
|
Loading…
Reference in New Issue
Block a user