reduce code for base use-case, ensure multiproc inherits custom env, add ability to limit ram use.

This commit is contained in:
robcaulk
2022-08-25 19:05:51 +02:00
parent 05ccebf9a1
commit 3199eb453b
5 changed files with 125 additions and 123 deletions

View File

@@ -19,6 +19,7 @@ from typing import Callable
from datetime import datetime, timezone
from stable_baselines3.common.utils import set_random_seed
import gym
from pathlib import Path
logger = logging.getLogger(__name__)
torch.multiprocessing.set_sharing_strategy('file_system')
@@ -110,9 +111,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params, config=self.config)
self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test,
self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params, config=self.config)
self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test,
window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params, config=self.config))
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
@@ -126,7 +127,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
go in here. Abstract method, so this function must be overridden by
user class.
"""
return
def get_state_info(self, pair: str):
@@ -232,6 +232,72 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return prices_train, prices_test
def load_model_from_disk(self, dk: FreqaiDataKitchen) -> Any:
"""
Can be used by user if they are trying to limit_ram_usage *and*
perform continual learning.
For now, this is unused.
"""
exists = Path(dk.data_path / f"{dk.model_filename}_model").is_file()
if exists:
model = self.MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")
else:
logger.info('No model file on disk to continue learning from.')
return model
# Nested class which can be overridden by user to customize further
class MyRLEnv(Base5ActionRLEnv):
"""
User can override any function in BaseRLEnv and gym.Env. Here the user
sets a custom reward based on profit and trade duration.
"""
def calculate_reward(self, action):
# first, penalize if the action is not valid
if not self._is_valid(action):
return -2
pnl = self.get_unrealized_profit()
rew = np.sign(pnl) * (pnl + 1)
factor = 100
# reward agent for entering trades
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \
and self._position == Positions.Neutral:
return 25
# discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral:
return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
trade_duration = self._current_tick - self._last_trade_tick
if trade_duration <= max_trade_duration:
factor *= 1.5
elif trade_duration > max_trade_duration:
factor *= 0.5
# discourage sitting in position
if self._position in (Positions.Short, Positions.Long) and \
action == Actions.Neutral.value:
return -1 * trade_duration / max_trade_duration
# close long
if action == Actions.Long_exit.value and self._position == Positions.Long:
if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
return float(rew * factor)
# close short
if action == Actions.Short_exit.value and self._position == Positions.Short:
if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
return float(rew * factor)
return 0.
# TODO take care of this appendage. Right now it needs to be called because FreqAI enforces it.
# But FreqaiRL needs more objects passed to fit() (like DK) and we dont want to go refactor
# all the other existing fit() functions to include dk argument. For now we instantiate and
@@ -240,7 +306,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return
def make_env(env_id: str, rank: int, seed: int, train_df: DataFrame, price: DataFrame,
def make_env(MyRLEnv: Base5ActionRLEnv, env_id: str, rank: int,
seed: int, train_df: DataFrame, price: DataFrame,
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
config: Dict[str, Any] = {}) -> Callable:
"""
@@ -252,6 +319,7 @@ def make_env(env_id: str, rank: int, seed: int, train_df: DataFrame, price: Data
:param rank: (int) index of the subprocess
:return: (Callable)
"""
def _init() -> gym.Env:
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
@@ -261,54 +329,3 @@ def make_env(env_id: str, rank: int, seed: int, train_df: DataFrame, price: Data
return env
set_random_seed(seed)
return _init
class MyRLEnv(Base5ActionRLEnv):
"""
User can override any function in BaseRLEnv and gym.Env. Here the user
sets a custom reward based on profit and trade duration.
"""
def calculate_reward(self, action):
# first, penalize if the action is not valid
if not self._is_valid(action):
return -2
pnl = self.get_unrealized_profit()
rew = np.sign(pnl) * (pnl + 1)
factor = 100
# reward agent for entering trades
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \
and self._position == Positions.Neutral:
return 25
# discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral:
return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
trade_duration = self._current_tick - self._last_trade_tick
if trade_duration <= max_trade_duration:
factor *= 1.5
elif trade_duration > max_trade_duration:
factor *= 0.5
# discourage sitting in position
if self._position in (Positions.Short, Positions.Long) and action == Actions.Neutral.value:
return -1 * trade_duration / max_trade_duration
# close long
if action == Actions.Long_exit.value and self._position == Positions.Long:
if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
return float(rew * factor)
# close short
if action == Actions.Short_exit.value and self._position == Positions.Short:
if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
return float(rew * factor)
return 0.