Merge pull request #7809 from richardjozsa/develop

Improve the RL learning process
This commit is contained in:
Matthias 2022-11-29 06:28:36 +01:00 committed by GitHub
commit c3daddc629
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 0 deletions

View File

@ -82,6 +82,7 @@ Mandatory parameters are marked as **Required** and have to be set in one of the
| `model_reward_parameters` | Parameters used inside the customizable `calculate_reward()` function in `ReinforcementLearner.py` <br> **Datatype:** int.
| `add_state_info` | Tell FreqAI to include state information in the feature set for training and inferencing. The current state variables include trade duration, current profit, trade position. This is only available in dry/live runs, and is automatically switched to false for backtesting. <br> **Datatype:** bool. <br> Default: `False`.
| `net_arch` | Network architecture which is well described in [`stable_baselines3` doc](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#examples). In summary: `[<shared layers>, dict(vf=[<non-shared value network layers>], pi=[<non-shared policy network layers>])]`. By default this is set to `[128, 128]`, which defines 2 shared hidden layers with 128 units each.
| `randomize_starting_position` | Randomize the starting point of each episode to avoid overfitting. <br> **Datatype:** bool. <br> Default: `False`.
### Additional parameters

View File

@ -591,6 +591,7 @@ CONF_SCHEMA = {
"model_type": {"type": "string", "default": "PPO"},
"policy_type": {"type": "string", "default": "MlpPolicy"},
"net_arch": {"type": "array", "default": [128, 128]},
"randomize_startinng_position": {"type": "boolean", "default": False},
"model_reward_parameters": {
"type": "object",
"properties": {

View File

@ -1,4 +1,5 @@
import logging
import random
from abc import abstractmethod
from enum import Enum
from typing import Optional
@ -121,6 +122,10 @@ class BaseEnvironment(gym.Env):
self._done = False
if self.starting_point is True:
if self.rl_config.get('randomize_starting_position', False):
length_of_data = int(self._end_tick / 4)
start_tick = random.randint(self.window_size + 1, length_of_data)
self._start_tick = start_tick
self._position_history = (self._start_tick * [None]) + [self._position]
else:
self._position_history = (self.window_size * [None]) + [self._position]