reduce code for base use-case, ensure multiproc inherits custom env, add ability to limit ram use.
This commit is contained in:
parent
05ccebf9a1
commit
3199eb453b
@ -58,6 +58,7 @@
|
||||
"model_save_type": "stable_baselines",
|
||||
"conv_width": 4,
|
||||
"purge_old_models": true,
|
||||
"limit_ram_usage": false,
|
||||
"train_period_days": 5,
|
||||
"backtest_period_days": 2,
|
||||
"identifier": "unique-id",
|
||||
|
@ -19,6 +19,7 @@ from typing import Callable
|
||||
from datetime import datetime, timezone
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
import gym
|
||||
from pathlib import Path
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
@ -110,9 +111,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH,
|
||||
self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params, config=self.config)
|
||||
self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test,
|
||||
self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test,
|
||||
window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params, config=self.config))
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
@ -126,7 +127,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
go in here. Abstract method, so this function must be overridden by
|
||||
user class.
|
||||
"""
|
||||
|
||||
return
|
||||
|
||||
def get_state_info(self, pair: str):
|
||||
@ -232,38 +232,22 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
return prices_train, prices_test
|
||||
|
||||
# TODO take care of this appendage. Right now it needs to be called because FreqAI enforces it.
|
||||
# But FreqaiRL needs more objects passed to fit() (like DK) and we dont want to go refactor
|
||||
# all the other existing fit() functions to include dk argument. For now we instantiate and
|
||||
# leave it.
|
||||
def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any:
|
||||
return
|
||||
|
||||
|
||||
def make_env(env_id: str, rank: int, seed: int, train_df: DataFrame, price: DataFrame,
|
||||
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
||||
config: Dict[str, Any] = {}) -> Callable:
|
||||
def load_model_from_disk(self, dk: FreqaiDataKitchen) -> Any:
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
:param env_id: (str) the environment ID
|
||||
:param num_env: (int) the number of environment you wish to have in subprocesses
|
||||
:param seed: (int) the inital seed for RNG
|
||||
:param rank: (int) index of the subprocess
|
||||
:return: (Callable)
|
||||
Can be used by user if they are trying to limit_ram_usage *and*
|
||||
perform continual learning.
|
||||
For now, this is unused.
|
||||
"""
|
||||
def _init() -> gym.Env:
|
||||
exists = Path(dk.data_path / f"{dk.model_filename}_model").is_file()
|
||||
if exists:
|
||||
model = self.MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")
|
||||
else:
|
||||
logger.info('No model file on disk to continue learning from.')
|
||||
|
||||
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
||||
reward_kwargs=reward_params, id=env_id, seed=seed + rank, config=config)
|
||||
if monitor:
|
||||
env = Monitor(env)
|
||||
return env
|
||||
set_random_seed(seed)
|
||||
return _init
|
||||
return model
|
||||
|
||||
|
||||
class MyRLEnv(Base5ActionRLEnv):
|
||||
# Nested class which can be overridden by user to customize further
|
||||
class MyRLEnv(Base5ActionRLEnv):
|
||||
"""
|
||||
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||
sets a custom reward based on profit and trade duration.
|
||||
@ -296,7 +280,8 @@ class MyRLEnv(Base5ActionRLEnv):
|
||||
factor *= 0.5
|
||||
|
||||
# discourage sitting in position
|
||||
if self._position in (Positions.Short, Positions.Long) and action == Actions.Neutral.value:
|
||||
if self._position in (Positions.Short, Positions.Long) and \
|
||||
action == Actions.Neutral.value:
|
||||
return -1 * trade_duration / max_trade_duration
|
||||
|
||||
# close long
|
||||
@ -312,3 +297,35 @@ class MyRLEnv(Base5ActionRLEnv):
|
||||
return float(rew * factor)
|
||||
|
||||
return 0.
|
||||
|
||||
# TODO take care of this appendage. Right now it needs to be called because FreqAI enforces it.
|
||||
# But FreqaiRL needs more objects passed to fit() (like DK) and we dont want to go refactor
|
||||
# all the other existing fit() functions to include dk argument. For now we instantiate and
|
||||
# leave it.
|
||||
def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any:
|
||||
return
|
||||
|
||||
|
||||
def make_env(MyRLEnv: Base5ActionRLEnv, env_id: str, rank: int,
|
||||
seed: int, train_df: DataFrame, price: DataFrame,
|
||||
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
||||
config: Dict[str, Any] = {}) -> Callable:
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
:param env_id: (str) the environment ID
|
||||
:param num_env: (int) the number of environment you wish to have in subprocesses
|
||||
:param seed: (int) the inital seed for RNG
|
||||
:param rank: (int) index of the subprocess
|
||||
:return: (Callable)
|
||||
"""
|
||||
|
||||
def _init() -> gym.Env:
|
||||
|
||||
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
||||
reward_kwargs=reward_params, id=env_id, seed=seed + rank, config=config)
|
||||
if monitor:
|
||||
env = Monitor(env)
|
||||
return env
|
||||
set_random_seed(seed)
|
||||
return _init
|
||||
|
@ -90,6 +90,7 @@ class FreqaiDataDrawer:
|
||||
self.empty_pair_dict: pair_info = {
|
||||
"model_filename": "", "trained_timestamp": 0,
|
||||
"priority": 1, "first": True, "data_path": "", "extras": {}}
|
||||
self.limit_ram_use = self.freqai_info.get('limit_ram_usage', False)
|
||||
|
||||
def load_drawer_from_disk(self):
|
||||
"""
|
||||
@ -423,7 +424,7 @@ class FreqaiDataDrawer:
|
||||
dk.pca, open(dk.data_path / f"{dk.model_filename}_pca_object.pkl", "wb")
|
||||
)
|
||||
|
||||
# if self.live:
|
||||
if not self.limit_ram_use:
|
||||
self.model_dictionary[coin] = model
|
||||
self.pair_dict[coin]["model_filename"] = dk.model_filename
|
||||
self.pair_dict[coin]["data_path"] = str(dk.data_path)
|
||||
@ -464,7 +465,7 @@ class FreqaiDataDrawer:
|
||||
|
||||
model_type = self.freqai_info.get('model_save_type', 'joblib')
|
||||
# try to access model in memory instead of loading object from disk to save time
|
||||
if dk.live and coin in self.model_dictionary:
|
||||
if dk.live and coin in self.model_dictionary and not self.limit_ram_use:
|
||||
model = self.model_dictionary[coin]
|
||||
elif model_type == 'joblib':
|
||||
model = load(dk.data_path / f"{dk.model_filename}_model.joblib")
|
||||
@ -486,7 +487,7 @@ class FreqaiDataDrawer:
|
||||
)
|
||||
|
||||
# load it into ram if it was loaded from disk
|
||||
if coin not in self.model_dictionary:
|
||||
if coin not in self.model_dictionary and not self.limit_ram_use:
|
||||
self.model_dictionary[coin] = model
|
||||
|
||||
if self.config["freqai"]["feature_parameters"]["principal_component_analysis"]:
|
||||
|
@ -3,12 +3,12 @@ from typing import Any, Dict
|
||||
|
||||
import torch as th
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Positions
|
||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
|
||||
from pathlib import Path
|
||||
from pandas import DataFrame
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
# from pandas import DataFrame
|
||||
# from stable_baselines3.common.callbacks import EvalCallback
|
||||
# from stable_baselines3.common.monitor import Monitor
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -53,26 +53,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
|
||||
return model
|
||||
|
||||
def set_train_and_eval_environments(self, data_dictionary: Dict[str, DataFrame],
|
||||
prices_train: DataFrame, prices_test: DataFrame,
|
||||
dk: FreqaiDataKitchen):
|
||||
"""
|
||||
User can override this if they are using a custom MyRLEnv
|
||||
"""
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params, config=self.config)
|
||||
self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test,
|
||||
window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params, config=self.config))
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=len(train_df),
|
||||
best_model_save_path=str(dk.data_path))
|
||||
|
||||
|
||||
class MyRLEnv(Base5ActionRLEnv):
|
||||
class MyRLEnv(BaseReinforcementLearningModel.MyRLEnv):
|
||||
"""
|
||||
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||
sets a custom reward based on profit and trade duration.
|
||||
@ -105,7 +86,8 @@ class MyRLEnv(Base5ActionRLEnv):
|
||||
factor *= 0.5
|
||||
|
||||
# discourage sitting in position
|
||||
if self._position in (Positions.Short, Positions.Long) and action == Actions.Neutral.value:
|
||||
if self._position in (Positions.Short, Positions.Long) and \
|
||||
action == Actions.Neutral.value:
|
||||
return -1 * trade_duration / max_trade_duration
|
||||
|
||||
# close long
|
||||
|
@ -34,7 +34,7 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
|
||||
**self.freqai_info['model_training_parameters']
|
||||
)
|
||||
else:
|
||||
logger.info('Continual training activated - starting training from previously '
|
||||
logger.info('Continual learning activated - starting training from previously '
|
||||
'trained agent.')
|
||||
model = self.dd.model_dictionary[dk.pair]
|
||||
model.tensorboard_log = Path(dk.data_path / "tensorboard")
|
||||
@ -65,13 +65,14 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
|
||||
|
||||
env_id = "train_env"
|
||||
num_cpu = int(self.freqai_info["rl_config"]["thread_count"] / 2)
|
||||
self.train_env = SubprocVecEnv([make_env(env_id, i, 1, train_df, prices_train,
|
||||
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
|
||||
self.reward_params, self.CONV_WIDTH,
|
||||
config=self.config) for i
|
||||
in range(num_cpu)])
|
||||
|
||||
eval_env_id = 'eval_env'
|
||||
self.eval_env = SubprocVecEnv([make_env(eval_env_id, i, 1, test_df, prices_test,
|
||||
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
|
||||
test_df, prices_test,
|
||||
self.reward_params, self.CONV_WIDTH, monitor=True,
|
||||
config=self.config) for i
|
||||
in range(num_cpu)])
|
||||
|
Loading…
Reference in New Issue
Block a user