reuse callback, allow user to acces all stable_baselines3 agents via config

This commit is contained in:
robcaulk
2022-08-20 16:35:29 +02:00
parent 4b9499e321
commit 3eb897c2f8
11 changed files with 295 additions and 587 deletions

View File

@@ -11,8 +11,12 @@ from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.RL.Base5ActionRLEnv import Base5ActionRLEnv, Actions, Positions
from freqtrade.persistence import Trade
import torch.multiprocessing
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
import torch as th
from typing import Callable
from stable_baselines3.common.utils import set_random_seed
import gym
logger = logging.getLogger(__name__)
torch.multiprocessing.set_sharing_strategy('file_system')
@@ -25,9 +29,15 @@ class BaseReinforcementLearningModel(IFreqaiModel):
def __init__(self, **kwargs):
super().__init__(config=kwargs['config'])
th.set_num_threads(self.freqai_info.get('data_kitchen_thread_count', 4))
th.set_num_threads(self.freqai_info['rl_config'].get('thread_count', 4))
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: Base5ActionRLEnv = None
self.eval_env: Base5ActionRLEnv = None
self.eval_callback: EvalCallback = None
mod = __import__('stable_baselines3', fromlist=[
self.freqai_info['rl_config']['model_type']])
self.MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type'])
self.policy_type = self.freqai_info['rl_config']['policy_type']
def train(
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen
@@ -67,7 +77,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
)
logger.info(f'Training model on {len(data_dictionary["train_features"])} data points')
self.set_train_and_eval_environments(data_dictionary, prices_train, prices_test)
self.set_train_and_eval_environments(data_dictionary, prices_train, prices_test, dk)
model = self.fit_rl(data_dictionary, dk)
@@ -75,13 +85,13 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return model
def set_train_and_eval_environments(self, data_dictionary, prices_train, prices_test):
def set_train_and_eval_environments(self, data_dictionary, prices_train, prices_test, dk):
"""
User overrides this in their prediction model if they are custom a MyRLEnv. Othwerwise
leaving this will default to Base5ActEnv
User overrides this as shown here if they are using a custom MyRLEnv
"""
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
eval_freq = self.freqai_info["rl_config"]["eval_cycles"] * len(test_df)
# environments
if not self.train_env:
@@ -90,11 +100,17 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test,
window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params), ".")
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=eval_freq,
best_model_save_path=dk.data_path)
else:
self.train_env.reset_env(train_df, prices_train, self.CONV_WIDTH, self.reward_params)
self.eval_env.reset_env(train_df, prices_train, self.CONV_WIDTH, self.reward_params)
self.train_env.reset()
self.eval_env.reset()
self.train_env.reset_env(train_df, prices_train, self.CONV_WIDTH, self.reward_params)
self.eval_env.reset_env(test_df, prices_test, self.CONV_WIDTH, self.reward_params)
self.eval_callback.__init__(self.eval_env, deterministic=True,
render=False, eval_freq=eval_freq,
best_model_save_path=dk.data_path)
@abstractmethod
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
@@ -206,16 +222,28 @@ class BaseReinforcementLearningModel(IFreqaiModel):
# all the other existing fit() functions to include dk argument. For now we instantiate and
# leave it.
def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any:
"""
Most regressors use the same function names and arguments e.g. user
can drop in LGBMRegressor in place of CatBoostRegressor and all data
management will be properly handled by Freqai.
:param data_dictionary: Dict = the dictionary constructed by DataHandler to hold
all the training and test data/labels.
"""
return
def make_env(env_id: str, rank: int, seed: int, train_df, price,
reward_params, window_size, monitor=False) -> Callable:
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environment you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
:return: (Callable)
"""
def _init() -> gym.Env:
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
reward_kwargs=reward_params, id=env_id, seed=seed + rank)
if monitor:
env = Monitor(env, ".")
return env
set_random_seed(seed)
return _init
class MyRLEnv(Base5ActionRLEnv):
"""
@@ -229,24 +257,24 @@ class MyRLEnv(Base5ActionRLEnv):
return 0.
# close long
if action == Actions.Long_sell.value and self._position == Positions.Long:
if action == Actions.Long_exit.value and self._position == Positions.Long:
last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open)
current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open)
return float(np.log(current_price) - np.log(last_trade_price))
if action == Actions.Long_sell.value and self._position == Positions.Long:
if action == Actions.Long_exit.value and self._position == Positions.Long:
if self.close_trade_profit[-1] > self.profit_aim * self.rr:
last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open)
current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open)
return float((np.log(current_price) - np.log(last_trade_price)) * 2)
# close short
if action == Actions.Short_buy.value and self._position == Positions.Short:
if action == Actions.Short_exit.value and self._position == Positions.Short:
last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open)
current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open)
return float(np.log(last_trade_price) - np.log(current_price))
if action == Actions.Short_buy.value and self._position == Positions.Short:
if action == Actions.Short_exit.value and self._position == Positions.Short:
if self.close_trade_profit[-1] > self.profit_aim * self.rr:
last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open)
current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open)