Merge remote-tracking branch 'origin/develop' into update-freqai-tf-handling
This commit is contained in:
@@ -71,7 +71,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
|
||||
model.learn(
|
||||
total_timesteps=int(total_timesteps),
|
||||
callback=self.eval_callback
|
||||
callback=[self.eval_callback, self.tensorboard_callback]
|
||||
)
|
||||
|
||||
if Path(dk.data_path / "best_model.zip").is_file():
|
||||
@@ -100,17 +100,24 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
"""
|
||||
# first, penalize if the action is not valid
|
||||
if not self._is_valid(action):
|
||||
self.custom_info["Invalid"] += 1
|
||||
return -2
|
||||
|
||||
pnl = self.get_unrealized_profit()
|
||||
factor = 100.
|
||||
|
||||
# reward agent for entering trades
|
||||
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
||||
if (action == Actions.Long_enter.value
|
||||
and self._position == Positions.Neutral):
|
||||
self.custom_info[f"{Actions.Long_enter.name}"] += 1
|
||||
return 25
|
||||
if (action == Actions.Short_enter.value
|
||||
and self._position == Positions.Neutral):
|
||||
self.custom_info[f"{Actions.Short_enter.name}"] += 1
|
||||
return 25
|
||||
# discourage agent from not entering trades
|
||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||
self.custom_info[f"{Actions.Neutral.name}"] += 1
|
||||
return -1
|
||||
|
||||
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
||||
@@ -124,18 +131,22 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
# discourage sitting in position
|
||||
if (self._position in (Positions.Short, Positions.Long) and
|
||||
action == Actions.Neutral.value):
|
||||
self.custom_info["Hold"] += 1
|
||||
return -1 * trade_duration / max_trade_duration
|
||||
|
||||
# close long
|
||||
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
||||
if pnl > self.profit_aim * self.rr:
|
||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||
self.custom_info[f"{Actions.Long_exit.name}"] += 1
|
||||
return float(pnl * factor)
|
||||
|
||||
# close short
|
||||
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
||||
if pnl > self.profit_aim * self.rr:
|
||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||
self.custom_info[f"{Actions.Short_exit.name}"] += 1
|
||||
return float(pnl * factor)
|
||||
|
||||
self.custom_info["Unknown"] += 1
|
||||
return 0.
|
||||
|
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from typing import Any, Dict # , Tuple
|
||||
from typing import Any, Dict
|
||||
|
||||
# import numpy.typing as npt
|
||||
from pandas import DataFrame
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
@@ -9,6 +8,7 @@ from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
|
||||
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -49,3 +49,6 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=len(train_df),
|
||||
best_model_save_path=str(dk.data_path))
|
||||
|
||||
actions = self.train_env.env_method("get_actions")[0]
|
||||
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
||||
|
Reference in New Issue
Block a user