add sb3_contrib models to the available agents. include sb3_contrib in requirements.

This commit is contained in:
robcaulk 2022-08-21 19:58:36 +02:00
parent 8b3a8234ac
commit d88a0dbf82
3 changed files with 35 additions and 25 deletions

View File

@ -223,12 +223,10 @@ class Base5ActionRLEnv(gym.Env):
(action == Actions.Neutral.value and self._position == Positions.Long) or (action == Actions.Neutral.value and self._position == Positions.Long) or
(action == Actions.Short_enter.value and self._position == Positions.Short) or (action == Actions.Short_enter.value and self._position == Positions.Short) or
(action == Actions.Short_enter.value and self._position == Positions.Long) or (action == Actions.Short_enter.value and self._position == Positions.Long) or
# (action == Actions.Short_exit.value and self._position == Positions.Short) or
(action == Actions.Short_exit.value and self._position == Positions.Long) or (action == Actions.Short_exit.value and self._position == Positions.Long) or
(action == Actions.Short_exit.value and self._position == Positions.Neutral) or (action == Actions.Short_exit.value and self._position == Positions.Neutral) or
(action == Actions.Long_enter.value and self._position == Positions.Long) or (action == Actions.Long_enter.value and self._position == Positions.Long) or
(action == Actions.Long_enter.value and self._position == Positions.Short) or (action == Actions.Long_enter.value and self._position == Positions.Short) or
# (action == Actions.Long_exit.value and self._position == Positions.Long) or
(action == Actions.Long_exit.value and self._position == Positions.Short) or (action == Actions.Long_exit.value and self._position == Positions.Short) or
(action == Actions.Long_exit.value and self._position == Positions.Neutral)) (action == Actions.Long_exit.value and self._position == Positions.Neutral))

View File

@ -6,6 +6,7 @@ import numpy.typing as npt
import pandas as pd import pandas as pd
from pandas import DataFrame from pandas import DataFrame
from abc import abstractmethod from abc import abstractmethod
from freqtrade.exceptions import OperationalException
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.RL.Base5ActionRLEnv import Base5ActionRLEnv, Actions, Positions from freqtrade.freqai.RL.Base5ActionRLEnv import Base5ActionRLEnv, Actions, Positions
@ -21,6 +22,9 @@ logger = logging.getLogger(__name__)
torch.multiprocessing.set_sharing_strategy('file_system') torch.multiprocessing.set_sharing_strategy('file_system')
SB3_MODELS = ['PPO', 'A2C', 'DQN', 'TD3', 'SAC']
SB3_CONTRIB_MODELS = ['TRPO', 'ARS']
class BaseReinforcementLearningModel(IFreqaiModel): class BaseReinforcementLearningModel(IFreqaiModel):
""" """
@ -34,9 +38,19 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.train_env: Base5ActionRLEnv = None self.train_env: Base5ActionRLEnv = None
self.eval_env: Base5ActionRLEnv = None self.eval_env: Base5ActionRLEnv = None
self.eval_callback: EvalCallback = None self.eval_callback: EvalCallback = None
mod = __import__('stable_baselines3', fromlist=[ self.model_type = self.freqai_info['rl_config']['model_type']
self.freqai_info['rl_config']['model_type']]) if self.model_type in SB3_MODELS:
self.MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type']) import_str = 'stable_baselines3'
elif self.model_type in SB3_CONTRIB_MODELS:
import_str = 'sb3_contrib'
else:
raise OperationalException(f'{self.model_type} not available in stable_baselines3 or '
f'sb3_contrib. please choose one of {SB3_MODELS} or '
f'{SB3_CONTRIB_MODELS}')
mod = __import__(import_str, fromlist=[
self.model_type])
self.MODELCLASS = getattr(mod, self.model_type)
self.policy_type = self.freqai_info['rl_config']['policy_type'] self.policy_type = self.freqai_info['rl_config']['policy_type']
def train( def train(
@ -137,7 +151,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
current_profit = current_value / openrate - 1 current_profit = current_value / openrate - 1
total_profit = 0 total_profit = 0
closed_trades = Trade.get_trades_proxy(pair = pair, is_open=False) closed_trades = Trade.get_trades_proxy(pair=pair, is_open=False)
for trade in closed_trades: for trade in closed_trades:
total_profit += trade.close_profit total_profit += trade.close_profit
@ -223,6 +237,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any: def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any:
return return
def make_env(env_id: str, rank: int, seed: int, train_df, price, def make_env(env_id: str, rank: int, seed: int, train_df, price,
reward_params, window_size, monitor=False) -> Callable: reward_params, window_size, monitor=False) -> Callable:
""" """
@ -244,6 +259,7 @@ def make_env(env_id: str, rank: int, seed: int, train_df, price,
set_random_seed(seed) set_random_seed(seed)
return _init return _init
class MyRLEnv(Base5ActionRLEnv): class MyRLEnv(Base5ActionRLEnv):
""" """
User can override any function in BaseRLEnv and gym.Env. Here the user User can override any function in BaseRLEnv and gym.Env. Here the user
@ -257,26 +273,20 @@ class MyRLEnv(Base5ActionRLEnv):
# close long # close long
if action == Actions.Long_exit.value and self._position == Positions.Long: if action == Actions.Long_exit.value and self._position == Positions.Long:
last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open) last_trade_price = self.add_entry_fee(self.prices.iloc[self._last_trade_tick].open)
current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open) current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open)
return float(np.log(current_price) - np.log(last_trade_price)) factor = 1
if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr:
if action == Actions.Long_exit.value and self._position == Positions.Long: factor = 2
if self.close_trade_profit[-1] > self.profit_aim * self.rr: return float((np.log(current_price) - np.log(last_trade_price)) * factor)
last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open)
current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open)
return float((np.log(current_price) - np.log(last_trade_price)) * 2)
# close short # close short
if action == Actions.Short_exit.value and self._position == Positions.Short: if action == Actions.Short_exit.value and self._position == Positions.Short:
last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open) last_trade_price = self.add_entry_fee(self.prices.iloc[self._last_trade_tick].open)
current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open) current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open)
return float(np.log(last_trade_price) - np.log(current_price)) factor = 1
if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr:
if action == Actions.Short_exit.value and self._position == Positions.Short: factor = 2
if self.close_trade_profit[-1] > self.profit_aim * self.rr: return float(np.log(last_trade_price) - np.log(current_price) * factor)
last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open)
current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open)
return float((np.log(last_trade_price) - np.log(current_price)) * 2)
return 0. return 0.

View File

@ -10,3 +10,5 @@ torch==1.12.1
stable-baselines3==1.6.0 stable-baselines3==1.6.0
gym==0.21.0 gym==0.21.0
tensorboard==2.9.1 tensorboard==2.9.1
optuna==2.10.1
sb3-contrib==1.6.0