Merge remote-tracking branch 'origin/develop' into feat/convolutional-neural-net
This commit is contained in:
		| @@ -20,6 +20,9 @@ class Base4ActionRLEnv(BaseEnvironment): | ||||
|     """ | ||||
|     Base class for a 4 action environment | ||||
|     """ | ||||
|     def __init__(self, **kwargs): | ||||
|         super().__init__(**kwargs) | ||||
|         self.actions = Actions | ||||
|  | ||||
|     def set_action_space(self): | ||||
|         self.action_space = spaces.Discrete(len(Actions)) | ||||
| @@ -43,9 +46,9 @@ class Base4ActionRLEnv(BaseEnvironment): | ||||
|             self._done = True | ||||
|  | ||||
|         self._update_unrealized_total_profit() | ||||
|  | ||||
|         step_reward = self.calculate_reward(action) | ||||
|         self.total_reward += step_reward | ||||
|         self.tensorboard_log(self.actions._member_names_[action]) | ||||
|  | ||||
|         trade_type = None | ||||
|         if self.is_tradesignal(action): | ||||
| @@ -92,9 +95,12 @@ class Base4ActionRLEnv(BaseEnvironment): | ||||
|  | ||||
|         info = dict( | ||||
|             tick=self._current_tick, | ||||
|             action=action, | ||||
|             total_reward=self.total_reward, | ||||
|             total_profit=self._total_profit, | ||||
|             position=self._position.value | ||||
|             position=self._position.value, | ||||
|             trade_duration=self.get_trade_duration(), | ||||
|             current_profit_pct=self.get_unrealized_profit() | ||||
|         ) | ||||
|  | ||||
|         observation = self._get_observation() | ||||
|   | ||||
| @@ -21,6 +21,9 @@ class Base5ActionRLEnv(BaseEnvironment): | ||||
|     """ | ||||
|     Base class for a 5 action environment | ||||
|     """ | ||||
|     def __init__(self, **kwargs): | ||||
|         super().__init__(**kwargs) | ||||
|         self.actions = Actions | ||||
|  | ||||
|     def set_action_space(self): | ||||
|         self.action_space = spaces.Discrete(len(Actions)) | ||||
| @@ -46,6 +49,7 @@ class Base5ActionRLEnv(BaseEnvironment): | ||||
|         self._update_unrealized_total_profit() | ||||
|         step_reward = self.calculate_reward(action) | ||||
|         self.total_reward += step_reward | ||||
|         self.tensorboard_log(self.actions._member_names_[action]) | ||||
|  | ||||
|         trade_type = None | ||||
|         if self.is_tradesignal(action): | ||||
| @@ -98,9 +102,12 @@ class Base5ActionRLEnv(BaseEnvironment): | ||||
|  | ||||
|         info = dict( | ||||
|             tick=self._current_tick, | ||||
|             action=action, | ||||
|             total_reward=self.total_reward, | ||||
|             total_profit=self._total_profit, | ||||
|             position=self._position.value | ||||
|             position=self._position.value, | ||||
|             trade_duration=self.get_trade_duration(), | ||||
|             current_profit_pct=self.get_unrealized_profit() | ||||
|         ) | ||||
|  | ||||
|         observation = self._get_observation() | ||||
|   | ||||
| @@ -2,7 +2,7 @@ import logging | ||||
| import random | ||||
| from abc import abstractmethod | ||||
| from enum import Enum | ||||
| from typing import Optional | ||||
| from typing import Optional, Type, Union | ||||
|  | ||||
| import gym | ||||
| import numpy as np | ||||
| @@ -11,12 +11,21 @@ from gym import spaces | ||||
| from gym.utils import seeding | ||||
| from pandas import DataFrame | ||||
|  | ||||
| from freqtrade.data.dataprovider import DataProvider | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| class BaseActions(Enum): | ||||
|     """ | ||||
|     Default action space, mostly used for type handling. | ||||
|     """ | ||||
|     Neutral = 0 | ||||
|     Long_enter = 1 | ||||
|     Long_exit = 2 | ||||
|     Short_enter = 3 | ||||
|     Short_exit = 4 | ||||
|  | ||||
|  | ||||
| class Positions(Enum): | ||||
|     Short = 0 | ||||
|     Long = 1 | ||||
| @@ -35,8 +44,8 @@ class BaseEnvironment(gym.Env): | ||||
|  | ||||
|     def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(), | ||||
|                  reward_kwargs: dict = {}, window_size=10, starting_point=True, | ||||
|                  id: str = 'baseenv-1', seed: int = 1, config: dict = {}, | ||||
|                  dp: Optional[DataProvider] = None): | ||||
|                  id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False, | ||||
|                  fee: float = 0.0015): | ||||
|         """ | ||||
|         Initializes the training/eval environment. | ||||
|         :param df: dataframe of features | ||||
| @@ -47,22 +56,29 @@ class BaseEnvironment(gym.Env): | ||||
|         :param id: string id of the environment (used in backend for multiprocessed env) | ||||
|         :param seed: Sets the seed of the environment higher in the gym.Env object | ||||
|         :param config: Typical user configuration file | ||||
|         :param dp: dataprovider from freqtrade | ||||
|         :param live: Whether or not this environment is active in dry/live/backtesting | ||||
|         :param fee: The fee to use for environmental interactions. | ||||
|         """ | ||||
|         self.config = config | ||||
|         self.rl_config = config['freqai']['rl_config'] | ||||
|         self.add_state_info = self.rl_config.get('add_state_info', False) | ||||
|         self.id = id | ||||
|         self.seed(seed) | ||||
|         self.reset_env(df, prices, window_size, reward_kwargs, starting_point) | ||||
|         self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8) | ||||
|         self.compound_trades = config['stake_amount'] == 'unlimited' | ||||
|         if self.config.get('fee', None) is not None: | ||||
|             self.fee = self.config['fee'] | ||||
|         elif dp is not None: | ||||
|             self.fee = dp._exchange.get_fee(symbol=dp.current_whitelist()[0])  # type: ignore | ||||
|         else: | ||||
|             self.fee = 0.0015 | ||||
|             self.fee = fee | ||||
|  | ||||
|         # set here to default 5Ac, but all children envs can override this | ||||
|         self.actions: Type[Enum] = BaseActions | ||||
|         self.tensorboard_metrics: dict = {} | ||||
|         self.live = live | ||||
|         if not self.live and self.add_state_info: | ||||
|             self.add_state_info = False | ||||
|             logger.warning("add_state_info is not available in backtesting. Deactivating.") | ||||
|         self.seed(seed) | ||||
|         self.reset_env(df, prices, window_size, reward_kwargs, starting_point) | ||||
|  | ||||
|     def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int, | ||||
|                   reward_kwargs: dict, starting_point=True): | ||||
| @@ -117,7 +133,38 @@ class BaseEnvironment(gym.Env): | ||||
|         self.np_random, seed = seeding.np_random(seed) | ||||
|         return [seed] | ||||
|  | ||||
|     def tensorboard_log(self, metric: str, value: Union[int, float] = 1, inc: bool = True): | ||||
|         """ | ||||
|         Function builds the tensorboard_metrics dictionary | ||||
|         to be parsed by the TensorboardCallback. This | ||||
|         function is designed for tracking incremented objects, | ||||
|         events, actions inside the training environment. | ||||
|         For example, a user can call this to track the | ||||
|         frequency of occurence of an `is_valid` call in | ||||
|         their `calculate_reward()`: | ||||
|  | ||||
|         def calculate_reward(self, action: int) -> float: | ||||
|             if not self._is_valid(action): | ||||
|                 self.tensorboard_log("is_valid") | ||||
|                 return -2 | ||||
|  | ||||
|         :param metric: metric to be tracked and incremented | ||||
|         :param value: value to increment `metric` by | ||||
|         :param inc: sets whether the `value` is incremented or not | ||||
|         """ | ||||
|         if not inc or metric not in self.tensorboard_metrics: | ||||
|             self.tensorboard_metrics[metric] = value | ||||
|         else: | ||||
|             self.tensorboard_metrics[metric] += value | ||||
|  | ||||
|     def reset_tensorboard_log(self): | ||||
|         self.tensorboard_metrics = {} | ||||
|  | ||||
|     def reset(self): | ||||
|         """ | ||||
|         Reset is called at the beginning of every episode | ||||
|         """ | ||||
|         self.reset_tensorboard_log() | ||||
|  | ||||
|         self._done = False | ||||
|  | ||||
| @@ -271,6 +318,13 @@ class BaseEnvironment(gym.Env): | ||||
|     def current_price(self) -> float: | ||||
|         return self.prices.iloc[self._current_tick].open | ||||
|  | ||||
|     def get_actions(self) -> Type[Enum]: | ||||
|         """ | ||||
|         Used by SubprocVecEnv to get actions from | ||||
|         initialized env for tensorboard callback | ||||
|         """ | ||||
|         return self.actions | ||||
|  | ||||
|     # Keeping around incase we want to start building more complex environment | ||||
|     # templates in the future. | ||||
|     # def most_recent_return(self): | ||||
|   | ||||
| @@ -21,7 +21,8 @@ from freqtrade.exceptions import OperationalException | ||||
| from freqtrade.freqai.data_kitchen import FreqaiDataKitchen | ||||
| from freqtrade.freqai.freqai_interface import IFreqaiModel | ||||
| from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv | ||||
| from freqtrade.freqai.RL.BaseEnvironment import Positions | ||||
| from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions | ||||
| from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback | ||||
| from freqtrade.persistence import Trade | ||||
|  | ||||
|  | ||||
| @@ -44,8 +45,8 @@ class BaseReinforcementLearningModel(IFreqaiModel): | ||||
|             'cpu_count', 1), max(int(self.max_system_threads / 2), 1)) | ||||
|         th.set_num_threads(self.max_threads) | ||||
|         self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] | ||||
|         self.train_env: Union[SubprocVecEnv, gym.Env] = None | ||||
|         self.eval_env: Union[SubprocVecEnv, gym.Env] = None | ||||
|         self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() | ||||
|         self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() | ||||
|         self.eval_callback: Optional[EvalCallback] = None | ||||
|         self.model_type = self.freqai_info['rl_config']['model_type'] | ||||
|         self.rl_config = self.freqai_info['rl_config'] | ||||
| @@ -65,6 +66,8 @@ class BaseReinforcementLearningModel(IFreqaiModel): | ||||
|         self.unset_outlier_removal() | ||||
|         self.net_arch = self.rl_config.get('net_arch', [128, 128]) | ||||
|         self.dd.model_type = import_str | ||||
|         self.tensorboard_callback: TensorboardCallback = \ | ||||
|             TensorboardCallback(verbose=1, actions=BaseActions) | ||||
|  | ||||
|     def unset_outlier_removal(self): | ||||
|         """ | ||||
| @@ -140,22 +143,35 @@ class BaseReinforcementLearningModel(IFreqaiModel): | ||||
|         train_df = data_dictionary["train_features"] | ||||
|         test_df = data_dictionary["test_features"] | ||||
|  | ||||
|         env_info = self.pack_env_dict() | ||||
|  | ||||
|         self.train_env = self.MyRLEnv(df=train_df, | ||||
|                                       prices=prices_train, | ||||
|                                       window_size=self.CONV_WIDTH, | ||||
|                                       reward_kwargs=self.reward_params, | ||||
|                                       config=self.config, | ||||
|                                       dp=self.data_provider) | ||||
|                                       **env_info) | ||||
|         self.eval_env = Monitor(self.MyRLEnv(df=test_df, | ||||
|                                              prices=prices_test, | ||||
|                                              window_size=self.CONV_WIDTH, | ||||
|                                              reward_kwargs=self.reward_params, | ||||
|                                              config=self.config, | ||||
|                                              dp=self.data_provider)) | ||||
|                                              **env_info)) | ||||
|         self.eval_callback = EvalCallback(self.eval_env, deterministic=True, | ||||
|                                           render=False, eval_freq=len(train_df), | ||||
|                                           best_model_save_path=str(dk.data_path)) | ||||
|  | ||||
|         actions = self.train_env.get_actions() | ||||
|         self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions) | ||||
|  | ||||
|     def pack_env_dict(self) -> Dict[str, Any]: | ||||
|         """ | ||||
|         Create dictionary of environment arguments | ||||
|         """ | ||||
|         env_info = {"window_size": self.CONV_WIDTH, | ||||
|                     "reward_kwargs": self.reward_params, | ||||
|                     "config": self.config, | ||||
|                     "live": self.live} | ||||
|         if self.data_provider: | ||||
|             env_info["fee"] = self.data_provider._exchange \ | ||||
|                 .get_fee(symbol=self.data_provider.current_whitelist()[0])  # type: ignore | ||||
|  | ||||
|         return env_info | ||||
|  | ||||
|     @abstractmethod | ||||
|     def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs): | ||||
|         """ | ||||
| @@ -377,8 +393,8 @@ class BaseReinforcementLearningModel(IFreqaiModel): | ||||
|  | ||||
| def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, | ||||
|              seed: int, train_df: DataFrame, price: DataFrame, | ||||
|              reward_params: Dict[str, int], window_size: int, monitor: bool = False, | ||||
|              config: Dict[str, Any] = {}) -> Callable: | ||||
|              monitor: bool = False, | ||||
|              env_info: Dict[str, Any] = {}) -> Callable: | ||||
|     """ | ||||
|     Utility function for multiprocessed env. | ||||
|  | ||||
| @@ -386,13 +402,14 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, | ||||
|     :param num_env: (int) the number of environment you wish to have in subprocesses | ||||
|     :param seed: (int) the inital seed for RNG | ||||
|     :param rank: (int) index of the subprocess | ||||
|     :param env_info: (dict) all required arguments to instantiate the environment. | ||||
|     :return: (Callable) | ||||
|     """ | ||||
|  | ||||
|     def _init() -> gym.Env: | ||||
|  | ||||
|         env = MyRLEnv(df=train_df, prices=price, window_size=window_size, | ||||
|                       reward_kwargs=reward_params, id=env_id, seed=seed + rank, config=config) | ||||
|         env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank, | ||||
|                       **env_info) | ||||
|         if monitor: | ||||
|             env = Monitor(env) | ||||
|         return env | ||||
|   | ||||
							
								
								
									
										59
									
								
								freqtrade/freqai/RL/TensorboardCallback.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								freqtrade/freqai/RL/TensorboardCallback.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,59 @@ | ||||
| from enum import Enum | ||||
| from typing import Any, Dict, Type, Union | ||||
|  | ||||
| from stable_baselines3.common.callbacks import BaseCallback | ||||
| from stable_baselines3.common.logger import HParam | ||||
|  | ||||
| from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment | ||||
|  | ||||
|  | ||||
| class TensorboardCallback(BaseCallback): | ||||
|     """ | ||||
|     Custom callback for plotting additional values in tensorboard and | ||||
|     episodic summary reports. | ||||
|     """ | ||||
|     def __init__(self, verbose=1, actions: Type[Enum] = BaseActions): | ||||
|         super(TensorboardCallback, self).__init__(verbose) | ||||
|         self.model: Any = None | ||||
|         self.logger = None  # type: Any | ||||
|         self.training_env: BaseEnvironment = None  # type: ignore | ||||
|         self.actions: Type[Enum] = actions | ||||
|  | ||||
|     def _on_training_start(self) -> None: | ||||
|         hparam_dict = { | ||||
|             "algorithm": self.model.__class__.__name__, | ||||
|             "learning_rate": self.model.learning_rate, | ||||
|             # "gamma": self.model.gamma, | ||||
|             # "gae_lambda": self.model.gae_lambda, | ||||
|             # "batch_size": self.model.batch_size, | ||||
|             # "n_steps": self.model.n_steps, | ||||
|         } | ||||
|         metric_dict: Dict[str, Union[float, int]] = { | ||||
|             "eval/mean_reward": 0, | ||||
|             "rollout/ep_rew_mean": 0, | ||||
|             "rollout/ep_len_mean": 0, | ||||
|             "train/value_loss": 0, | ||||
|             "train/explained_variance": 0, | ||||
|         } | ||||
|         self.logger.record( | ||||
|             "hparams", | ||||
|             HParam(hparam_dict, metric_dict), | ||||
|             exclude=("stdout", "log", "json", "csv"), | ||||
|         ) | ||||
|  | ||||
|     def _on_step(self) -> bool: | ||||
|  | ||||
|         local_info = self.locals["infos"][0] | ||||
|         tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] | ||||
|  | ||||
|         for info in local_info: | ||||
|             if info not in ["episode", "terminal_observation"]: | ||||
|                 self.logger.record(f"_info/{info}", local_info[info]) | ||||
|  | ||||
|         for info in tensorboard_metrics: | ||||
|             if info in [action.name for action in self.actions]: | ||||
|                 self.logger.record(f"_actions/{info}", tensorboard_metrics[info]) | ||||
|             else: | ||||
|                 self.logger.record(f"_custom/{info}", tensorboard_metrics[info]) | ||||
|  | ||||
|         return True | ||||
| @@ -95,9 +95,14 @@ class BaseClassifierModel(IFreqaiModel): | ||||
|         self.data_cleaning_predict(dk) | ||||
|  | ||||
|         predictions = self.model.predict(dk.data_dictionary["prediction_features"]) | ||||
|         if self.CONV_WIDTH == 1: | ||||
|             predictions = np.reshape(predictions, (-1, len(dk.label_list))) | ||||
|  | ||||
|         pred_df = DataFrame(predictions, columns=dk.label_list) | ||||
|  | ||||
|         predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"]) | ||||
|         if self.CONV_WIDTH == 1: | ||||
|             predictions_prob = np.reshape(predictions_prob, (-1, len(self.model.classes_))) | ||||
|         pred_df_prob = DataFrame(predictions_prob, columns=self.model.classes_) | ||||
|  | ||||
|         pred_df = pd.concat([pred_df, pred_df_prob], axis=1) | ||||
|   | ||||
| @@ -95,6 +95,9 @@ class BaseRegressionModel(IFreqaiModel): | ||||
|         self.data_cleaning_predict(dk) | ||||
|  | ||||
|         predictions = self.model.predict(dk.data_dictionary["prediction_features"]) | ||||
|         if self.CONV_WIDTH == 1: | ||||
|             predictions = np.reshape(predictions, (-1, len(dk.label_list))) | ||||
|  | ||||
|         pred_df = DataFrame(predictions, columns=dk.label_list) | ||||
|  | ||||
|         pred_df = dk.denormalize_labels_from_metadata(pred_df) | ||||
|   | ||||
| @@ -462,10 +462,10 @@ class FreqaiDataKitchen: | ||||
|         :param df: Dataframe containing all candles to run the entire backtest. Here | ||||
|                    it is sliced down to just the present training period. | ||||
|         """ | ||||
|  | ||||
|         df = df.loc[df["date"] >= timerange.startdt, :] | ||||
|         if not self.live: | ||||
|             df = df.loc[df["date"] < timerange.stopdt, :] | ||||
|             df = df.loc[(df["date"] >= timerange.startdt) & (df["date"] < timerange.stopdt), :] | ||||
|         else: | ||||
|             df = df.loc[df["date"] >= timerange.startdt, :] | ||||
|  | ||||
|         return df | ||||
|  | ||||
|   | ||||
| @@ -282,10 +282,10 @@ class IFreqaiModel(ABC): | ||||
|             train_it += 1 | ||||
|             total_trains = len(dk.backtesting_timeranges) | ||||
|             self.training_timerange = tr_train | ||||
|             dataframe_train = dk.slice_dataframe(tr_train, dataframe) | ||||
|             dataframe_backtest = dk.slice_dataframe(tr_backtest, dataframe) | ||||
|             len_backtest_df = len(dataframe.loc[(dataframe["date"] >= tr_backtest.startdt) & ( | ||||
|                                   dataframe["date"] < tr_backtest.stopdt), :]) | ||||
|  | ||||
|             if not self.ensure_data_exists(dataframe_backtest, tr_backtest, pair): | ||||
|             if not self.ensure_data_exists(len_backtest_df, tr_backtest, pair): | ||||
|                 continue | ||||
|  | ||||
|             self.log_backtesting_progress(tr_train, pair, train_it, total_trains) | ||||
| @@ -298,13 +298,15 @@ class IFreqaiModel(ABC): | ||||
|  | ||||
|             dk.set_new_model_names(pair, timestamp_model_id) | ||||
|  | ||||
|             if dk.check_if_backtest_prediction_is_valid(len(dataframe_backtest)): | ||||
|             if dk.check_if_backtest_prediction_is_valid(len_backtest_df): | ||||
|                 self.dd.load_metadata(dk) | ||||
|                 dk.find_features(dataframe_train) | ||||
|                 dk.find_features(dataframe) | ||||
|                 self.check_if_feature_list_matches_strategy(dk) | ||||
|                 append_df = dk.get_backtesting_prediction() | ||||
|                 dk.append_predictions(append_df) | ||||
|             else: | ||||
|                 dataframe_train = dk.slice_dataframe(tr_train, dataframe) | ||||
|                 dataframe_backtest = dk.slice_dataframe(tr_backtest, dataframe) | ||||
|                 if not self.model_exists(dk): | ||||
|                     dk.find_features(dataframe_train) | ||||
|                     dk.find_labels(dataframe_train) | ||||
| @@ -804,16 +806,16 @@ class IFreqaiModel(ABC): | ||||
|             self.pair_it = 1 | ||||
|             self.current_candle = self.dd.current_candle | ||||
|  | ||||
|     def ensure_data_exists(self, dataframe_backtest: DataFrame, | ||||
|     def ensure_data_exists(self, len_dataframe_backtest: int, | ||||
|                            tr_backtest: TimeRange, pair: str) -> bool: | ||||
|         """ | ||||
|         Check if the dataframe is empty, if not, report useful information to user. | ||||
|         :param dataframe_backtest: the backtesting dataframe, maybe empty. | ||||
|         :param len_dataframe_backtest: the len of backtesting dataframe | ||||
|         :param tr_backtest: current backtesting timerange. | ||||
|         :param pair: current pair | ||||
|         :return: if the data exists or not | ||||
|         """ | ||||
|         if self.config.get("freqai_backtest_live_models", False) and len(dataframe_backtest) == 0: | ||||
|         if self.config.get("freqai_backtest_live_models", False) and len_dataframe_backtest == 0: | ||||
|             logger.info(f"No data found for pair {pair} from " | ||||
|                         f"from { tr_backtest.start_fmt} to {tr_backtest.stop_fmt}. " | ||||
|                         "Probably more than one training within the same candle period.") | ||||
|   | ||||
| @@ -61,7 +61,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel): | ||||
|             model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs, | ||||
|                                     tensorboard_log=Path( | ||||
|                                         dk.full_path / "tensorboard" / dk.pair.split('/')[0]), | ||||
|                                     **self.freqai_info['model_training_parameters'] | ||||
|                                     **self.freqai_info.get('model_training_parameters', {}) | ||||
|                                     ) | ||||
|         else: | ||||
|             logger.info('Continual training activated - starting training from previously ' | ||||
| @@ -71,7 +71,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel): | ||||
|  | ||||
|         model.learn( | ||||
|             total_timesteps=int(total_timesteps), | ||||
|             callback=self.eval_callback | ||||
|             callback=[self.eval_callback, self.tensorboard_callback] | ||||
|         ) | ||||
|  | ||||
|         if Path(dk.data_path / "best_model.zip").is_file(): | ||||
| @@ -100,13 +100,17 @@ class ReinforcementLearner(BaseReinforcementLearningModel): | ||||
|             """ | ||||
|             # first, penalize if the action is not valid | ||||
|             if not self._is_valid(action): | ||||
|                 self.tensorboard_log("is_valid") | ||||
|                 return -2 | ||||
|  | ||||
|             pnl = self.get_unrealized_profit() | ||||
|             factor = 100. | ||||
|  | ||||
|             # reward agent for entering trades | ||||
|             if (action in (Actions.Long_enter.value, Actions.Short_enter.value) | ||||
|             if (action == Actions.Long_enter.value | ||||
|                     and self._position == Positions.Neutral): | ||||
|                 return 25 | ||||
|             if (action == Actions.Short_enter.value | ||||
|                     and self._position == Positions.Neutral): | ||||
|                 return 25 | ||||
|             # discourage agent from not entering trades | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| import logging | ||||
| from typing import Any, Dict  # , Tuple | ||||
| from typing import Any, Dict | ||||
|  | ||||
| # import numpy.typing as npt | ||||
| from pandas import DataFrame | ||||
| from stable_baselines3.common.callbacks import EvalCallback | ||||
| from stable_baselines3.common.vec_env import SubprocVecEnv | ||||
| @@ -9,6 +8,7 @@ from stable_baselines3.common.vec_env import SubprocVecEnv | ||||
| from freqtrade.freqai.data_kitchen import FreqaiDataKitchen | ||||
| from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner | ||||
| from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env | ||||
| from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
| @@ -34,18 +34,24 @@ class ReinforcementLearner_multiproc(ReinforcementLearner): | ||||
|         train_df = data_dictionary["train_features"] | ||||
|         test_df = data_dictionary["test_features"] | ||||
|  | ||||
|         env_info = self.pack_env_dict() | ||||
|  | ||||
|         env_id = "train_env" | ||||
|         self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train, | ||||
|                                         self.reward_params, self.CONV_WIDTH, monitor=True, | ||||
|                                         config=self.config) for i | ||||
|         self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, | ||||
|                                         train_df, prices_train, | ||||
|                                         monitor=True, | ||||
|                                         env_info=env_info) for i | ||||
|                                         in range(self.max_threads)]) | ||||
|  | ||||
|         eval_env_id = 'eval_env' | ||||
|         self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, | ||||
|                                                 test_df, prices_test, | ||||
|                                                 self.reward_params, self.CONV_WIDTH, monitor=True, | ||||
|                                                 config=self.config) for i | ||||
|                                                 monitor=True, | ||||
|                                                 env_info=env_info) for i | ||||
|                                        in range(self.max_threads)]) | ||||
|         self.eval_callback = EvalCallback(self.eval_env, deterministic=True, | ||||
|                                           render=False, eval_freq=len(train_df), | ||||
|                                           best_model_save_path=str(dk.data_path)) | ||||
|  | ||||
|         actions = self.train_env.env_method("get_actions")[0] | ||||
|         self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user