Merge pull request #7843 from smarmau/develop

freqai RL agent info during training
This commit is contained in:
Robert Caulk 2022-12-06 20:06:41 +01:00 committed by GitHub
commit b9f6911a6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 137 additions and 11 deletions

View File

@ -20,6 +20,9 @@ class Base4ActionRLEnv(BaseEnvironment):
""" """
Base class for a 4 action environment Base class for a 4 action environment
""" """
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.actions = Actions
def set_action_space(self): def set_action_space(self):
self.action_space = spaces.Discrete(len(Actions)) self.action_space = spaces.Discrete(len(Actions))
@ -92,9 +95,12 @@ class Base4ActionRLEnv(BaseEnvironment):
info = dict( info = dict(
tick=self._current_tick, tick=self._current_tick,
action=action,
total_reward=self.total_reward, total_reward=self.total_reward,
total_profit=self._total_profit, total_profit=self._total_profit,
position=self._position.value position=self._position.value,
trade_duration=self.get_trade_duration(),
current_profit_pct=self.get_unrealized_profit()
) )
observation = self._get_observation() observation = self._get_observation()

View File

@ -21,6 +21,9 @@ class Base5ActionRLEnv(BaseEnvironment):
""" """
Base class for a 5 action environment Base class for a 5 action environment
""" """
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.actions = Actions
def set_action_space(self): def set_action_space(self):
self.action_space = spaces.Discrete(len(Actions)) self.action_space = spaces.Discrete(len(Actions))
@ -98,9 +101,12 @@ class Base5ActionRLEnv(BaseEnvironment):
info = dict( info = dict(
tick=self._current_tick, tick=self._current_tick,
action=action,
total_reward=self.total_reward, total_reward=self.total_reward,
total_profit=self._total_profit, total_profit=self._total_profit,
position=self._position.value position=self._position.value,
trade_duration=self.get_trade_duration(),
current_profit_pct=self.get_unrealized_profit()
) )
observation = self._get_observation() observation = self._get_observation()

View File

@ -2,7 +2,7 @@ import logging
import random import random
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional, Type
import gym import gym
import numpy as np import numpy as np
@ -17,6 +17,17 @@ from freqtrade.data.dataprovider import DataProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BaseActions(Enum):
"""
Default action space, mostly used for type handling.
"""
Neutral = 0
Long_enter = 1
Long_exit = 2
Short_enter = 3
Short_exit = 4
class Positions(Enum): class Positions(Enum):
Short = 0 Short = 0
Long = 1 Long = 1
@ -64,6 +75,10 @@ class BaseEnvironment(gym.Env):
else: else:
self.fee = 0.0015 self.fee = 0.0015
# set here to default 5Ac, but all children envs can override this
self.actions: Type[Enum] = BaseActions
self.custom_info: dict = {}
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int, def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
reward_kwargs: dict, starting_point=True): reward_kwargs: dict, starting_point=True):
""" """
@ -118,6 +133,19 @@ class BaseEnvironment(gym.Env):
return [seed] return [seed]
def reset(self): def reset(self):
"""
Reset is called at the beginning of every episode
"""
# custom_info is used for episodic reports and tensorboard logging
self.custom_info["Invalid"] = 0
self.custom_info["Hold"] = 0
self.custom_info["Unknown"] = 0
self.custom_info["pnl_factor"] = 0
self.custom_info["duration_factor"] = 0
self.custom_info["reward_exit"] = 0
self.custom_info["reward_hold"] = 0
for action in self.actions:
self.custom_info[f"{action.name}"] = 0
self._done = False self._done = False
@ -271,6 +299,13 @@ class BaseEnvironment(gym.Env):
def current_price(self) -> float: def current_price(self) -> float:
return self.prices.iloc[self._current_tick].open return self.prices.iloc[self._current_tick].open
def get_actions(self) -> Type[Enum]:
"""
Used by SubprocVecEnv to get actions from
initialized env for tensorboard callback
"""
return self.actions
# Keeping around incase we want to start building more complex environment # Keeping around incase we want to start building more complex environment
# templates in the future. # templates in the future.
# def most_recent_return(self): # def most_recent_return(self):

View File

@ -21,7 +21,8 @@ from freqtrade.exceptions import OperationalException
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
from freqtrade.freqai.RL.BaseEnvironment import Positions from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
from freqtrade.persistence import Trade from freqtrade.persistence import Trade
@ -44,8 +45,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
'cpu_count', 1), max(int(self.max_system_threads / 2), 1)) 'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
th.set_num_threads(self.max_threads) th.set_num_threads(self.max_threads)
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: Union[SubprocVecEnv, gym.Env] = None self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
self.eval_env: Union[SubprocVecEnv, gym.Env] = None self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
self.eval_callback: Optional[EvalCallback] = None self.eval_callback: Optional[EvalCallback] = None
self.model_type = self.freqai_info['rl_config']['model_type'] self.model_type = self.freqai_info['rl_config']['model_type']
self.rl_config = self.freqai_info['rl_config'] self.rl_config = self.freqai_info['rl_config']
@ -65,6 +66,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.unset_outlier_removal() self.unset_outlier_removal()
self.net_arch = self.rl_config.get('net_arch', [128, 128]) self.net_arch = self.rl_config.get('net_arch', [128, 128])
self.dd.model_type = import_str self.dd.model_type = import_str
self.tensorboard_callback: TensorboardCallback = \
TensorboardCallback(verbose=1, actions=BaseActions)
def unset_outlier_removal(self): def unset_outlier_removal(self):
""" """
@ -156,6 +159,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
render=False, eval_freq=len(train_df), render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path)) best_model_save_path=str(dk.data_path))
actions = self.train_env.get_actions()
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
@abstractmethod @abstractmethod
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs): def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
""" """

View File

@ -0,0 +1,60 @@
from enum import Enum
from typing import Any, Dict, Type, Union
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment
class TensorboardCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard and
episodic summary reports.
"""
def __init__(self, verbose=1, actions: Type[Enum] = BaseActions):
super(TensorboardCallback, self).__init__(verbose)
self.model: Any = None
self.logger = None # type: Any
self.training_env: BaseEnvironment = None # type: ignore
self.actions: Type[Enum] = actions
def _on_training_start(self) -> None:
hparam_dict = {
"algorithm": self.model.__class__.__name__,
"learning_rate": self.model.learning_rate,
# "gamma": self.model.gamma,
# "gae_lambda": self.model.gae_lambda,
# "batch_size": self.model.batch_size,
# "n_steps": self.model.n_steps,
}
metric_dict: Dict[str, Union[float, int]] = {
"eval/mean_reward": 0,
"rollout/ep_rew_mean": 0,
"rollout/ep_len_mean": 0,
"train/value_loss": 0,
"train/explained_variance": 0,
}
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
exclude=("stdout", "log", "json", "csv"),
)
def _on_step(self) -> bool:
custom_info = self.training_env.get_attr("custom_info")[0]
self.logger.record("_state/position", self.locals["infos"][0]["position"])
self.logger.record("_state/trade_duration", self.locals["infos"][0]["trade_duration"])
self.logger.record("_state/current_profit_pct", self.locals["infos"]
[0]["current_profit_pct"])
self.logger.record("_reward/total_profit", self.locals["infos"][0]["total_profit"])
self.logger.record("_reward/total_reward", self.locals["infos"][0]["total_reward"])
self.logger.record_mean("_reward/mean_trade_duration", self.locals["infos"]
[0]["trade_duration"])
self.logger.record("_actions/action", self.locals["infos"][0]["action"])
self.logger.record("_actions/_Invalid", custom_info["Invalid"])
self.logger.record("_actions/_Unknown", custom_info["Unknown"])
self.logger.record("_actions/Hold", custom_info["Hold"])
for action in self.actions:
self.logger.record(f"_actions/{action.name}", custom_info[action.name])
return True

View File

@ -71,7 +71,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
model.learn( model.learn(
total_timesteps=int(total_timesteps), total_timesteps=int(total_timesteps),
callback=self.eval_callback callback=[self.eval_callback, self.tensorboard_callback]
) )
if Path(dk.data_path / "best_model.zip").is_file(): if Path(dk.data_path / "best_model.zip").is_file():
@ -100,17 +100,24 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
""" """
# first, penalize if the action is not valid # first, penalize if the action is not valid
if not self._is_valid(action): if not self._is_valid(action):
self.custom_info["Invalid"] += 1
return -2 return -2
pnl = self.get_unrealized_profit() pnl = self.get_unrealized_profit()
factor = 100. factor = 100.
# reward agent for entering trades # reward agent for entering trades
if (action in (Actions.Long_enter.value, Actions.Short_enter.value) if (action == Actions.Long_enter.value
and self._position == Positions.Neutral): and self._position == Positions.Neutral):
self.custom_info[f"{Actions.Long_enter.name}"] += 1
return 25
if (action == Actions.Short_enter.value
and self._position == Positions.Neutral):
self.custom_info[f"{Actions.Short_enter.name}"] += 1
return 25 return 25
# discourage agent from not entering trades # discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral: if action == Actions.Neutral.value and self._position == Positions.Neutral:
self.custom_info[f"{Actions.Neutral.name}"] += 1
return -1 return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
@ -124,18 +131,22 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
# discourage sitting in position # discourage sitting in position
if (self._position in (Positions.Short, Positions.Long) and if (self._position in (Positions.Short, Positions.Long) and
action == Actions.Neutral.value): action == Actions.Neutral.value):
self.custom_info["Hold"] += 1
return -1 * trade_duration / max_trade_duration return -1 * trade_duration / max_trade_duration
# close long # close long
if action == Actions.Long_exit.value and self._position == Positions.Long: if action == Actions.Long_exit.value and self._position == Positions.Long:
if pnl > self.profit_aim * self.rr: if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
self.custom_info[f"{Actions.Long_exit.name}"] += 1
return float(pnl * factor) return float(pnl * factor)
# close short # close short
if action == Actions.Short_exit.value and self._position == Positions.Short: if action == Actions.Short_exit.value and self._position == Positions.Short:
if pnl > self.profit_aim * self.rr: if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
self.custom_info[f"{Actions.Short_exit.name}"] += 1
return float(pnl * factor) return float(pnl * factor)
self.custom_info["Unknown"] += 1
return 0. return 0.

View File

@ -1,7 +1,6 @@
import logging import logging
from typing import Any, Dict # , Tuple from typing import Any, Dict
# import numpy.typing as npt
from pandas import DataFrame from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import SubprocVecEnv from stable_baselines3.common.vec_env import SubprocVecEnv
@ -9,6 +8,7 @@ from stable_baselines3.common.vec_env import SubprocVecEnv
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,3 +49,6 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
self.eval_callback = EvalCallback(self.eval_env, deterministic=True, self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df), render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path)) best_model_save_path=str(dk.data_path))
actions = self.train_env.env_method("get_actions")[0]
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)

View File

@ -237,7 +237,6 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog)
df = freqai.cache_corr_pairlist_dfs(df, freqai.dk) df = freqai.cache_corr_pairlist_dfs(df, freqai.dk)
for i in range(5): for i in range(5):
df[f'%-constant_{i}'] = i df[f'%-constant_{i}'] = i
# df.loc[:, f'%-constant_{i}'] = i
metadata = {"pair": "LTC/BTC"} metadata = {"pair": "LTC/BTC"}
freqai.start_backtesting(df, metadata, freqai.dk) freqai.start_backtesting(df, metadata, freqai.dk)