reduce code for base use-case, ensure multiproc inherits custom env, add ability to limit ram use.
This commit is contained in:
parent
05ccebf9a1
commit
3199eb453b
@ -58,6 +58,7 @@
|
|||||||
"model_save_type": "stable_baselines",
|
"model_save_type": "stable_baselines",
|
||||||
"conv_width": 4,
|
"conv_width": 4,
|
||||||
"purge_old_models": true,
|
"purge_old_models": true,
|
||||||
|
"limit_ram_usage": false,
|
||||||
"train_period_days": 5,
|
"train_period_days": 5,
|
||||||
"backtest_period_days": 2,
|
"backtest_period_days": 2,
|
||||||
"identifier": "unique-id",
|
"identifier": "unique-id",
|
||||||
|
@ -19,6 +19,7 @@ from typing import Callable
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from stable_baselines3.common.utils import set_random_seed
|
from stable_baselines3.common.utils import set_random_seed
|
||||||
import gym
|
import gym
|
||||||
|
from pathlib import Path
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||||
@ -110,9 +111,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
train_df = data_dictionary["train_features"]
|
train_df = data_dictionary["train_features"]
|
||||||
test_df = data_dictionary["test_features"]
|
test_df = data_dictionary["test_features"]
|
||||||
|
|
||||||
self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH,
|
self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH,
|
||||||
reward_kwargs=self.reward_params, config=self.config)
|
reward_kwargs=self.reward_params, config=self.config)
|
||||||
self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test,
|
self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test,
|
||||||
window_size=self.CONV_WIDTH,
|
window_size=self.CONV_WIDTH,
|
||||||
reward_kwargs=self.reward_params, config=self.config))
|
reward_kwargs=self.reward_params, config=self.config))
|
||||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||||
@ -126,7 +127,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
go in here. Abstract method, so this function must be overridden by
|
go in here. Abstract method, so this function must be overridden by
|
||||||
user class.
|
user class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def get_state_info(self, pair: str):
|
def get_state_info(self, pair: str):
|
||||||
@ -232,6 +232,72 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
return prices_train, prices_test
|
return prices_train, prices_test
|
||||||
|
|
||||||
|
def load_model_from_disk(self, dk: FreqaiDataKitchen) -> Any:
|
||||||
|
"""
|
||||||
|
Can be used by user if they are trying to limit_ram_usage *and*
|
||||||
|
perform continual learning.
|
||||||
|
For now, this is unused.
|
||||||
|
"""
|
||||||
|
exists = Path(dk.data_path / f"{dk.model_filename}_model").is_file()
|
||||||
|
if exists:
|
||||||
|
model = self.MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")
|
||||||
|
else:
|
||||||
|
logger.info('No model file on disk to continue learning from.')
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
# Nested class which can be overridden by user to customize further
|
||||||
|
class MyRLEnv(Base5ActionRLEnv):
|
||||||
|
"""
|
||||||
|
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||||
|
sets a custom reward based on profit and trade duration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def calculate_reward(self, action):
|
||||||
|
|
||||||
|
# first, penalize if the action is not valid
|
||||||
|
if not self._is_valid(action):
|
||||||
|
return -2
|
||||||
|
|
||||||
|
pnl = self.get_unrealized_profit()
|
||||||
|
rew = np.sign(pnl) * (pnl + 1)
|
||||||
|
factor = 100
|
||||||
|
|
||||||
|
# reward agent for entering trades
|
||||||
|
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \
|
||||||
|
and self._position == Positions.Neutral:
|
||||||
|
return 25
|
||||||
|
# discourage agent from not entering trades
|
||||||
|
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
||||||
|
trade_duration = self._current_tick - self._last_trade_tick
|
||||||
|
|
||||||
|
if trade_duration <= max_trade_duration:
|
||||||
|
factor *= 1.5
|
||||||
|
elif trade_duration > max_trade_duration:
|
||||||
|
factor *= 0.5
|
||||||
|
|
||||||
|
# discourage sitting in position
|
||||||
|
if self._position in (Positions.Short, Positions.Long) and \
|
||||||
|
action == Actions.Neutral.value:
|
||||||
|
return -1 * trade_duration / max_trade_duration
|
||||||
|
|
||||||
|
# close long
|
||||||
|
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
||||||
|
if pnl > self.profit_aim * self.rr:
|
||||||
|
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||||
|
return float(rew * factor)
|
||||||
|
|
||||||
|
# close short
|
||||||
|
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
||||||
|
if pnl > self.profit_aim * self.rr:
|
||||||
|
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||||
|
return float(rew * factor)
|
||||||
|
|
||||||
|
return 0.
|
||||||
|
|
||||||
# TODO take care of this appendage. Right now it needs to be called because FreqAI enforces it.
|
# TODO take care of this appendage. Right now it needs to be called because FreqAI enforces it.
|
||||||
# But FreqaiRL needs more objects passed to fit() (like DK) and we dont want to go refactor
|
# But FreqaiRL needs more objects passed to fit() (like DK) and we dont want to go refactor
|
||||||
# all the other existing fit() functions to include dk argument. For now we instantiate and
|
# all the other existing fit() functions to include dk argument. For now we instantiate and
|
||||||
@ -240,7 +306,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def make_env(env_id: str, rank: int, seed: int, train_df: DataFrame, price: DataFrame,
|
def make_env(MyRLEnv: Base5ActionRLEnv, env_id: str, rank: int,
|
||||||
|
seed: int, train_df: DataFrame, price: DataFrame,
|
||||||
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
||||||
config: Dict[str, Any] = {}) -> Callable:
|
config: Dict[str, Any] = {}) -> Callable:
|
||||||
"""
|
"""
|
||||||
@ -252,6 +319,7 @@ def make_env(env_id: str, rank: int, seed: int, train_df: DataFrame, price: Data
|
|||||||
:param rank: (int) index of the subprocess
|
:param rank: (int) index of the subprocess
|
||||||
:return: (Callable)
|
:return: (Callable)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _init() -> gym.Env:
|
def _init() -> gym.Env:
|
||||||
|
|
||||||
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
||||||
@ -261,54 +329,3 @@ def make_env(env_id: str, rank: int, seed: int, train_df: DataFrame, price: Data
|
|||||||
return env
|
return env
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
return _init
|
return _init
|
||||||
|
|
||||||
|
|
||||||
class MyRLEnv(Base5ActionRLEnv):
|
|
||||||
"""
|
|
||||||
User can override any function in BaseRLEnv and gym.Env. Here the user
|
|
||||||
sets a custom reward based on profit and trade duration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def calculate_reward(self, action):
|
|
||||||
|
|
||||||
# first, penalize if the action is not valid
|
|
||||||
if not self._is_valid(action):
|
|
||||||
return -2
|
|
||||||
|
|
||||||
pnl = self.get_unrealized_profit()
|
|
||||||
rew = np.sign(pnl) * (pnl + 1)
|
|
||||||
factor = 100
|
|
||||||
|
|
||||||
# reward agent for entering trades
|
|
||||||
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \
|
|
||||||
and self._position == Positions.Neutral:
|
|
||||||
return 25
|
|
||||||
# discourage agent from not entering trades
|
|
||||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
|
||||||
return -1
|
|
||||||
|
|
||||||
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
|
||||||
trade_duration = self._current_tick - self._last_trade_tick
|
|
||||||
|
|
||||||
if trade_duration <= max_trade_duration:
|
|
||||||
factor *= 1.5
|
|
||||||
elif trade_duration > max_trade_duration:
|
|
||||||
factor *= 0.5
|
|
||||||
|
|
||||||
# discourage sitting in position
|
|
||||||
if self._position in (Positions.Short, Positions.Long) and action == Actions.Neutral.value:
|
|
||||||
return -1 * trade_duration / max_trade_duration
|
|
||||||
|
|
||||||
# close long
|
|
||||||
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
|
||||||
if pnl > self.profit_aim * self.rr:
|
|
||||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
|
||||||
return float(rew * factor)
|
|
||||||
|
|
||||||
# close short
|
|
||||||
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
|
||||||
if pnl > self.profit_aim * self.rr:
|
|
||||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
|
||||||
return float(rew * factor)
|
|
||||||
|
|
||||||
return 0.
|
|
||||||
|
@ -90,6 +90,7 @@ class FreqaiDataDrawer:
|
|||||||
self.empty_pair_dict: pair_info = {
|
self.empty_pair_dict: pair_info = {
|
||||||
"model_filename": "", "trained_timestamp": 0,
|
"model_filename": "", "trained_timestamp": 0,
|
||||||
"priority": 1, "first": True, "data_path": "", "extras": {}}
|
"priority": 1, "first": True, "data_path": "", "extras": {}}
|
||||||
|
self.limit_ram_use = self.freqai_info.get('limit_ram_usage', False)
|
||||||
|
|
||||||
def load_drawer_from_disk(self):
|
def load_drawer_from_disk(self):
|
||||||
"""
|
"""
|
||||||
@ -423,8 +424,8 @@ class FreqaiDataDrawer:
|
|||||||
dk.pca, open(dk.data_path / f"{dk.model_filename}_pca_object.pkl", "wb")
|
dk.pca, open(dk.data_path / f"{dk.model_filename}_pca_object.pkl", "wb")
|
||||||
)
|
)
|
||||||
|
|
||||||
# if self.live:
|
if not self.limit_ram_use:
|
||||||
self.model_dictionary[coin] = model
|
self.model_dictionary[coin] = model
|
||||||
self.pair_dict[coin]["model_filename"] = dk.model_filename
|
self.pair_dict[coin]["model_filename"] = dk.model_filename
|
||||||
self.pair_dict[coin]["data_path"] = str(dk.data_path)
|
self.pair_dict[coin]["data_path"] = str(dk.data_path)
|
||||||
self.save_drawer_to_disk()
|
self.save_drawer_to_disk()
|
||||||
@ -464,7 +465,7 @@ class FreqaiDataDrawer:
|
|||||||
|
|
||||||
model_type = self.freqai_info.get('model_save_type', 'joblib')
|
model_type = self.freqai_info.get('model_save_type', 'joblib')
|
||||||
# try to access model in memory instead of loading object from disk to save time
|
# try to access model in memory instead of loading object from disk to save time
|
||||||
if dk.live and coin in self.model_dictionary:
|
if dk.live and coin in self.model_dictionary and not self.limit_ram_use:
|
||||||
model = self.model_dictionary[coin]
|
model = self.model_dictionary[coin]
|
||||||
elif model_type == 'joblib':
|
elif model_type == 'joblib':
|
||||||
model = load(dk.data_path / f"{dk.model_filename}_model.joblib")
|
model = load(dk.data_path / f"{dk.model_filename}_model.joblib")
|
||||||
@ -486,7 +487,7 @@ class FreqaiDataDrawer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# load it into ram if it was loaded from disk
|
# load it into ram if it was loaded from disk
|
||||||
if coin not in self.model_dictionary:
|
if coin not in self.model_dictionary and not self.limit_ram_use:
|
||||||
self.model_dictionary[coin] = model
|
self.model_dictionary[coin] = model
|
||||||
|
|
||||||
if self.config["freqai"]["feature_parameters"]["principal_component_analysis"]:
|
if self.config["freqai"]["feature_parameters"]["principal_component_analysis"]:
|
||||||
|
@ -3,12 +3,12 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
import torch as th
|
import torch as th
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
|
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Positions
|
||||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
|
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pandas import DataFrame
|
# from pandas import DataFrame
|
||||||
from stable_baselines3.common.callbacks import EvalCallback
|
# from stable_baselines3.common.callbacks import EvalCallback
|
||||||
from stable_baselines3.common.monitor import Monitor
|
# from stable_baselines3.common.monitor import Monitor
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -53,71 +53,53 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def set_train_and_eval_environments(self, data_dictionary: Dict[str, DataFrame],
|
class MyRLEnv(BaseReinforcementLearningModel.MyRLEnv):
|
||||||
prices_train: DataFrame, prices_test: DataFrame,
|
|
||||||
dk: FreqaiDataKitchen):
|
|
||||||
"""
|
"""
|
||||||
User can override this if they are using a custom MyRLEnv
|
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||||
|
sets a custom reward based on profit and trade duration.
|
||||||
"""
|
"""
|
||||||
train_df = data_dictionary["train_features"]
|
|
||||||
test_df = data_dictionary["test_features"]
|
|
||||||
|
|
||||||
self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH,
|
def calculate_reward(self, action):
|
||||||
reward_kwargs=self.reward_params, config=self.config)
|
|
||||||
self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test,
|
|
||||||
window_size=self.CONV_WIDTH,
|
|
||||||
reward_kwargs=self.reward_params, config=self.config))
|
|
||||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
|
||||||
render=False, eval_freq=len(train_df),
|
|
||||||
best_model_save_path=str(dk.data_path))
|
|
||||||
|
|
||||||
|
# first, penalize if the action is not valid
|
||||||
|
if not self._is_valid(action):
|
||||||
|
return -2
|
||||||
|
|
||||||
class MyRLEnv(Base5ActionRLEnv):
|
pnl = self.get_unrealized_profit()
|
||||||
"""
|
rew = np.sign(pnl) * (pnl + 1)
|
||||||
User can override any function in BaseRLEnv and gym.Env. Here the user
|
factor = 100
|
||||||
sets a custom reward based on profit and trade duration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def calculate_reward(self, action):
|
# reward agent for entering trades
|
||||||
|
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \
|
||||||
|
and self._position == Positions.Neutral:
|
||||||
|
return 25
|
||||||
|
# discourage agent from not entering trades
|
||||||
|
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||||
|
return -1
|
||||||
|
|
||||||
# first, penalize if the action is not valid
|
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
||||||
if not self._is_valid(action):
|
trade_duration = self._current_tick - self._last_trade_tick
|
||||||
return -2
|
|
||||||
|
|
||||||
pnl = self.get_unrealized_profit()
|
if trade_duration <= max_trade_duration:
|
||||||
rew = np.sign(pnl) * (pnl + 1)
|
factor *= 1.5
|
||||||
factor = 100
|
elif trade_duration > max_trade_duration:
|
||||||
|
factor *= 0.5
|
||||||
|
|
||||||
# reward agent for entering trades
|
# discourage sitting in position
|
||||||
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \
|
if self._position in (Positions.Short, Positions.Long) and \
|
||||||
and self._position == Positions.Neutral:
|
action == Actions.Neutral.value:
|
||||||
return 25
|
return -1 * trade_duration / max_trade_duration
|
||||||
# discourage agent from not entering trades
|
|
||||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
|
||||||
return -1
|
|
||||||
|
|
||||||
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
# close long
|
||||||
trade_duration = self._current_tick - self._last_trade_tick
|
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
||||||
|
if pnl > self.profit_aim * self.rr:
|
||||||
|
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||||
|
return float(rew * factor)
|
||||||
|
|
||||||
if trade_duration <= max_trade_duration:
|
# close short
|
||||||
factor *= 1.5
|
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
||||||
elif trade_duration > max_trade_duration:
|
if pnl > self.profit_aim * self.rr:
|
||||||
factor *= 0.5
|
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||||
|
return float(rew * factor)
|
||||||
|
|
||||||
# discourage sitting in position
|
return 0.
|
||||||
if self._position in (Positions.Short, Positions.Long) and action == Actions.Neutral.value:
|
|
||||||
return -1 * trade_duration / max_trade_duration
|
|
||||||
|
|
||||||
# close long
|
|
||||||
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
|
||||||
if pnl > self.profit_aim * self.rr:
|
|
||||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
|
||||||
return float(rew * factor)
|
|
||||||
|
|
||||||
# close short
|
|
||||||
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
|
||||||
if pnl > self.profit_aim * self.rr:
|
|
||||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
|
||||||
return float(rew * factor)
|
|
||||||
|
|
||||||
return 0.
|
|
||||||
|
@ -34,7 +34,7 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
|
|||||||
**self.freqai_info['model_training_parameters']
|
**self.freqai_info['model_training_parameters']
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info('Continual training activated - starting training from previously '
|
logger.info('Continual learning activated - starting training from previously '
|
||||||
'trained agent.')
|
'trained agent.')
|
||||||
model = self.dd.model_dictionary[dk.pair]
|
model = self.dd.model_dictionary[dk.pair]
|
||||||
model.tensorboard_log = Path(dk.data_path / "tensorboard")
|
model.tensorboard_log = Path(dk.data_path / "tensorboard")
|
||||||
@ -65,13 +65,14 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
|
|||||||
|
|
||||||
env_id = "train_env"
|
env_id = "train_env"
|
||||||
num_cpu = int(self.freqai_info["rl_config"]["thread_count"] / 2)
|
num_cpu = int(self.freqai_info["rl_config"]["thread_count"] / 2)
|
||||||
self.train_env = SubprocVecEnv([make_env(env_id, i, 1, train_df, prices_train,
|
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
|
||||||
self.reward_params, self.CONV_WIDTH,
|
self.reward_params, self.CONV_WIDTH,
|
||||||
config=self.config) for i
|
config=self.config) for i
|
||||||
in range(num_cpu)])
|
in range(num_cpu)])
|
||||||
|
|
||||||
eval_env_id = 'eval_env'
|
eval_env_id = 'eval_env'
|
||||||
self.eval_env = SubprocVecEnv([make_env(eval_env_id, i, 1, test_df, prices_test,
|
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
|
||||||
|
test_df, prices_test,
|
||||||
self.reward_params, self.CONV_WIDTH, monitor=True,
|
self.reward_params, self.CONV_WIDTH, monitor=True,
|
||||||
config=self.config) for i
|
config=self.config) for i
|
||||||
in range(num_cpu)])
|
in range(num_cpu)])
|
||||||
|
Loading…
Reference in New Issue
Block a user