Merge branch 'develop' of github.com:lolongcovas/freqtrade into strategies
This commit is contained in:
125
freqtrade/freqai/RL/Base3ActionRLEnv.py
Normal file
125
freqtrade/freqai/RL/Base3ActionRLEnv.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from gym import spaces
|
||||
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Actions(Enum):
|
||||
Neutral = 0
|
||||
Buy = 1
|
||||
Sell = 2
|
||||
|
||||
|
||||
class Base3ActionRLEnv(BaseEnvironment):
|
||||
"""
|
||||
Base class for a 3 action environment
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.actions = Actions
|
||||
|
||||
def set_action_space(self):
|
||||
self.action_space = spaces.Discrete(len(Actions))
|
||||
|
||||
def step(self, action: int):
|
||||
"""
|
||||
Logic for a single step (incrementing one candle in time)
|
||||
by the agent
|
||||
:param: action: int = the action type that the agent plans
|
||||
to take for the current step.
|
||||
:returns:
|
||||
observation = current state of environment
|
||||
step_reward = the reward from `calculate_reward()`
|
||||
_done = if the agent "died" or if the candles finished
|
||||
info = dict passed back to openai gym lib
|
||||
"""
|
||||
self._done = False
|
||||
self._current_tick += 1
|
||||
|
||||
if self._current_tick == self._end_tick:
|
||||
self._done = True
|
||||
|
||||
self._update_unrealized_total_profit()
|
||||
step_reward = self.calculate_reward(action)
|
||||
self.total_reward += step_reward
|
||||
self.tensorboard_log(self.actions._member_names_[action])
|
||||
|
||||
trade_type = None
|
||||
if self.is_tradesignal(action):
|
||||
if action == Actions.Buy.value:
|
||||
if self._position == Positions.Short:
|
||||
self._update_total_profit()
|
||||
self._position = Positions.Long
|
||||
trade_type = "long"
|
||||
self._last_trade_tick = self._current_tick
|
||||
elif action == Actions.Sell.value and self.can_short:
|
||||
if self._position == Positions.Long:
|
||||
self._update_total_profit()
|
||||
self._position = Positions.Short
|
||||
trade_type = "short"
|
||||
self._last_trade_tick = self._current_tick
|
||||
elif action == Actions.Sell.value and not self.can_short:
|
||||
self._update_total_profit()
|
||||
self._position = Positions.Neutral
|
||||
trade_type = "neutral"
|
||||
self._last_trade_tick = None
|
||||
else:
|
||||
print("case not defined")
|
||||
|
||||
if trade_type is not None:
|
||||
self.trade_history.append(
|
||||
{'price': self.current_price(), 'index': self._current_tick,
|
||||
'type': trade_type})
|
||||
|
||||
if (self._total_profit < self.max_drawdown or
|
||||
self._total_unrealized_profit < self.max_drawdown):
|
||||
self._done = True
|
||||
|
||||
self._position_history.append(self._position)
|
||||
|
||||
info = dict(
|
||||
tick=self._current_tick,
|
||||
action=action,
|
||||
total_reward=self.total_reward,
|
||||
total_profit=self._total_profit,
|
||||
position=self._position.value,
|
||||
trade_duration=self.get_trade_duration(),
|
||||
current_profit_pct=self.get_unrealized_profit()
|
||||
)
|
||||
|
||||
observation = self._get_observation()
|
||||
|
||||
self._update_history(info)
|
||||
|
||||
return observation, step_reward, self._done, info
|
||||
|
||||
def is_tradesignal(self, action: int) -> bool:
|
||||
"""
|
||||
Determine if the signal is a trade signal
|
||||
e.g.: agent wants a Actions.Buy while it is in a Positions.short
|
||||
"""
|
||||
return (
|
||||
(action == Actions.Buy.value and self._position == Positions.Neutral)
|
||||
or (action == Actions.Sell.value and self._position == Positions.Long)
|
||||
or (action == Actions.Sell.value and self._position == Positions.Neutral
|
||||
and self.can_short)
|
||||
or (action == Actions.Buy.value and self._position == Positions.Short
|
||||
and self.can_short)
|
||||
)
|
||||
|
||||
def _is_valid(self, action: int) -> bool:
|
||||
"""
|
||||
Determine if the signal is valid.
|
||||
e.g.: agent wants a Actions.Sell while it is in a Positions.Long
|
||||
"""
|
||||
if self.can_short:
|
||||
return action in [Actions.Buy.value, Actions.Sell.value, Actions.Neutral.value]
|
||||
else:
|
||||
if action == Actions.Sell.value and self._position != Positions.Long:
|
||||
return False
|
||||
return True
|
@@ -20,6 +20,9 @@ class Base4ActionRLEnv(BaseEnvironment):
|
||||
"""
|
||||
Base class for a 4 action environment
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.actions = Actions
|
||||
|
||||
def set_action_space(self):
|
||||
self.action_space = spaces.Discrete(len(Actions))
|
||||
@@ -43,9 +46,9 @@ class Base4ActionRLEnv(BaseEnvironment):
|
||||
self._done = True
|
||||
|
||||
self._update_unrealized_total_profit()
|
||||
|
||||
step_reward = self.calculate_reward(action)
|
||||
self.total_reward += step_reward
|
||||
self.tensorboard_log(self.actions._member_names_[action])
|
||||
|
||||
trade_type = None
|
||||
if self.is_tradesignal(action):
|
||||
@@ -85,16 +88,20 @@ class Base4ActionRLEnv(BaseEnvironment):
|
||||
{'price': self.current_price(), 'index': self._current_tick,
|
||||
'type': trade_type})
|
||||
|
||||
if self._total_profit < 1 - self.rl_config.get('max_training_drawdown_pct', 0.8):
|
||||
if (self._total_profit < self.max_drawdown or
|
||||
self._total_unrealized_profit < self.max_drawdown):
|
||||
self._done = True
|
||||
|
||||
self._position_history.append(self._position)
|
||||
|
||||
info = dict(
|
||||
tick=self._current_tick,
|
||||
action=action,
|
||||
total_reward=self.total_reward,
|
||||
total_profit=self._total_profit,
|
||||
position=self._position.value
|
||||
position=self._position.value,
|
||||
trade_duration=self.get_trade_duration(),
|
||||
current_profit_pct=self.get_unrealized_profit()
|
||||
)
|
||||
|
||||
observation = self._get_observation()
|
||||
|
@@ -21,6 +21,9 @@ class Base5ActionRLEnv(BaseEnvironment):
|
||||
"""
|
||||
Base class for a 5 action environment
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.actions = Actions
|
||||
|
||||
def set_action_space(self):
|
||||
self.action_space = spaces.Discrete(len(Actions))
|
||||
@@ -46,6 +49,7 @@ class Base5ActionRLEnv(BaseEnvironment):
|
||||
self._update_unrealized_total_profit()
|
||||
step_reward = self.calculate_reward(action)
|
||||
self.total_reward += step_reward
|
||||
self.tensorboard_log(self.actions._member_names_[action])
|
||||
|
||||
trade_type = None
|
||||
if self.is_tradesignal(action):
|
||||
@@ -98,9 +102,12 @@ class Base5ActionRLEnv(BaseEnvironment):
|
||||
|
||||
info = dict(
|
||||
tick=self._current_tick,
|
||||
action=action,
|
||||
total_reward=self.total_reward,
|
||||
total_profit=self._total_profit,
|
||||
position=self._position.value
|
||||
position=self._position.value,
|
||||
trade_duration=self.get_trade_duration(),
|
||||
current_profit_pct=self.get_unrealized_profit()
|
||||
)
|
||||
|
||||
observation = self._get_observation()
|
||||
|
@@ -2,7 +2,7 @@ import logging
|
||||
import random
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@@ -11,12 +11,21 @@ from gym import spaces
|
||||
from gym.utils import seeding
|
||||
from pandas import DataFrame
|
||||
|
||||
from freqtrade.data.dataprovider import DataProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseActions(Enum):
|
||||
"""
|
||||
Default action space, mostly used for type handling.
|
||||
"""
|
||||
Neutral = 0
|
||||
Long_enter = 1
|
||||
Long_exit = 2
|
||||
Short_enter = 3
|
||||
Short_exit = 4
|
||||
|
||||
|
||||
class Positions(Enum):
|
||||
Short = 0
|
||||
Long = 1
|
||||
@@ -35,8 +44,8 @@ class BaseEnvironment(gym.Env):
|
||||
|
||||
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
||||
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
||||
id: str = 'baseenv-1', seed: int = 1, config: dict = {},
|
||||
dp: Optional[DataProvider] = None):
|
||||
id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
|
||||
fee: float = 0.0015, can_short: bool = False):
|
||||
"""
|
||||
Initializes the training/eval environment.
|
||||
:param df: dataframe of features
|
||||
@@ -47,22 +56,31 @@ class BaseEnvironment(gym.Env):
|
||||
:param id: string id of the environment (used in backend for multiprocessed env)
|
||||
:param seed: Sets the seed of the environment higher in the gym.Env object
|
||||
:param config: Typical user configuration file
|
||||
:param dp: dataprovider from freqtrade
|
||||
:param live: Whether or not this environment is active in dry/live/backtesting
|
||||
:param fee: The fee to use for environmental interactions.
|
||||
:param can_short: Whether or not the environment can short
|
||||
"""
|
||||
self.config = config
|
||||
self.rl_config = config['freqai']['rl_config']
|
||||
self.add_state_info = self.rl_config.get('add_state_info', False)
|
||||
self.id = id
|
||||
self.seed(seed)
|
||||
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
||||
self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
|
||||
self.compound_trades = config['stake_amount'] == 'unlimited'
|
||||
if self.config.get('fee', None) is not None:
|
||||
self.fee = self.config['fee']
|
||||
elif dp is not None:
|
||||
self.fee = dp._exchange.get_fee(symbol=dp.current_whitelist()[0]) # type: ignore
|
||||
else:
|
||||
self.fee = 0.0015
|
||||
self.fee = fee
|
||||
|
||||
# set here to default 5Ac, but all children envs can override this
|
||||
self.actions: Type[Enum] = BaseActions
|
||||
self.tensorboard_metrics: dict = {}
|
||||
self.can_short = can_short
|
||||
self.live = live
|
||||
if not self.live and self.add_state_info:
|
||||
self.add_state_info = False
|
||||
logger.warning("add_state_info is not available in backtesting. Deactivating.")
|
||||
self.seed(seed)
|
||||
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
||||
|
||||
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
|
||||
reward_kwargs: dict, starting_point=True):
|
||||
@@ -117,7 +135,38 @@ class BaseEnvironment(gym.Env):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def tensorboard_log(self, metric: str, value: Union[int, float] = 1, inc: bool = True):
|
||||
"""
|
||||
Function builds the tensorboard_metrics dictionary
|
||||
to be parsed by the TensorboardCallback. This
|
||||
function is designed for tracking incremented objects,
|
||||
events, actions inside the training environment.
|
||||
For example, a user can call this to track the
|
||||
frequency of occurence of an `is_valid` call in
|
||||
their `calculate_reward()`:
|
||||
|
||||
def calculate_reward(self, action: int) -> float:
|
||||
if not self._is_valid(action):
|
||||
self.tensorboard_log("is_valid")
|
||||
return -2
|
||||
|
||||
:param metric: metric to be tracked and incremented
|
||||
:param value: value to increment `metric` by
|
||||
:param inc: sets whether the `value` is incremented or not
|
||||
"""
|
||||
if not inc or metric not in self.tensorboard_metrics:
|
||||
self.tensorboard_metrics[metric] = value
|
||||
else:
|
||||
self.tensorboard_metrics[metric] += value
|
||||
|
||||
def reset_tensorboard_log(self):
|
||||
self.tensorboard_metrics = {}
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset is called at the beginning of every episode
|
||||
"""
|
||||
self.reset_tensorboard_log()
|
||||
|
||||
self._done = False
|
||||
|
||||
@@ -271,6 +320,13 @@ class BaseEnvironment(gym.Env):
|
||||
def current_price(self) -> float:
|
||||
return self.prices.iloc[self._current_tick].open
|
||||
|
||||
def get_actions(self) -> Type[Enum]:
|
||||
"""
|
||||
Used by SubprocVecEnv to get actions from
|
||||
initialized env for tensorboard callback
|
||||
"""
|
||||
return self.actions
|
||||
|
||||
# Keeping around incase we want to start building more complex environment
|
||||
# templates in the future.
|
||||
# def most_recent_return(self):
|
||||
|
@@ -21,7 +21,8 @@ from freqtrade.exceptions import OperationalException
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
||||
from freqtrade.freqai.RL.BaseEnvironment import Positions
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions
|
||||
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
||||
from freqtrade.persistence import Trade
|
||||
|
||||
|
||||
@@ -44,8 +45,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
|
||||
th.set_num_threads(self.max_threads)
|
||||
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
||||
self.train_env: Union[SubprocVecEnv, gym.Env] = None
|
||||
self.eval_env: Union[SubprocVecEnv, gym.Env] = None
|
||||
self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
|
||||
self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
|
||||
self.eval_callback: Optional[EvalCallback] = None
|
||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||
self.rl_config = self.freqai_info['rl_config']
|
||||
@@ -65,6 +66,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
self.unset_outlier_removal()
|
||||
self.net_arch = self.rl_config.get('net_arch', [128, 128])
|
||||
self.dd.model_type = import_str
|
||||
self.tensorboard_callback: TensorboardCallback = \
|
||||
TensorboardCallback(verbose=1, actions=BaseActions)
|
||||
|
||||
def unset_outlier_removal(self):
|
||||
"""
|
||||
@@ -140,22 +143,36 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
env_info = self.pack_env_dict()
|
||||
|
||||
self.train_env = self.MyRLEnv(df=train_df,
|
||||
prices=prices_train,
|
||||
window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params,
|
||||
config=self.config,
|
||||
dp=self.data_provider)
|
||||
**env_info)
|
||||
self.eval_env = Monitor(self.MyRLEnv(df=test_df,
|
||||
prices=prices_test,
|
||||
window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params,
|
||||
config=self.config,
|
||||
dp=self.data_provider))
|
||||
**env_info))
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=len(train_df),
|
||||
best_model_save_path=str(dk.data_path))
|
||||
|
||||
actions = self.train_env.get_actions()
|
||||
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
||||
|
||||
def pack_env_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Create dictionary of environment arguments
|
||||
"""
|
||||
env_info = {"window_size": self.CONV_WIDTH,
|
||||
"reward_kwargs": self.reward_params,
|
||||
"config": self.config,
|
||||
"live": self.live,
|
||||
"can_short": self.can_short}
|
||||
if self.data_provider:
|
||||
env_info["fee"] = self.data_provider._exchange \
|
||||
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
|
||||
|
||||
return env_info
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||
"""
|
||||
@@ -263,26 +280,36 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
# %-raw_volume_gen_shift-2_ETH/USDT_1h
|
||||
# price data for model training and evaluation
|
||||
tf = self.config['timeframe']
|
||||
ohlc_list = [f'%-{pair}raw_open_{tf}', f'%-{pair}raw_low_{tf}',
|
||||
f'%-{pair}raw_high_{tf}', f'%-{pair}raw_close_{tf}']
|
||||
rename_dict = {f'%-{pair}raw_open_{tf}': 'open', f'%-{pair}raw_low_{tf}': 'low',
|
||||
f'%-{pair}raw_high_{tf}': ' high', f'%-{pair}raw_close_{tf}': 'close'}
|
||||
rename_dict = {'%-raw_open': 'open', '%-raw_low': 'low',
|
||||
'%-raw_high': ' high', '%-raw_close': 'close'}
|
||||
rename_dict_old = {f'%-{pair}raw_open_{tf}': 'open', f'%-{pair}raw_low_{tf}': 'low',
|
||||
f'%-{pair}raw_high_{tf}': ' high', f'%-{pair}raw_close_{tf}': 'close'}
|
||||
|
||||
prices_train = train_df.filter(rename_dict.keys(), axis=1)
|
||||
prices_train_old = train_df.filter(rename_dict_old.keys(), axis=1)
|
||||
if prices_train.empty or not prices_train_old.empty:
|
||||
if not prices_train_old.empty:
|
||||
prices_train = prices_train_old
|
||||
rename_dict = rename_dict_old
|
||||
logger.warning('Reinforcement learning module didnt find the correct raw prices '
|
||||
'assigned in feature_engineering_standard(). '
|
||||
'Please assign them with:\n'
|
||||
'dataframe["%-raw_close"] = dataframe["close"]\n'
|
||||
'dataframe["%-raw_open"] = dataframe["open"]\n'
|
||||
'dataframe["%-raw_high"] = dataframe["high"]\n'
|
||||
'dataframe["%-raw_low"] = dataframe["low"]\n'
|
||||
'inside `feature_engineering_standard()')
|
||||
elif prices_train.empty:
|
||||
raise OperationalException("No prices found, please follow log warning "
|
||||
"instructions to correct the strategy.")
|
||||
|
||||
prices_train = train_df.filter(ohlc_list, axis=1)
|
||||
if prices_train.empty:
|
||||
raise OperationalException('Reinforcement learning module didnt find the raw prices '
|
||||
'assigned in populate_any_indicators. Please assign them '
|
||||
'with:\n'
|
||||
'informative[f"%-{pair}raw_close"] = informative["close"]\n'
|
||||
'informative[f"%-{pair}raw_open"] = informative["open"]\n'
|
||||
'informative[f"%-{pair}raw_high"] = informative["high"]\n'
|
||||
'informative[f"%-{pair}raw_low"] = informative["low"]\n')
|
||||
prices_train.rename(columns=rename_dict, inplace=True)
|
||||
prices_train.reset_index(drop=True)
|
||||
|
||||
prices_test = test_df.filter(ohlc_list, axis=1)
|
||||
prices_test = test_df.filter(rename_dict.keys(), axis=1)
|
||||
prices_test.rename(columns=rename_dict, inplace=True)
|
||||
prices_test.reset_index(drop=True)
|
||||
|
||||
@@ -377,8 +404,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
||||
seed: int, train_df: DataFrame, price: DataFrame,
|
||||
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
||||
config: Dict[str, Any] = {}) -> Callable:
|
||||
monitor: bool = False,
|
||||
env_info: Dict[str, Any] = {}) -> Callable:
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
@@ -386,13 +413,14 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
||||
:param num_env: (int) the number of environment you wish to have in subprocesses
|
||||
:param seed: (int) the inital seed for RNG
|
||||
:param rank: (int) index of the subprocess
|
||||
:param env_info: (dict) all required arguments to instantiate the environment.
|
||||
:return: (Callable)
|
||||
"""
|
||||
|
||||
def _init() -> gym.Env:
|
||||
|
||||
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
||||
reward_kwargs=reward_params, id=env_id, seed=seed + rank, config=config)
|
||||
env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank,
|
||||
**env_info)
|
||||
if monitor:
|
||||
env = Monitor(env)
|
||||
return env
|
||||
|
59
freqtrade/freqai/RL/TensorboardCallback.py
Normal file
59
freqtrade/freqai/RL/TensorboardCallback.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Type, Union
|
||||
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.logger import HParam
|
||||
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment
|
||||
|
||||
|
||||
class TensorboardCallback(BaseCallback):
|
||||
"""
|
||||
Custom callback for plotting additional values in tensorboard and
|
||||
episodic summary reports.
|
||||
"""
|
||||
def __init__(self, verbose=1, actions: Type[Enum] = BaseActions):
|
||||
super(TensorboardCallback, self).__init__(verbose)
|
||||
self.model: Any = None
|
||||
self.logger = None # type: Any
|
||||
self.training_env: BaseEnvironment = None # type: ignore
|
||||
self.actions: Type[Enum] = actions
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
hparam_dict = {
|
||||
"algorithm": self.model.__class__.__name__,
|
||||
"learning_rate": self.model.learning_rate,
|
||||
# "gamma": self.model.gamma,
|
||||
# "gae_lambda": self.model.gae_lambda,
|
||||
# "batch_size": self.model.batch_size,
|
||||
# "n_steps": self.model.n_steps,
|
||||
}
|
||||
metric_dict: Dict[str, Union[float, int]] = {
|
||||
"eval/mean_reward": 0,
|
||||
"rollout/ep_rew_mean": 0,
|
||||
"rollout/ep_len_mean": 0,
|
||||
"train/value_loss": 0,
|
||||
"train/explained_variance": 0,
|
||||
}
|
||||
self.logger.record(
|
||||
"hparams",
|
||||
HParam(hparam_dict, metric_dict),
|
||||
exclude=("stdout", "log", "json", "csv"),
|
||||
)
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
|
||||
local_info = self.locals["infos"][0]
|
||||
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
|
||||
|
||||
for info in local_info:
|
||||
if info not in ["episode", "terminal_observation"]:
|
||||
self.logger.record(f"_info/{info}", local_info[info])
|
||||
|
||||
for info in tensorboard_metrics:
|
||||
if info in [action.name for action in self.actions]:
|
||||
self.logger.record(f"_actions/{info}", tensorboard_metrics[info])
|
||||
else:
|
||||
self.logger.record(f"_custom/{info}", tensorboard_metrics[info])
|
||||
|
||||
return True
|
@@ -95,9 +95,14 @@ class BaseClassifierModel(IFreqaiModel):
|
||||
self.data_cleaning_predict(dk)
|
||||
|
||||
predictions = self.model.predict(dk.data_dictionary["prediction_features"])
|
||||
if self.CONV_WIDTH == 1:
|
||||
predictions = np.reshape(predictions, (-1, len(dk.label_list)))
|
||||
|
||||
pred_df = DataFrame(predictions, columns=dk.label_list)
|
||||
|
||||
predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"])
|
||||
if self.CONV_WIDTH == 1:
|
||||
predictions_prob = np.reshape(predictions_prob, (-1, len(self.model.classes_)))
|
||||
pred_df_prob = DataFrame(predictions_prob, columns=self.model.classes_)
|
||||
|
||||
pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
|
||||
|
@@ -95,6 +95,9 @@ class BaseRegressionModel(IFreqaiModel):
|
||||
self.data_cleaning_predict(dk)
|
||||
|
||||
predictions = self.model.predict(dk.data_dictionary["prediction_features"])
|
||||
if self.CONV_WIDTH == 1:
|
||||
predictions = np.reshape(predictions, (-1, len(dk.label_list)))
|
||||
|
||||
pred_df = DataFrame(predictions, columns=dk.label_list)
|
||||
|
||||
pred_df = dk.denormalize_labels_from_metadata(pred_df)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import shutil
|
||||
from datetime import datetime, timezone
|
||||
@@ -23,6 +24,7 @@ from freqtrade.constants import Config
|
||||
from freqtrade.data.converter import reduce_dataframe_footprint
|
||||
from freqtrade.exceptions import OperationalException
|
||||
from freqtrade.exchange import timeframe_to_seconds
|
||||
from freqtrade.strategy import merge_informative_pair
|
||||
from freqtrade.strategy.interface import IStrategy
|
||||
|
||||
|
||||
@@ -1159,9 +1161,9 @@ class FreqaiDataKitchen:
|
||||
|
||||
for pair in pairs:
|
||||
pair = pair.replace(':', '') # lightgbm doesnt like colons
|
||||
valid_strs = [f"%-{pair}", f"%{pair}", f"%_{pair}"]
|
||||
pair_cols = [col for col in dataframe.columns if
|
||||
any(substr in col for substr in valid_strs)]
|
||||
pair_cols = [col for col in dataframe.columns if col.startswith("%")
|
||||
and f"{pair}_" in col]
|
||||
|
||||
if pair_cols:
|
||||
pair_cols.insert(0, 'date')
|
||||
corr_dataframes[pair] = dataframe.filter(pair_cols, axis=1)
|
||||
@@ -1190,6 +1192,103 @@ class FreqaiDataKitchen:
|
||||
|
||||
return dataframe
|
||||
|
||||
def get_pair_data_for_features(self,
|
||||
pair: str,
|
||||
tf: str,
|
||||
strategy: IStrategy,
|
||||
corr_dataframes: dict = {},
|
||||
base_dataframes: dict = {},
|
||||
is_corr_pairs: bool = False) -> DataFrame:
|
||||
"""
|
||||
Get the data for the pair. If it's not in the dictionary, get it from the data provider
|
||||
:param pair: str = pair to get data for
|
||||
:param tf: str = timeframe to get data for
|
||||
:param strategy: IStrategy = user defined strategy object
|
||||
:param corr_dataframes: dict = dict containing the df pair dataframes
|
||||
(for user defined timeframes)
|
||||
:param base_dataframes: dict = dict containing the current pair dataframes
|
||||
(for user defined timeframes)
|
||||
:param is_corr_pairs: bool = whether the pair is a corr pair or not
|
||||
:return: dataframe = dataframe containing the pair data
|
||||
"""
|
||||
if is_corr_pairs:
|
||||
dataframe = corr_dataframes[pair][tf]
|
||||
if not dataframe.empty:
|
||||
return dataframe
|
||||
else:
|
||||
dataframe = strategy.dp.get_pair_dataframe(pair=pair, timeframe=tf)
|
||||
return dataframe
|
||||
else:
|
||||
dataframe = base_dataframes[tf]
|
||||
if not dataframe.empty:
|
||||
return dataframe
|
||||
else:
|
||||
dataframe = strategy.dp.get_pair_dataframe(pair=pair, timeframe=tf)
|
||||
return dataframe
|
||||
|
||||
def merge_features(self, df_main: DataFrame, df_to_merge: DataFrame,
|
||||
tf: str, timeframe_inf: str, suffix: str) -> DataFrame:
|
||||
"""
|
||||
Merge the features of the dataframe and remove HLCV and date added columns
|
||||
:param df_main: DataFrame = main dataframe
|
||||
:param df_to_merge: DataFrame = dataframe to merge
|
||||
:param tf: str = timeframe of the main dataframe
|
||||
:param timeframe_inf: str = timeframe of the dataframe to merge
|
||||
:param suffix: str = suffix to add to the columns of the dataframe to merge
|
||||
:return: dataframe = merged dataframe
|
||||
"""
|
||||
dataframe = merge_informative_pair(df_main, df_to_merge, tf, timeframe_inf=timeframe_inf,
|
||||
append_timeframe=False, suffix=suffix, ffill=True)
|
||||
skip_columns = [
|
||||
(f"{s}_{suffix}") for s in ["date", "open", "high", "low", "close", "volume"]
|
||||
]
|
||||
dataframe = dataframe.drop(columns=skip_columns)
|
||||
return dataframe
|
||||
|
||||
def populate_features(self, dataframe: DataFrame, pair: str, strategy: IStrategy,
|
||||
corr_dataframes: dict, base_dataframes: dict,
|
||||
is_corr_pairs: bool = False) -> DataFrame:
|
||||
"""
|
||||
Use the user defined strategy functions for populating features
|
||||
:param dataframe: DataFrame = dataframe to populate
|
||||
:param pair: str = pair to populate
|
||||
:param strategy: IStrategy = user defined strategy object
|
||||
:param corr_dataframes: dict = dict containing the df pair dataframes
|
||||
:param base_dataframes: dict = dict containing the current pair dataframes
|
||||
:param is_corr_pairs: bool = whether the pair is a corr pair or not
|
||||
:return: dataframe = populated dataframe
|
||||
"""
|
||||
tfs: List[str] = self.freqai_config["feature_parameters"].get("include_timeframes")
|
||||
|
||||
for tf in tfs:
|
||||
informative_df = self.get_pair_data_for_features(
|
||||
pair, tf, strategy, corr_dataframes, base_dataframes, is_corr_pairs)
|
||||
informative_copy = informative_df.copy()
|
||||
|
||||
for t in self.freqai_config["feature_parameters"]["indicator_periods_candles"]:
|
||||
df_features = strategy.feature_engineering_expand_all(
|
||||
informative_copy.copy(), t)
|
||||
suffix = f"{t}"
|
||||
informative_df = self.merge_features(informative_df, df_features, tf, tf, suffix)
|
||||
|
||||
generic_df = strategy.feature_engineering_expand_basic(informative_copy.copy())
|
||||
suffix = "gen"
|
||||
|
||||
informative_df = self.merge_features(informative_df, generic_df, tf, tf, suffix)
|
||||
|
||||
indicators = [col for col in informative_df if col.startswith("%")]
|
||||
for n in range(self.freqai_config["feature_parameters"]["include_shifted_candles"] + 1):
|
||||
if n == 0:
|
||||
continue
|
||||
df_shift = informative_df[indicators].shift(n)
|
||||
df_shift = df_shift.add_suffix("_shift-" + str(n))
|
||||
informative_df = pd.concat((informative_df, df_shift), axis=1)
|
||||
|
||||
dataframe = self.merge_features(dataframe.copy(), informative_df,
|
||||
self.config["timeframe"], tf, f'{pair}_{tf}')
|
||||
|
||||
return dataframe
|
||||
|
||||
def use_strategy_to_populate_indicators(
|
||||
self,
|
||||
strategy: IStrategy,
|
||||
@@ -1202,7 +1301,87 @@ class FreqaiDataKitchen:
|
||||
"""
|
||||
Use the user defined strategy for populating indicators during retrain
|
||||
:param strategy: IStrategy = user defined strategy object
|
||||
:param corr_dataframes: dict = dict containing the informative pair dataframes
|
||||
:param corr_dataframes: dict = dict containing the df pair dataframes
|
||||
(for user defined timeframes)
|
||||
:param base_dataframes: dict = dict containing the current pair dataframes
|
||||
(for user defined timeframes)
|
||||
:param pair: str = pair to populate
|
||||
:param prediction_dataframe: DataFrame = dataframe containing the pair data
|
||||
used for prediction
|
||||
:param do_corr_pairs: bool = whether to populate corr pairs or not
|
||||
:return:
|
||||
dataframe: DataFrame = dataframe containing populated indicators
|
||||
"""
|
||||
|
||||
# this is a hack to check if the user is using the populate_any_indicators function
|
||||
new_version = inspect.getsource(strategy.populate_any_indicators) == (
|
||||
inspect.getsource(IStrategy.populate_any_indicators))
|
||||
|
||||
if new_version:
|
||||
tfs: List[str] = self.freqai_config["feature_parameters"].get("include_timeframes")
|
||||
pairs: List[str] = self.freqai_config["feature_parameters"].get(
|
||||
"include_corr_pairlist", [])
|
||||
|
||||
for tf in tfs:
|
||||
if tf not in base_dataframes:
|
||||
base_dataframes[tf] = pd.DataFrame()
|
||||
for p in pairs:
|
||||
if p not in corr_dataframes:
|
||||
corr_dataframes[p] = {}
|
||||
if tf not in corr_dataframes[p]:
|
||||
corr_dataframes[p][tf] = pd.DataFrame()
|
||||
|
||||
if not prediction_dataframe.empty:
|
||||
dataframe = prediction_dataframe.copy()
|
||||
else:
|
||||
dataframe = base_dataframes[self.config["timeframe"]].copy()
|
||||
|
||||
corr_pairs: List[str] = self.freqai_config["feature_parameters"].get(
|
||||
"include_corr_pairlist", [])
|
||||
dataframe = self.populate_features(dataframe.copy(), pair, strategy,
|
||||
corr_dataframes, base_dataframes)
|
||||
|
||||
dataframe = strategy.feature_engineering_standard(dataframe.copy())
|
||||
# ensure corr pairs are always last
|
||||
for corr_pair in corr_pairs:
|
||||
if pair == corr_pair:
|
||||
continue # dont repeat anything from whitelist
|
||||
if corr_pairs and do_corr_pairs:
|
||||
dataframe = self.populate_features(dataframe.copy(), corr_pair, strategy,
|
||||
corr_dataframes, base_dataframes, True)
|
||||
|
||||
dataframe = strategy.set_freqai_targets(dataframe.copy())
|
||||
|
||||
self.get_unique_classes_from_labels(dataframe)
|
||||
|
||||
dataframe = self.remove_special_chars_from_feature_names(dataframe)
|
||||
|
||||
if self.config.get('reduce_df_footprint', False):
|
||||
dataframe = reduce_dataframe_footprint(dataframe)
|
||||
|
||||
return dataframe
|
||||
|
||||
else:
|
||||
# the user is using the populate_any_indicators functions which is deprecated
|
||||
|
||||
df = self.use_strategy_to_populate_indicators_old_version(
|
||||
strategy, corr_dataframes, base_dataframes, pair,
|
||||
prediction_dataframe, do_corr_pairs)
|
||||
return df
|
||||
|
||||
def use_strategy_to_populate_indicators_old_version(
|
||||
self,
|
||||
strategy: IStrategy,
|
||||
corr_dataframes: dict = {},
|
||||
base_dataframes: dict = {},
|
||||
pair: str = "",
|
||||
prediction_dataframe: DataFrame = pd.DataFrame(),
|
||||
do_corr_pairs: bool = True,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Use the user defined strategy for populating indicators during retrain
|
||||
:param strategy: IStrategy = user defined strategy object
|
||||
:param corr_dataframes: dict = dict containing the df pair dataframes
|
||||
(for user defined timeframes)
|
||||
:param base_dataframes: dict = dict containing the current pair dataframes
|
||||
(for user defined timeframes)
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
@@ -104,6 +105,9 @@ class IFreqaiModel(ABC):
|
||||
self.metadata: Dict[str, Any] = self.dd.load_global_metadata_from_disk()
|
||||
self.data_provider: Optional[DataProvider] = None
|
||||
self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1)
|
||||
self.can_short = True # overridden in start() with strategy.can_short
|
||||
|
||||
self.warned_deprecated_populate_any_indicators = False
|
||||
|
||||
record_params(config, self.full_path)
|
||||
|
||||
@@ -133,6 +137,10 @@ class IFreqaiModel(ABC):
|
||||
self.live = strategy.dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
|
||||
self.dd.set_pair_dict_info(metadata)
|
||||
self.data_provider = strategy.dp
|
||||
self.can_short = strategy.can_short
|
||||
|
||||
# check if the strategy has deprecated populate_any_indicators function
|
||||
self.check_deprecated_populate_any_indicators(strategy)
|
||||
|
||||
if self.live:
|
||||
self.inference_timer('start')
|
||||
@@ -147,12 +155,9 @@ class IFreqaiModel(ABC):
|
||||
# the concatenated results for the full backtesting period back to the strategy.
|
||||
elif not self.follow_mode:
|
||||
self.dk = FreqaiDataKitchen(self.config, self.live, metadata["pair"])
|
||||
dataframe = self.dk.use_strategy_to_populate_indicators(
|
||||
strategy, prediction_dataframe=dataframe, pair=metadata["pair"]
|
||||
)
|
||||
if not self.config.get("freqai_backtest_live_models", False):
|
||||
logger.info(f"Training {len(self.dk.training_timeranges)} timeranges")
|
||||
dk = self.start_backtesting(dataframe, metadata, self.dk)
|
||||
dk = self.start_backtesting(dataframe, metadata, self.dk, strategy)
|
||||
dataframe = dk.remove_features_from_df(dk.return_dataframe)
|
||||
else:
|
||||
logger.info(
|
||||
@@ -253,7 +258,7 @@ class IFreqaiModel(ABC):
|
||||
self.dd.save_metric_tracker_to_disk()
|
||||
|
||||
def start_backtesting(
|
||||
self, dataframe: DataFrame, metadata: dict, dk: FreqaiDataKitchen
|
||||
self, dataframe: DataFrame, metadata: dict, dk: FreqaiDataKitchen, strategy: IStrategy
|
||||
) -> FreqaiDataKitchen:
|
||||
"""
|
||||
The main broad execution for backtesting. For backtesting, each pair enters and then gets
|
||||
@@ -265,19 +270,22 @@ class IFreqaiModel(ABC):
|
||||
:param dataframe: DataFrame = strategy passed dataframe
|
||||
:param metadata: Dict = pair metadata
|
||||
:param dk: FreqaiDataKitchen = Data management/analysis tool associated to present pair only
|
||||
:param strategy: Strategy to train on
|
||||
:return:
|
||||
FreqaiDataKitchen = Data management/analysis tool associated to present pair only
|
||||
"""
|
||||
|
||||
self.pair_it += 1
|
||||
train_it = 0
|
||||
pair = metadata["pair"]
|
||||
populate_indicators = True
|
||||
check_features = True
|
||||
# Loop enforcing the sliding window training/backtesting paradigm
|
||||
# tr_train is the training time range e.g. 1 historical month
|
||||
# tr_backtest is the backtesting time range e.g. the week directly
|
||||
# following tr_train. Both of these windows slide through the
|
||||
# entire backtest
|
||||
for tr_train, tr_backtest in zip(dk.training_timeranges, dk.backtesting_timeranges):
|
||||
pair = metadata["pair"]
|
||||
(_, _, _) = self.dd.get_pair_dict_info(pair)
|
||||
train_it += 1
|
||||
total_trains = len(dk.backtesting_timeranges)
|
||||
@@ -299,18 +307,42 @@ class IFreqaiModel(ABC):
|
||||
dk.set_new_model_names(pair, timestamp_model_id)
|
||||
|
||||
if dk.check_if_backtest_prediction_is_valid(len_backtest_df):
|
||||
self.dd.load_metadata(dk)
|
||||
dk.find_features(dataframe)
|
||||
self.check_if_feature_list_matches_strategy(dk)
|
||||
if check_features:
|
||||
self.dd.load_metadata(dk)
|
||||
dataframe_dummy_features = self.dk.use_strategy_to_populate_indicators(
|
||||
strategy, prediction_dataframe=dataframe.tail(1), pair=metadata["pair"]
|
||||
)
|
||||
dk.find_features(dataframe_dummy_features)
|
||||
self.check_if_feature_list_matches_strategy(dk)
|
||||
check_features = False
|
||||
append_df = dk.get_backtesting_prediction()
|
||||
dk.append_predictions(append_df)
|
||||
else:
|
||||
dataframe_train = dk.slice_dataframe(tr_train, dataframe)
|
||||
dataframe_backtest = dk.slice_dataframe(tr_backtest, dataframe)
|
||||
if populate_indicators:
|
||||
dataframe = self.dk.use_strategy_to_populate_indicators(
|
||||
strategy, prediction_dataframe=dataframe, pair=metadata["pair"]
|
||||
)
|
||||
populate_indicators = False
|
||||
|
||||
dataframe_base_train = dataframe.loc[dataframe["date"] < tr_train.stopdt, :]
|
||||
dataframe_base_train = strategy.set_freqai_targets(dataframe_base_train)
|
||||
dataframe_base_backtest = dataframe.loc[dataframe["date"] < tr_backtest.stopdt, :]
|
||||
dataframe_base_backtest = strategy.set_freqai_targets(dataframe_base_backtest)
|
||||
|
||||
dataframe_train = dk.slice_dataframe(tr_train, dataframe_base_train)
|
||||
dataframe_backtest = dk.slice_dataframe(tr_backtest, dataframe_base_backtest)
|
||||
|
||||
if not self.model_exists(dk):
|
||||
dk.find_features(dataframe_train)
|
||||
dk.find_labels(dataframe_train)
|
||||
self.model = self.train(dataframe_train, pair, dk)
|
||||
|
||||
try:
|
||||
self.model = self.train(dataframe_train, pair, dk)
|
||||
except Exception as msg:
|
||||
logger.warning(
|
||||
f"Training {pair} raised exception {msg.__class__.__name__}. "
|
||||
f"Message: {msg}, skipping.")
|
||||
|
||||
self.dd.pair_dict[pair]["trained_timestamp"] = int(
|
||||
tr_train.stopts)
|
||||
if self.plot_features:
|
||||
@@ -347,7 +379,6 @@ class IFreqaiModel(ABC):
|
||||
:returns:
|
||||
dk: FreqaiDataKitchen = Data management/analysis tool associated to present pair only
|
||||
"""
|
||||
|
||||
# update follower
|
||||
if self.follow_mode:
|
||||
self.dd.update_follower_metadata()
|
||||
@@ -911,9 +942,28 @@ class IFreqaiModel(ABC):
|
||||
dk.return_dataframe = dk.return_dataframe.drop(columns=list(columns_to_drop))
|
||||
dk.return_dataframe = pd.merge(
|
||||
dk.return_dataframe, saved_dataframe, how='left', left_on='date', right_on="date_pred")
|
||||
# dk.return_dataframe = dk.return_dataframe[saved_dataframe.columns].fillna(0)
|
||||
return dk
|
||||
|
||||
def check_deprecated_populate_any_indicators(self, strategy: IStrategy):
|
||||
"""
|
||||
Check and warn if the deprecated populate_any_indicators function is used.
|
||||
:param strategy: strategy object
|
||||
"""
|
||||
|
||||
if not self.warned_deprecated_populate_any_indicators:
|
||||
self.warned_deprecated_populate_any_indicators = True
|
||||
old_version = inspect.getsource(strategy.populate_any_indicators) != (
|
||||
inspect.getsource(IStrategy.populate_any_indicators))
|
||||
|
||||
if old_version:
|
||||
logger.warning("DEPRECATION WARNING: "
|
||||
"You are using the deprecated populate_any_indicators function. "
|
||||
"This function will raise an error on March 1 2023. "
|
||||
"Please update your strategy by using "
|
||||
"the new feature_engineering functions. See \n"
|
||||
"https://www.freqtrade.io/en/latest/freqai-feature-engineering/"
|
||||
"for details.")
|
||||
|
||||
# Following methods which are overridden by user made prediction models.
|
||||
# See freqai/prediction_models/CatboostPredictionModel.py for an example.
|
||||
|
||||
|
@@ -61,7 +61,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=Path(
|
||||
dk.full_path / "tensorboard" / dk.pair.split('/')[0]),
|
||||
**self.freqai_info['model_training_parameters']
|
||||
**self.freqai_info.get('model_training_parameters', {})
|
||||
)
|
||||
else:
|
||||
logger.info('Continual training activated - starting training from previously '
|
||||
@@ -71,7 +71,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
|
||||
model.learn(
|
||||
total_timesteps=int(total_timesteps),
|
||||
callback=self.eval_callback
|
||||
callback=[self.eval_callback, self.tensorboard_callback]
|
||||
)
|
||||
|
||||
if Path(dk.data_path / "best_model.zip").is_file():
|
||||
@@ -100,13 +100,17 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
"""
|
||||
# first, penalize if the action is not valid
|
||||
if not self._is_valid(action):
|
||||
self.tensorboard_log("is_valid")
|
||||
return -2
|
||||
|
||||
pnl = self.get_unrealized_profit()
|
||||
factor = 100.
|
||||
|
||||
# reward agent for entering trades
|
||||
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
||||
if (action == Actions.Long_enter.value
|
||||
and self._position == Positions.Neutral):
|
||||
return 25
|
||||
if (action == Actions.Short_enter.value
|
||||
and self._position == Positions.Neutral):
|
||||
return 25
|
||||
# discourage agent from not entering trades
|
||||
|
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from typing import Any, Dict # , Tuple
|
||||
from typing import Any, Dict
|
||||
|
||||
# import numpy.typing as npt
|
||||
from pandas import DataFrame
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
@@ -9,6 +8,7 @@ from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
|
||||
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -34,18 +34,24 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
env_info = self.pack_env_dict()
|
||||
|
||||
env_id = "train_env"
|
||||
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
|
||||
self.reward_params, self.CONV_WIDTH, monitor=True,
|
||||
config=self.config) for i
|
||||
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1,
|
||||
train_df, prices_train,
|
||||
monitor=True,
|
||||
env_info=env_info) for i
|
||||
in range(self.max_threads)])
|
||||
|
||||
eval_env_id = 'eval_env'
|
||||
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
|
||||
test_df, prices_test,
|
||||
self.reward_params, self.CONV_WIDTH, monitor=True,
|
||||
config=self.config) for i
|
||||
monitor=True,
|
||||
env_info=env_info) for i
|
||||
in range(self.max_threads)])
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=len(train_df),
|
||||
best_model_save_path=str(dk.data_path))
|
||||
|
||||
actions = self.train_env.env_method("get_actions")[0]
|
||||
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
||||
|
Reference in New Issue
Block a user