improve docstring clarity about how to inherit from ReinforcementLearner, demonstrate inherittance with ReinforcementLearner_multiproc
This commit is contained in:
parent
9f13d99b99
commit
8dbfd2cacf
@ -14,7 +14,32 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class ReinforcementLearner(BaseReinforcementLearningModel):
|
class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||||
"""
|
"""
|
||||||
User created Reinforcement Learning Model prediction model.
|
Reinforcement Learning Model prediction model.
|
||||||
|
|
||||||
|
Users can inherit from this class to make their own RL model with custom
|
||||||
|
environment/training controls. Define the file as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||||
|
|
||||||
|
class MyCoolRLModel(ReinforcementLearner):
|
||||||
|
```
|
||||||
|
|
||||||
|
Save the file to `user_data/freqaimodels`, then run it with:
|
||||||
|
|
||||||
|
freqtrade trade --freqaimodel MyCoolRLModel --config config.json --strategy SomeCoolStrat
|
||||||
|
|
||||||
|
Here the users can override any of the functions
|
||||||
|
available in the `IFreqaiModel` inheritance tree. Most importantly for RL, this
|
||||||
|
is where the user overrides `MyRLEnv` (see below), to define custom
|
||||||
|
`calculate_reward()` function, or to override any other parts of the environment.
|
||||||
|
|
||||||
|
This class also allows users to override any other part of the IFreqaiModel tree.
|
||||||
|
For example, the user can override `def fit()` or `def train()` or `def predict()`
|
||||||
|
to take fine-tuned control over these processes.
|
||||||
|
|
||||||
|
Another common override may be `def data_cleaning_predict()` where the user can
|
||||||
|
take fine-tuned control over the data handling pipeline.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||||
|
@ -1,61 +1,24 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict # , Tuple
|
from typing import Any, Dict # , Tuple
|
||||||
|
|
||||||
# import numpy.typing as npt
|
# import numpy.typing as npt
|
||||||
import torch as th
|
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from stable_baselines3.common.callbacks import EvalCallback
|
from stable_baselines3.common.callbacks import EvalCallback
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||||
|
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import (BaseReinforcementLearningModel,
|
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||||
make_env)
|
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
|
class ReinforcementLearner_multiproc(ReinforcementLearner):
|
||||||
"""
|
"""
|
||||||
User created Reinforcement Learning Model prediction model.
|
Demonstration of how to build vectorized environments
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
|
||||||
|
|
||||||
train_df = data_dictionary["train_features"]
|
|
||||||
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
|
||||||
|
|
||||||
# model arch
|
|
||||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
|
||||||
net_arch=self.net_arch)
|
|
||||||
|
|
||||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
|
||||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
|
||||||
tensorboard_log=Path(
|
|
||||||
dk.full_path / "tensorboard" / dk.pair.split('/')[0]),
|
|
||||||
**self.freqai_info['model_training_parameters']
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info('Continual learning activated - starting training from previously '
|
|
||||||
'trained agent.')
|
|
||||||
model = self.dd.model_dictionary[dk.pair]
|
|
||||||
model.set_env(self.train_env)
|
|
||||||
|
|
||||||
model.learn(
|
|
||||||
total_timesteps=int(total_timesteps),
|
|
||||||
callback=self.eval_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
if Path(dk.data_path / "best_model.zip").is_file():
|
|
||||||
logger.info('Callback found a best model.')
|
|
||||||
best_model = self.MODELCLASS.load(dk.data_path / "best_model")
|
|
||||||
return best_model
|
|
||||||
|
|
||||||
logger.info('Couldnt find best model, using final model instead.')
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def set_train_and_eval_environments(self, data_dictionary: Dict[str, Any],
|
def set_train_and_eval_environments(self, data_dictionary: Dict[str, Any],
|
||||||
prices_train: DataFrame, prices_test: DataFrame,
|
prices_train: DataFrame, prices_test: DataFrame,
|
||||||
dk: FreqaiDataKitchen):
|
dk: FreqaiDataKitchen):
|
||||||
|
Loading…
Reference in New Issue
Block a user