reduce code for base use-case, ensure multiproc inherits custom env, add ability to limit ram use.
This commit is contained in:
		| @@ -58,6 +58,7 @@ | |||||||
|         "model_save_type": "stable_baselines", |         "model_save_type": "stable_baselines", | ||||||
|         "conv_width": 4, |         "conv_width": 4, | ||||||
|         "purge_old_models": true, |         "purge_old_models": true, | ||||||
|  |         "limit_ram_usage": false, | ||||||
|         "train_period_days": 5, |         "train_period_days": 5, | ||||||
|         "backtest_period_days": 2, |         "backtest_period_days": 2, | ||||||
|         "identifier": "unique-id", |         "identifier": "unique-id", | ||||||
|   | |||||||
| @@ -19,6 +19,7 @@ from typing import Callable | |||||||
| from datetime import datetime, timezone | from datetime import datetime, timezone | ||||||
| from stable_baselines3.common.utils import set_random_seed | from stable_baselines3.common.utils import set_random_seed | ||||||
| import gym | import gym | ||||||
|  | from pathlib import Path | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| torch.multiprocessing.set_sharing_strategy('file_system') | torch.multiprocessing.set_sharing_strategy('file_system') | ||||||
| @@ -110,9 +111,9 @@ class BaseReinforcementLearningModel(IFreqaiModel): | |||||||
|         train_df = data_dictionary["train_features"] |         train_df = data_dictionary["train_features"] | ||||||
|         test_df = data_dictionary["test_features"] |         test_df = data_dictionary["test_features"] | ||||||
|  |  | ||||||
|         self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH, |         self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH, | ||||||
|                                       reward_kwargs=self.reward_params, config=self.config) |                                       reward_kwargs=self.reward_params, config=self.config) | ||||||
|         self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test, |         self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test, | ||||||
|                                 window_size=self.CONV_WIDTH, |                                 window_size=self.CONV_WIDTH, | ||||||
|                                 reward_kwargs=self.reward_params, config=self.config)) |                                 reward_kwargs=self.reward_params, config=self.config)) | ||||||
|         self.eval_callback = EvalCallback(self.eval_env, deterministic=True, |         self.eval_callback = EvalCallback(self.eval_env, deterministic=True, | ||||||
| @@ -126,7 +127,6 @@ class BaseReinforcementLearningModel(IFreqaiModel): | |||||||
|         go in here. Abstract method, so this function must be overridden by |         go in here. Abstract method, so this function must be overridden by | ||||||
|         user class. |         user class. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         return |         return | ||||||
|  |  | ||||||
|     def get_state_info(self, pair: str): |     def get_state_info(self, pair: str): | ||||||
| @@ -232,38 +232,22 @@ class BaseReinforcementLearningModel(IFreqaiModel): | |||||||
|  |  | ||||||
|         return prices_train, prices_test |         return prices_train, prices_test | ||||||
|  |  | ||||||
|     # TODO take care of this appendage. Right now it needs to be called because FreqAI enforces it. |     def load_model_from_disk(self, dk: FreqaiDataKitchen) -> Any: | ||||||
|     # But FreqaiRL needs more objects passed to fit() (like DK) and we dont want to go refactor |  | ||||||
|     # all the other existing fit() functions to include dk argument. For now we instantiate and |  | ||||||
|     # leave it. |  | ||||||
|     def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any: |  | ||||||
|         return |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def make_env(env_id: str, rank: int, seed: int, train_df: DataFrame, price: DataFrame, |  | ||||||
|              reward_params: Dict[str, int], window_size: int, monitor: bool = False, |  | ||||||
|              config: Dict[str, Any] = {}) -> Callable: |  | ||||||
|         """ |         """ | ||||||
|     Utility function for multiprocessed env. |         Can be used by user if they are trying to limit_ram_usage *and* | ||||||
|  |         perform continual learning. | ||||||
|     :param env_id: (str) the environment ID |         For now, this is unused. | ||||||
|     :param num_env: (int) the number of environment you wish to have in subprocesses |  | ||||||
|     :param seed: (int) the inital seed for RNG |  | ||||||
|     :param rank: (int) index of the subprocess |  | ||||||
|     :return: (Callable) |  | ||||||
|         """ |         """ | ||||||
|     def _init() -> gym.Env: |         exists = Path(dk.data_path / f"{dk.model_filename}_model").is_file() | ||||||
|  |         if exists: | ||||||
|  |             model = self.MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model") | ||||||
|  |         else: | ||||||
|  |             logger.info('No model file on disk to continue learning from.') | ||||||
|  |  | ||||||
|         env = MyRLEnv(df=train_df, prices=price, window_size=window_size, |         return model | ||||||
|                       reward_kwargs=reward_params, id=env_id, seed=seed + rank, config=config) |  | ||||||
|         if monitor: |  | ||||||
|             env = Monitor(env) |  | ||||||
|         return env |  | ||||||
|     set_random_seed(seed) |  | ||||||
|     return _init |  | ||||||
|  |  | ||||||
|  |     # Nested class which can be overridden by user to customize further | ||||||
| class MyRLEnv(Base5ActionRLEnv): |     class MyRLEnv(Base5ActionRLEnv): | ||||||
|         """ |         """ | ||||||
|         User can override any function in BaseRLEnv and gym.Env. Here the user |         User can override any function in BaseRLEnv and gym.Env. Here the user | ||||||
|         sets a custom reward based on profit and trade duration. |         sets a custom reward based on profit and trade duration. | ||||||
| @@ -296,7 +280,8 @@ class MyRLEnv(Base5ActionRLEnv): | |||||||
|                 factor *= 0.5 |                 factor *= 0.5 | ||||||
|  |  | ||||||
|             # discourage sitting in position |             # discourage sitting in position | ||||||
|         if self._position in (Positions.Short, Positions.Long) and action == Actions.Neutral.value: |             if self._position in (Positions.Short, Positions.Long) and \ | ||||||
|  |                action == Actions.Neutral.value: | ||||||
|                 return -1 * trade_duration / max_trade_duration |                 return -1 * trade_duration / max_trade_duration | ||||||
|  |  | ||||||
|             # close long |             # close long | ||||||
| @@ -312,3 +297,35 @@ class MyRLEnv(Base5ActionRLEnv): | |||||||
|                 return float(rew * factor) |                 return float(rew * factor) | ||||||
|  |  | ||||||
|             return 0. |             return 0. | ||||||
|  |  | ||||||
|  |     # TODO take care of this appendage. Right now it needs to be called because FreqAI enforces it. | ||||||
|  |     # But FreqaiRL needs more objects passed to fit() (like DK) and we dont want to go refactor | ||||||
|  |     # all the other existing fit() functions to include dk argument. For now we instantiate and | ||||||
|  |     # leave it. | ||||||
|  |     def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any: | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def make_env(MyRLEnv: Base5ActionRLEnv, env_id: str, rank: int, | ||||||
|  |              seed: int, train_df: DataFrame, price: DataFrame, | ||||||
|  |              reward_params: Dict[str, int], window_size: int, monitor: bool = False, | ||||||
|  |              config: Dict[str, Any] = {}) -> Callable: | ||||||
|  |     """ | ||||||
|  |     Utility function for multiprocessed env. | ||||||
|  |  | ||||||
|  |     :param env_id: (str) the environment ID | ||||||
|  |     :param num_env: (int) the number of environment you wish to have in subprocesses | ||||||
|  |     :param seed: (int) the inital seed for RNG | ||||||
|  |     :param rank: (int) index of the subprocess | ||||||
|  |     :return: (Callable) | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def _init() -> gym.Env: | ||||||
|  |  | ||||||
|  |         env = MyRLEnv(df=train_df, prices=price, window_size=window_size, | ||||||
|  |                       reward_kwargs=reward_params, id=env_id, seed=seed + rank, config=config) | ||||||
|  |         if monitor: | ||||||
|  |             env = Monitor(env) | ||||||
|  |         return env | ||||||
|  |     set_random_seed(seed) | ||||||
|  |     return _init | ||||||
|   | |||||||
| @@ -90,6 +90,7 @@ class FreqaiDataDrawer: | |||||||
|         self.empty_pair_dict: pair_info = { |         self.empty_pair_dict: pair_info = { | ||||||
|                 "model_filename": "", "trained_timestamp": 0, |                 "model_filename": "", "trained_timestamp": 0, | ||||||
|                 "priority": 1, "first": True, "data_path": "", "extras": {}} |                 "priority": 1, "first": True, "data_path": "", "extras": {}} | ||||||
|  |         self.limit_ram_use = self.freqai_info.get('limit_ram_usage', False) | ||||||
|  |  | ||||||
|     def load_drawer_from_disk(self): |     def load_drawer_from_disk(self): | ||||||
|         """ |         """ | ||||||
| @@ -423,7 +424,7 @@ class FreqaiDataDrawer: | |||||||
|                 dk.pca, open(dk.data_path / f"{dk.model_filename}_pca_object.pkl", "wb") |                 dk.pca, open(dk.data_path / f"{dk.model_filename}_pca_object.pkl", "wb") | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         # if self.live: |         if not self.limit_ram_use: | ||||||
|             self.model_dictionary[coin] = model |             self.model_dictionary[coin] = model | ||||||
|         self.pair_dict[coin]["model_filename"] = dk.model_filename |         self.pair_dict[coin]["model_filename"] = dk.model_filename | ||||||
|         self.pair_dict[coin]["data_path"] = str(dk.data_path) |         self.pair_dict[coin]["data_path"] = str(dk.data_path) | ||||||
| @@ -464,7 +465,7 @@ class FreqaiDataDrawer: | |||||||
|  |  | ||||||
|         model_type = self.freqai_info.get('model_save_type', 'joblib') |         model_type = self.freqai_info.get('model_save_type', 'joblib') | ||||||
|         # try to access model in memory instead of loading object from disk to save time |         # try to access model in memory instead of loading object from disk to save time | ||||||
|         if dk.live and coin in self.model_dictionary: |         if dk.live and coin in self.model_dictionary and not self.limit_ram_use: | ||||||
|             model = self.model_dictionary[coin] |             model = self.model_dictionary[coin] | ||||||
|         elif model_type == 'joblib': |         elif model_type == 'joblib': | ||||||
|             model = load(dk.data_path / f"{dk.model_filename}_model.joblib") |             model = load(dk.data_path / f"{dk.model_filename}_model.joblib") | ||||||
| @@ -486,7 +487,7 @@ class FreqaiDataDrawer: | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         # load it into ram if it was loaded from disk |         # load it into ram if it was loaded from disk | ||||||
|         if coin not in self.model_dictionary: |         if coin not in self.model_dictionary and not self.limit_ram_use: | ||||||
|             self.model_dictionary[coin] = model |             self.model_dictionary[coin] = model | ||||||
|  |  | ||||||
|         if self.config["freqai"]["feature_parameters"]["principal_component_analysis"]: |         if self.config["freqai"]["feature_parameters"]["principal_component_analysis"]: | ||||||
|   | |||||||
| @@ -3,12 +3,12 @@ from typing import Any, Dict | |||||||
|  |  | ||||||
| import torch as th | import torch as th | ||||||
| from freqtrade.freqai.data_kitchen import FreqaiDataKitchen | from freqtrade.freqai.data_kitchen import FreqaiDataKitchen | ||||||
| from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions | from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Positions | ||||||
| from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel | from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from pandas import DataFrame | # from pandas import DataFrame | ||||||
| from stable_baselines3.common.callbacks import EvalCallback | # from stable_baselines3.common.callbacks import EvalCallback | ||||||
| from stable_baselines3.common.monitor import Monitor | # from stable_baselines3.common.monitor import Monitor | ||||||
| import numpy as np | import numpy as np | ||||||
|  |  | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| @@ -53,26 +53,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel): | |||||||
|  |  | ||||||
|         return model |         return model | ||||||
|  |  | ||||||
|     def set_train_and_eval_environments(self, data_dictionary: Dict[str, DataFrame], |     class MyRLEnv(BaseReinforcementLearningModel.MyRLEnv): | ||||||
|                                         prices_train: DataFrame, prices_test: DataFrame, |  | ||||||
|                                         dk: FreqaiDataKitchen): |  | ||||||
|         """ |  | ||||||
|         User can override this if they are using a custom MyRLEnv |  | ||||||
|         """ |  | ||||||
|         train_df = data_dictionary["train_features"] |  | ||||||
|         test_df = data_dictionary["test_features"] |  | ||||||
|  |  | ||||||
|         self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH, |  | ||||||
|                                  reward_kwargs=self.reward_params, config=self.config) |  | ||||||
|         self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test, |  | ||||||
|                                 window_size=self.CONV_WIDTH, |  | ||||||
|                                 reward_kwargs=self.reward_params, config=self.config)) |  | ||||||
|         self.eval_callback = EvalCallback(self.eval_env, deterministic=True, |  | ||||||
|                                           render=False, eval_freq=len(train_df), |  | ||||||
|                                           best_model_save_path=str(dk.data_path)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class MyRLEnv(Base5ActionRLEnv): |  | ||||||
|         """ |         """ | ||||||
|         User can override any function in BaseRLEnv and gym.Env. Here the user |         User can override any function in BaseRLEnv and gym.Env. Here the user | ||||||
|         sets a custom reward based on profit and trade duration. |         sets a custom reward based on profit and trade duration. | ||||||
| @@ -105,7 +86,8 @@ class MyRLEnv(Base5ActionRLEnv): | |||||||
|                 factor *= 0.5 |                 factor *= 0.5 | ||||||
|  |  | ||||||
|             # discourage sitting in position |             # discourage sitting in position | ||||||
|         if self._position in (Positions.Short, Positions.Long) and action == Actions.Neutral.value: |             if self._position in (Positions.Short, Positions.Long) and \ | ||||||
|  |                     action == Actions.Neutral.value: | ||||||
|                 return -1 * trade_duration / max_trade_duration |                 return -1 * trade_duration / max_trade_duration | ||||||
|  |  | ||||||
|             # close long |             # close long | ||||||
|   | |||||||
| @@ -34,7 +34,7 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel): | |||||||
|                                     **self.freqai_info['model_training_parameters'] |                                     **self.freqai_info['model_training_parameters'] | ||||||
|                                     ) |                                     ) | ||||||
|         else: |         else: | ||||||
|             logger.info('Continual training activated - starting training from previously ' |             logger.info('Continual learning activated - starting training from previously ' | ||||||
|                         'trained agent.') |                         'trained agent.') | ||||||
|             model = self.dd.model_dictionary[dk.pair] |             model = self.dd.model_dictionary[dk.pair] | ||||||
|             model.tensorboard_log = Path(dk.data_path / "tensorboard") |             model.tensorboard_log = Path(dk.data_path / "tensorboard") | ||||||
| @@ -65,13 +65,14 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel): | |||||||
|  |  | ||||||
|         env_id = "train_env" |         env_id = "train_env" | ||||||
|         num_cpu = int(self.freqai_info["rl_config"]["thread_count"] / 2) |         num_cpu = int(self.freqai_info["rl_config"]["thread_count"] / 2) | ||||||
|         self.train_env = SubprocVecEnv([make_env(env_id, i, 1, train_df, prices_train, |         self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train, | ||||||
|                                         self.reward_params, self.CONV_WIDTH, |                                         self.reward_params, self.CONV_WIDTH, | ||||||
|                                         config=self.config) for i |                                         config=self.config) for i | ||||||
|                                         in range(num_cpu)]) |                                         in range(num_cpu)]) | ||||||
|  |  | ||||||
|         eval_env_id = 'eval_env' |         eval_env_id = 'eval_env' | ||||||
|         self.eval_env = SubprocVecEnv([make_env(eval_env_id, i, 1, test_df, prices_test, |         self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, | ||||||
|  |                                                 test_df, prices_test, | ||||||
|                                                 self.reward_params, self.CONV_WIDTH, monitor=True, |                                                 self.reward_params, self.CONV_WIDTH, monitor=True, | ||||||
|                                                 config=self.config) for i |                                                 config=self.config) for i | ||||||
|                                        in range(num_cpu)]) |                                        in range(num_cpu)]) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user