Merge remote-tracking branch 'origin/develop' into feat/convolutional-neural-net
This commit is contained in:
		| @@ -61,7 +61,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel): | ||||
|             model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs, | ||||
|                                     tensorboard_log=Path( | ||||
|                                         dk.full_path / "tensorboard" / dk.pair.split('/')[0]), | ||||
|                                     **self.freqai_info['model_training_parameters'] | ||||
|                                     **self.freqai_info.get('model_training_parameters', {}) | ||||
|                                     ) | ||||
|         else: | ||||
|             logger.info('Continual training activated - starting training from previously ' | ||||
| @@ -71,7 +71,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel): | ||||
|  | ||||
|         model.learn( | ||||
|             total_timesteps=int(total_timesteps), | ||||
|             callback=self.eval_callback | ||||
|             callback=[self.eval_callback, self.tensorboard_callback] | ||||
|         ) | ||||
|  | ||||
|         if Path(dk.data_path / "best_model.zip").is_file(): | ||||
| @@ -100,13 +100,17 @@ class ReinforcementLearner(BaseReinforcementLearningModel): | ||||
|             """ | ||||
|             # first, penalize if the action is not valid | ||||
|             if not self._is_valid(action): | ||||
|                 self.tensorboard_log("is_valid") | ||||
|                 return -2 | ||||
|  | ||||
|             pnl = self.get_unrealized_profit() | ||||
|             factor = 100. | ||||
|  | ||||
|             # reward agent for entering trades | ||||
|             if (action in (Actions.Long_enter.value, Actions.Short_enter.value) | ||||
|             if (action == Actions.Long_enter.value | ||||
|                     and self._position == Positions.Neutral): | ||||
|                 return 25 | ||||
|             if (action == Actions.Short_enter.value | ||||
|                     and self._position == Positions.Neutral): | ||||
|                 return 25 | ||||
|             # discourage agent from not entering trades | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| import logging | ||||
| from typing import Any, Dict  # , Tuple | ||||
| from typing import Any, Dict | ||||
|  | ||||
| # import numpy.typing as npt | ||||
| from pandas import DataFrame | ||||
| from stable_baselines3.common.callbacks import EvalCallback | ||||
| from stable_baselines3.common.vec_env import SubprocVecEnv | ||||
| @@ -9,6 +8,7 @@ from stable_baselines3.common.vec_env import SubprocVecEnv | ||||
| from freqtrade.freqai.data_kitchen import FreqaiDataKitchen | ||||
| from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner | ||||
| from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env | ||||
| from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
| @@ -34,18 +34,24 @@ class ReinforcementLearner_multiproc(ReinforcementLearner): | ||||
|         train_df = data_dictionary["train_features"] | ||||
|         test_df = data_dictionary["test_features"] | ||||
|  | ||||
|         env_info = self.pack_env_dict() | ||||
|  | ||||
|         env_id = "train_env" | ||||
|         self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train, | ||||
|                                         self.reward_params, self.CONV_WIDTH, monitor=True, | ||||
|                                         config=self.config) for i | ||||
|         self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, | ||||
|                                         train_df, prices_train, | ||||
|                                         monitor=True, | ||||
|                                         env_info=env_info) for i | ||||
|                                         in range(self.max_threads)]) | ||||
|  | ||||
|         eval_env_id = 'eval_env' | ||||
|         self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, | ||||
|                                                 test_df, prices_test, | ||||
|                                                 self.reward_params, self.CONV_WIDTH, monitor=True, | ||||
|                                                 config=self.config) for i | ||||
|                                                 monitor=True, | ||||
|                                                 env_info=env_info) for i | ||||
|                                        in range(self.max_threads)]) | ||||
|         self.eval_callback = EvalCallback(self.eval_env, deterministic=True, | ||||
|                                           render=False, eval_freq=len(train_df), | ||||
|                                           best_model_save_path=str(dk.data_path)) | ||||
|  | ||||
|         actions = self.train_env.env_method("get_actions")[0] | ||||
|         self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user