This commit is contained in:
initrv 2023-04-02 01:08:42 +00:00 committed by GitHub
commit 71ecd79fda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 7 additions and 1 deletions

View File

@ -85,6 +85,7 @@ Mandatory parameters are marked as **Required** and have to be set in one of the
| `net_arch` | Network architecture which is well described in [`stable_baselines3` doc](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#examples). In summary: `[<shared layers>, dict(vf=[<non-shared value network layers>], pi=[<non-shared policy network layers>])]`. By default this is set to `[128, 128]`, which defines 2 shared hidden layers with 128 units each. | `net_arch` | Network architecture which is well described in [`stable_baselines3` doc](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#examples). In summary: `[<shared layers>, dict(vf=[<non-shared value network layers>], pi=[<non-shared policy network layers>])]`. By default this is set to `[128, 128]`, which defines 2 shared hidden layers with 128 units each.
| `randomize_starting_position` | Randomize the starting point of each episode to avoid overfitting. <br> **Datatype:** bool. <br> Default: `False`. | `randomize_starting_position` | Randomize the starting point of each episode to avoid overfitting. <br> **Datatype:** bool. <br> Default: `False`.
| `drop_ohlc_from_features` | Do not include the normalized ohlc data in the feature set passed to the agent during training (ohlc will still be used for driving the environment in all cases) <br> **Datatype:** Boolean. <br> **Default:** `False` | `drop_ohlc_from_features` | Do not include the normalized ohlc data in the feature set passed to the agent during training (ohlc will still be used for driving the environment in all cases) <br> **Datatype:** Boolean. <br> **Default:** `False`
| `progress_bar` | Display a progress bar with the current progress, elapsed time and estimated remaining time. <br> **Datatype:** Boolean. <br> Default: `False`.
### Additional parameters ### Additional parameters

View File

@ -599,6 +599,7 @@ CONF_SCHEMA = {
"policy_type": {"type": "string", "default": "MlpPolicy"}, "policy_type": {"type": "string", "default": "MlpPolicy"},
"net_arch": {"type": "array", "default": [128, 128]}, "net_arch": {"type": "array", "default": [128, 128]},
"randomize_startinng_position": {"type": "boolean", "default": False}, "randomize_startinng_position": {"type": "boolean", "default": False},
"progress_bar": {"type": "boolean", "default": False},
"model_reward_parameters": { "model_reward_parameters": {
"type": "object", "type": "object",
"properties": { "properties": {

View File

@ -71,7 +71,8 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
model.learn( model.learn(
total_timesteps=int(total_timesteps), total_timesteps=int(total_timesteps),
callback=[self.eval_callback, self.tensorboard_callback] callback=[self.eval_callback, self.tensorboard_callback],
progress_bar=self.rl_config.get('progress_bar', False)
) )
if Path(dk.data_path / "best_model.zip").is_file(): if Path(dk.data_path / "best_model.zip").is_file():

View File

@ -8,3 +8,6 @@ sb3-contrib==1.7.0; python_version < '3.11'
# Gym is forced to this version by stable-baselines3. # Gym is forced to this version by stable-baselines3.
setuptools==65.5.1 # Should be removed when gym is fixed. setuptools==65.5.1 # Should be removed when gym is fixed.
gym==0.21; python_version < '3.11' gym==0.21; python_version < '3.11'
# Progress bar for stable-baselines3 and sb3-contrib
tqdm==4.65.0; python_version < '3.11'
rich==13.3.3; python_version < '3.11'