From 8dbfd2cacfcd3dcabf2e4e5b3eddf84269e850f9 Mon Sep 17 00:00:00 2001 From: robcaulk Date: Sat, 26 Nov 2022 11:51:08 +0100 Subject: [PATCH] improve docstring clarity about how to inherit from ReinforcementLearner, demonstrate inherittance with ReinforcementLearner_multiproc --- .../prediction_models/ReinforcementLearner.py | 27 ++++++++++- .../ReinforcementLearner_multiproc.py | 45 ++----------------- 2 files changed, 30 insertions(+), 42 deletions(-) diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index 063af5ff5..dcf7cf54b 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -14,7 +14,32 @@ logger = logging.getLogger(__name__) class ReinforcementLearner(BaseReinforcementLearningModel): """ - User created Reinforcement Learning Model prediction model. + Reinforcement Learning Model prediction model. + + Users can inherit from this class to make their own RL model with custom + environment/training controls. Define the file as follows: + + ``` + from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner + + class MyCoolRLModel(ReinforcementLearner): + ``` + + Save the file to `user_data/freqaimodels`, then run it with: + + freqtrade trade --freqaimodel MyCoolRLModel --config config.json --strategy SomeCoolStrat + + Here the users can override any of the functions + available in the `IFreqaiModel` inheritance tree. Most importantly for RL, this + is where the user overrides `MyRLEnv` (see below), to define custom + `calculate_reward()` function, or to override any other parts of the environment. + + This class also allows users to override any other part of the IFreqaiModel tree. + For example, the user can override `def fit()` or `def train()` or `def predict()` + to take fine-tuned control over these processes. + + Another common override may be `def data_cleaning_predict()` where the user can + take fine-tuned control over the data handling pipeline. """ def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs): diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index baba16066..56636c1f6 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -1,61 +1,24 @@ import logging -from pathlib import Path from typing import Any, Dict # , Tuple # import numpy.typing as npt -import torch as th from pandas import DataFrame from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.vec_env import SubprocVecEnv from freqtrade.freqai.data_kitchen import FreqaiDataKitchen -from freqtrade.freqai.RL.BaseReinforcementLearningModel import (BaseReinforcementLearningModel, - make_env) +from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner +from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env logger = logging.getLogger(__name__) -class ReinforcementLearner_multiproc(BaseReinforcementLearningModel): +class ReinforcementLearner_multiproc(ReinforcementLearner): """ - User created Reinforcement Learning Model prediction model. + Demonstration of how to build vectorized environments """ - def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs): - - train_df = data_dictionary["train_features"] - total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df) - - # model arch - policy_kwargs = dict(activation_fn=th.nn.ReLU, - net_arch=self.net_arch) - - if dk.pair not in self.dd.model_dictionary or not self.continual_learning: - model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs, - tensorboard_log=Path( - dk.full_path / "tensorboard" / dk.pair.split('/')[0]), - **self.freqai_info['model_training_parameters'] - ) - else: - logger.info('Continual learning activated - starting training from previously ' - 'trained agent.') - model = self.dd.model_dictionary[dk.pair] - model.set_env(self.train_env) - - model.learn( - total_timesteps=int(total_timesteps), - callback=self.eval_callback - ) - - if Path(dk.data_path / "best_model.zip").is_file(): - logger.info('Callback found a best model.') - best_model = self.MODELCLASS.load(dk.data_path / "best_model") - return best_model - - logger.info('Couldnt find best model, using final model instead.') - - return model - def set_train_and_eval_environments(self, data_dictionary: Dict[str, Any], prices_train: DataFrame, prices_test: DataFrame, dk: FreqaiDataKitchen):