reorganize/generalize tensorboard callback
This commit is contained in:
parent
b2edc58089
commit
24766928ba
@ -20,6 +20,9 @@ class Base4ActionRLEnv(BaseEnvironment):
|
|||||||
"""
|
"""
|
||||||
Base class for a 4 action environment
|
Base class for a 4 action environment
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, *args):
|
||||||
|
super().__init__(*args)
|
||||||
|
self.actions = Actions
|
||||||
|
|
||||||
def set_action_space(self):
|
def set_action_space(self):
|
||||||
self.action_space = spaces.Discrete(len(Actions))
|
self.action_space = spaces.Discrete(len(Actions))
|
||||||
@ -92,9 +95,12 @@ class Base4ActionRLEnv(BaseEnvironment):
|
|||||||
|
|
||||||
info = dict(
|
info = dict(
|
||||||
tick=self._current_tick,
|
tick=self._current_tick,
|
||||||
|
action=action,
|
||||||
total_reward=self.total_reward,
|
total_reward=self.total_reward,
|
||||||
total_profit=self._total_profit,
|
total_profit=self._total_profit,
|
||||||
position=self._position.value
|
position=self._position.value,
|
||||||
|
trade_duration=self.get_trade_duration(),
|
||||||
|
current_profit_pct=self.get_unrealized_profit()
|
||||||
)
|
)
|
||||||
|
|
||||||
observation = self._get_observation()
|
observation = self._get_observation()
|
||||||
|
@ -21,6 +21,9 @@ class Base5ActionRLEnv(BaseEnvironment):
|
|||||||
"""
|
"""
|
||||||
Base class for a 5 action environment
|
Base class for a 5 action environment
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.actions = Actions
|
||||||
|
|
||||||
def set_action_space(self):
|
def set_action_space(self):
|
||||||
self.action_space = spaces.Discrete(len(Actions))
|
self.action_space = spaces.Discrete(len(Actions))
|
||||||
@ -98,9 +101,12 @@ class Base5ActionRLEnv(BaseEnvironment):
|
|||||||
|
|
||||||
info = dict(
|
info = dict(
|
||||||
tick=self._current_tick,
|
tick=self._current_tick,
|
||||||
|
action=action,
|
||||||
total_reward=self.total_reward,
|
total_reward=self.total_reward,
|
||||||
total_profit=self._total_profit,
|
total_profit=self._total_profit,
|
||||||
position=self._position.value
|
position=self._position.value,
|
||||||
|
trade_duration=self.get_trade_duration(),
|
||||||
|
current_profit_pct=self.get_unrealized_profit()
|
||||||
)
|
)
|
||||||
|
|
||||||
observation = self._get_observation()
|
observation = self._get_observation()
|
||||||
|
@ -2,7 +2,7 @@ import logging
|
|||||||
import random
|
import random
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional, Type
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -17,6 +17,17 @@ from freqtrade.data.dataprovider import DataProvider
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseActions(Enum):
|
||||||
|
"""
|
||||||
|
Default action space, mostly used for type handling.
|
||||||
|
"""
|
||||||
|
Neutral = 0
|
||||||
|
Long_enter = 1
|
||||||
|
Long_exit = 2
|
||||||
|
Short_enter = 3
|
||||||
|
Short_exit = 4
|
||||||
|
|
||||||
|
|
||||||
class Positions(Enum):
|
class Positions(Enum):
|
||||||
Short = 0
|
Short = 0
|
||||||
Long = 1
|
Long = 1
|
||||||
@ -64,6 +75,9 @@ class BaseEnvironment(gym.Env):
|
|||||||
else:
|
else:
|
||||||
self.fee = 0.0015
|
self.fee = 0.0015
|
||||||
|
|
||||||
|
# set here to default 5Ac, but all children envs can overwrite this
|
||||||
|
self.actions: Type[Enum] = BaseActions
|
||||||
|
|
||||||
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
|
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
|
||||||
reward_kwargs: dict, starting_point=True):
|
reward_kwargs: dict, starting_point=True):
|
||||||
"""
|
"""
|
||||||
@ -106,6 +120,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
self._total_unrealized_profit: float = 1
|
self._total_unrealized_profit: float = 1
|
||||||
self.history: dict = {}
|
self.history: dict = {}
|
||||||
self.trade_history: list = []
|
self.trade_history: list = []
|
||||||
|
self.custom_info: dict = {}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_action_space(self):
|
def set_action_space(self):
|
||||||
@ -118,6 +133,19 @@ class BaseEnvironment(gym.Env):
|
|||||||
return [seed]
|
return [seed]
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset is called at the beginning of every episode
|
||||||
|
"""
|
||||||
|
# custom_info is used for episodic reports and tensorboard logging
|
||||||
|
self.custom_info["Invalid"] = 0
|
||||||
|
self.custom_info["Hold"] = 0
|
||||||
|
self.custom_info["Unknown"] = 0
|
||||||
|
self.custom_info["pnl_factor"] = 0
|
||||||
|
self.custom_info["duration_factor"] = 0
|
||||||
|
self.custom_info["reward_exit"] = 0
|
||||||
|
self.custom_info["reward_hold"] = 0
|
||||||
|
for action in self.actions:
|
||||||
|
self.custom_info[f"{action.name}"] = 0
|
||||||
|
|
||||||
self._done = False
|
self._done = False
|
||||||
|
|
||||||
@ -271,6 +299,13 @@ class BaseEnvironment(gym.Env):
|
|||||||
def current_price(self) -> float:
|
def current_price(self) -> float:
|
||||||
return self.prices.iloc[self._current_tick].open
|
return self.prices.iloc[self._current_tick].open
|
||||||
|
|
||||||
|
def get_actions(self) -> Type[Enum]:
|
||||||
|
"""
|
||||||
|
Used by SubprocVecEnv to get actions from
|
||||||
|
initialized env for tensorboard callback
|
||||||
|
"""
|
||||||
|
return self.actions
|
||||||
|
|
||||||
# Keeping around incase we want to start building more complex environment
|
# Keeping around incase we want to start building more complex environment
|
||||||
# templates in the future.
|
# templates in the future.
|
||||||
# def most_recent_return(self):
|
# def most_recent_return(self):
|
||||||
|
@ -12,8 +12,7 @@ import pandas as pd
|
|||||||
import torch as th
|
import torch as th
|
||||||
import torch.multiprocessing
|
import torch.multiprocessing
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
|
from stable_baselines3.common.callbacks import EvalCallback
|
||||||
from stable_baselines3.common.logger import HParam
|
|
||||||
from stable_baselines3.common.monitor import Monitor
|
from stable_baselines3.common.monitor import Monitor
|
||||||
from stable_baselines3.common.utils import set_random_seed
|
from stable_baselines3.common.utils import set_random_seed
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||||
@ -22,7 +21,8 @@ from freqtrade.exceptions import OperationalException
|
|||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
||||||
from freqtrade.freqai.RL.BaseEnvironment import Positions
|
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions
|
||||||
|
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
||||||
from freqtrade.persistence import Trade
|
from freqtrade.persistence import Trade
|
||||||
|
|
||||||
|
|
||||||
@ -45,8 +45,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
|
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
|
||||||
th.set_num_threads(self.max_threads)
|
th.set_num_threads(self.max_threads)
|
||||||
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
||||||
self.train_env: Union[SubprocVecEnv, gym.Env] = None
|
self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
|
||||||
self.eval_env: Union[SubprocVecEnv, gym.Env] = None
|
self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
|
||||||
self.eval_callback: Optional[EvalCallback] = None
|
self.eval_callback: Optional[EvalCallback] = None
|
||||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||||
self.rl_config = self.freqai_info['rl_config']
|
self.rl_config = self.freqai_info['rl_config']
|
||||||
@ -66,6 +66,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
self.unset_outlier_removal()
|
self.unset_outlier_removal()
|
||||||
self.net_arch = self.rl_config.get('net_arch', [128, 128])
|
self.net_arch = self.rl_config.get('net_arch', [128, 128])
|
||||||
self.dd.model_type = import_str
|
self.dd.model_type = import_str
|
||||||
|
self.tensorboard_callback: TensorboardCallback = \
|
||||||
|
TensorboardCallback(verbose=1, actions=BaseActions)
|
||||||
|
|
||||||
def unset_outlier_removal(self):
|
def unset_outlier_removal(self):
|
||||||
"""
|
"""
|
||||||
@ -157,7 +159,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
render=False, eval_freq=len(train_df),
|
render=False, eval_freq=len(train_df),
|
||||||
best_model_save_path=str(dk.data_path))
|
best_model_save_path=str(dk.data_path))
|
||||||
|
|
||||||
self.tensorboard_callback = TensorboardCallback()
|
actions = self.train_env.get_actions()
|
||||||
|
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||||
@ -401,51 +404,3 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
|||||||
return env
|
return env
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
return _init
|
return _init
|
||||||
|
|
||||||
|
|
||||||
class TensorboardCallback(BaseCallback):
|
|
||||||
"""
|
|
||||||
Custom callback for plotting additional values in tensorboard.
|
|
||||||
"""
|
|
||||||
def __init__(self, verbose=1):
|
|
||||||
super(TensorboardCallback, self).__init__(verbose)
|
|
||||||
|
|
||||||
def _on_training_start(self) -> None:
|
|
||||||
hparam_dict = {
|
|
||||||
"algorithm": self.model.__class__.__name__,
|
|
||||||
"learning_rate": self.model.learning_rate,
|
|
||||||
"gamma": self.model.gamma,
|
|
||||||
"gae_lambda": self.model.gae_lambda,
|
|
||||||
"batch_size": self.model.batch_size,
|
|
||||||
"n_steps": self.model.n_steps,
|
|
||||||
}
|
|
||||||
metric_dict = {
|
|
||||||
"eval/mean_reward": 0,
|
|
||||||
"rollout/ep_rew_mean": 0,
|
|
||||||
"rollout/ep_len_mean": 0,
|
|
||||||
"train/value_loss": 0,
|
|
||||||
"train/explained_variance": 0,
|
|
||||||
}
|
|
||||||
self.logger.record(
|
|
||||||
"hparams",
|
|
||||||
HParam(hparam_dict, metric_dict),
|
|
||||||
exclude=("stdout", "log", "json", "csv"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
|
||||||
custom_info = self.training_env.get_attr("custom_info")[0]
|
|
||||||
self.logger.record("_state/position", self.locals["infos"][0]["position"])
|
|
||||||
self.logger.record("_state/trade_duration", self.locals["infos"][0]["trade_duration"])
|
|
||||||
self.logger.record("_state/current_profit_pct", self.locals["infos"]
|
|
||||||
[0]["current_profit_pct"])
|
|
||||||
self.logger.record("_reward/total_profit", self.locals["infos"][0]["total_profit"])
|
|
||||||
self.logger.record("_reward/total_reward", self.locals["infos"][0]["total_reward"])
|
|
||||||
self.logger.record_mean("_reward/mean_trade_duration", self.locals["infos"]
|
|
||||||
[0]["trade_duration"])
|
|
||||||
self.logger.record("_actions/action", self.locals["infos"][0]["action"])
|
|
||||||
self.logger.record("_actions/_Invalid", custom_info["Invalid"])
|
|
||||||
self.logger.record("_actions/_Unknown", custom_info["Unknown"])
|
|
||||||
self.logger.record("_actions/Hold", custom_info["Hold"])
|
|
||||||
for action in Actions:
|
|
||||||
self.logger.record(f"_actions/{action.name}", custom_info[action.name])
|
|
||||||
return True
|
|
||||||
|
61
freqtrade/freqai/RL/TensorboardCallback.py
Normal file
61
freqtrade/freqai/RL/TensorboardCallback.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Type, Union
|
||||||
|
|
||||||
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
|
from stable_baselines3.common.logger import HParam
|
||||||
|
|
||||||
|
from freqtrade.freqai.RL.BaseEnvironment import BaseActions
|
||||||
|
|
||||||
|
|
||||||
|
class TensorboardCallback(BaseCallback):
|
||||||
|
"""
|
||||||
|
Custom callback for plotting additional values in tensorboard and
|
||||||
|
episodic summary reports.
|
||||||
|
"""
|
||||||
|
def __init__(self, verbose=1, actions: Type[Enum] = BaseActions):
|
||||||
|
super(TensorboardCallback, self).__init__(verbose)
|
||||||
|
self.model: Any = None
|
||||||
|
# An alias for self.model.get_env(), the environment used for training
|
||||||
|
self.logger = None # type: Any
|
||||||
|
# self.training_env = None # type: Union[gym.Env, VecEnv]
|
||||||
|
self.actions: Type[Enum] = actions
|
||||||
|
|
||||||
|
def _on_training_start(self) -> None:
|
||||||
|
hparam_dict = {
|
||||||
|
"algorithm": self.model.__class__.__name__,
|
||||||
|
"learning_rate": self.model.learning_rate,
|
||||||
|
# "gamma": self.model.gamma,
|
||||||
|
# "gae_lambda": self.model.gae_lambda,
|
||||||
|
# "batch_size": self.model.batch_size,
|
||||||
|
# "n_steps": self.model.n_steps,
|
||||||
|
}
|
||||||
|
metric_dict: Dict[str, Union[float, int]] = {
|
||||||
|
"eval/mean_reward": 0,
|
||||||
|
"rollout/ep_rew_mean": 0,
|
||||||
|
"rollout/ep_len_mean": 0,
|
||||||
|
"train/value_loss": 0,
|
||||||
|
"train/explained_variance": 0,
|
||||||
|
}
|
||||||
|
self.logger.record(
|
||||||
|
"hparams",
|
||||||
|
HParam(hparam_dict, metric_dict),
|
||||||
|
exclude=("stdout", "log", "json", "csv"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_step(self) -> bool:
|
||||||
|
custom_info = self.training_env.get_attr("custom_info")[0] # type: ignore
|
||||||
|
self.logger.record("_state/position", self.locals["infos"][0]["position"])
|
||||||
|
self.logger.record("_state/trade_duration", self.locals["infos"][0]["trade_duration"])
|
||||||
|
self.logger.record("_state/current_profit_pct", self.locals["infos"]
|
||||||
|
[0]["current_profit_pct"])
|
||||||
|
self.logger.record("_reward/total_profit", self.locals["infos"][0]["total_profit"])
|
||||||
|
self.logger.record("_reward/total_reward", self.locals["infos"][0]["total_reward"])
|
||||||
|
self.logger.record_mean("_reward/mean_trade_duration", self.locals["infos"]
|
||||||
|
[0]["trade_duration"])
|
||||||
|
self.logger.record("_actions/action", self.locals["infos"][0]["action"])
|
||||||
|
self.logger.record("_actions/_Invalid", custom_info["Invalid"])
|
||||||
|
self.logger.record("_actions/_Unknown", custom_info["Unknown"])
|
||||||
|
self.logger.record("_actions/Hold", custom_info["Hold"])
|
||||||
|
for action in self.actions:
|
||||||
|
self.logger.record(f"_actions/{action.name}", custom_info[action.name])
|
||||||
|
return True
|
@ -88,33 +88,6 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
User can override any function in BaseRLEnv and gym.Env. Here the user
|
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||||
sets a custom reward based on profit and trade duration.
|
sets a custom reward based on profit and trade duration.
|
||||||
"""
|
"""
|
||||||
def reset(self):
|
|
||||||
|
|
||||||
# Reset custom info
|
|
||||||
self.custom_info = {}
|
|
||||||
self.custom_info["Invalid"] = 0
|
|
||||||
self.custom_info["Hold"] = 0
|
|
||||||
self.custom_info["Unknown"] = 0
|
|
||||||
self.custom_info["pnl_factor"] = 0
|
|
||||||
self.custom_info["duration_factor"] = 0
|
|
||||||
self.custom_info["reward_exit"] = 0
|
|
||||||
self.custom_info["reward_hold"] = 0
|
|
||||||
for action in Actions:
|
|
||||||
self.custom_info[f"{action.name}"] = 0
|
|
||||||
return super().reset()
|
|
||||||
|
|
||||||
def step(self, action: int):
|
|
||||||
observation, step_reward, done, info = super().step(action)
|
|
||||||
info = dict(
|
|
||||||
tick=self._current_tick,
|
|
||||||
action=action,
|
|
||||||
total_reward=self.total_reward,
|
|
||||||
total_profit=self._total_profit,
|
|
||||||
position=self._position.value,
|
|
||||||
trade_duration=self.get_trade_duration(),
|
|
||||||
current_profit_pct=self.get_unrealized_profit()
|
|
||||||
)
|
|
||||||
return observation, step_reward, done, info
|
|
||||||
|
|
||||||
def calculate_reward(self, action: int) -> float:
|
def calculate_reward(self, action: int) -> float:
|
||||||
"""
|
"""
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict # , Tuple
|
from typing import Any, Dict
|
||||||
|
|
||||||
# import numpy.typing as npt
|
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from stable_baselines3.common.callbacks import EvalCallback
|
from stable_baselines3.common.callbacks import EvalCallback
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||||
|
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import TensorboardCallback, make_env
|
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
|
||||||
|
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -50,4 +50,5 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
|
|||||||
render=False, eval_freq=len(train_df),
|
render=False, eval_freq=len(train_df),
|
||||||
best_model_save_path=str(dk.data_path))
|
best_model_save_path=str(dk.data_path))
|
||||||
|
|
||||||
self.tensorboard_callback = TensorboardCallback()
|
actions = self.train_env.env_method("get_actions")[0]
|
||||||
|
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
||||||
|
Loading…
Reference in New Issue
Block a user