add sb3_contrib models to the available agents. include sb3_contrib in requirements.
This commit is contained in:
parent
8b3a8234ac
commit
d88a0dbf82
@ -223,12 +223,10 @@ class Base5ActionRLEnv(gym.Env):
|
|||||||
(action == Actions.Neutral.value and self._position == Positions.Long) or
|
(action == Actions.Neutral.value and self._position == Positions.Long) or
|
||||||
(action == Actions.Short_enter.value and self._position == Positions.Short) or
|
(action == Actions.Short_enter.value and self._position == Positions.Short) or
|
||||||
(action == Actions.Short_enter.value and self._position == Positions.Long) or
|
(action == Actions.Short_enter.value and self._position == Positions.Long) or
|
||||||
# (action == Actions.Short_exit.value and self._position == Positions.Short) or
|
|
||||||
(action == Actions.Short_exit.value and self._position == Positions.Long) or
|
(action == Actions.Short_exit.value and self._position == Positions.Long) or
|
||||||
(action == Actions.Short_exit.value and self._position == Positions.Neutral) or
|
(action == Actions.Short_exit.value and self._position == Positions.Neutral) or
|
||||||
(action == Actions.Long_enter.value and self._position == Positions.Long) or
|
(action == Actions.Long_enter.value and self._position == Positions.Long) or
|
||||||
(action == Actions.Long_enter.value and self._position == Positions.Short) or
|
(action == Actions.Long_enter.value and self._position == Positions.Short) or
|
||||||
# (action == Actions.Long_exit.value and self._position == Positions.Long) or
|
|
||||||
(action == Actions.Long_exit.value and self._position == Positions.Short) or
|
(action == Actions.Long_exit.value and self._position == Positions.Short) or
|
||||||
(action == Actions.Long_exit.value and self._position == Positions.Neutral))
|
(action == Actions.Long_exit.value and self._position == Positions.Neutral))
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import numpy.typing as npt
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from freqtrade.exceptions import OperationalException
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Base5ActionRLEnv, Actions, Positions
|
from freqtrade.freqai.RL.Base5ActionRLEnv import Base5ActionRLEnv, Actions, Positions
|
||||||
@ -21,6 +22,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||||
|
|
||||||
|
SB3_MODELS = ['PPO', 'A2C', 'DQN', 'TD3', 'SAC']
|
||||||
|
SB3_CONTRIB_MODELS = ['TRPO', 'ARS']
|
||||||
|
|
||||||
|
|
||||||
class BaseReinforcementLearningModel(IFreqaiModel):
|
class BaseReinforcementLearningModel(IFreqaiModel):
|
||||||
"""
|
"""
|
||||||
@ -34,9 +38,19 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
self.train_env: Base5ActionRLEnv = None
|
self.train_env: Base5ActionRLEnv = None
|
||||||
self.eval_env: Base5ActionRLEnv = None
|
self.eval_env: Base5ActionRLEnv = None
|
||||||
self.eval_callback: EvalCallback = None
|
self.eval_callback: EvalCallback = None
|
||||||
mod = __import__('stable_baselines3', fromlist=[
|
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||||
self.freqai_info['rl_config']['model_type']])
|
if self.model_type in SB3_MODELS:
|
||||||
self.MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type'])
|
import_str = 'stable_baselines3'
|
||||||
|
elif self.model_type in SB3_CONTRIB_MODELS:
|
||||||
|
import_str = 'sb3_contrib'
|
||||||
|
else:
|
||||||
|
raise OperationalException(f'{self.model_type} not available in stable_baselines3 or '
|
||||||
|
f'sb3_contrib. please choose one of {SB3_MODELS} or '
|
||||||
|
f'{SB3_CONTRIB_MODELS}')
|
||||||
|
|
||||||
|
mod = __import__(import_str, fromlist=[
|
||||||
|
self.model_type])
|
||||||
|
self.MODELCLASS = getattr(mod, self.model_type)
|
||||||
self.policy_type = self.freqai_info['rl_config']['policy_type']
|
self.policy_type = self.freqai_info['rl_config']['policy_type']
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
@ -223,6 +237,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any:
|
def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
||||||
reward_params, window_size, monitor=False) -> Callable:
|
reward_params, window_size, monitor=False) -> Callable:
|
||||||
"""
|
"""
|
||||||
@ -244,6 +259,7 @@ def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
|||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
return _init
|
return _init
|
||||||
|
|
||||||
|
|
||||||
class MyRLEnv(Base5ActionRLEnv):
|
class MyRLEnv(Base5ActionRLEnv):
|
||||||
"""
|
"""
|
||||||
User can override any function in BaseRLEnv and gym.Env. Here the user
|
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||||
@ -257,26 +273,20 @@ class MyRLEnv(Base5ActionRLEnv):
|
|||||||
|
|
||||||
# close long
|
# close long
|
||||||
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
||||||
last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open)
|
last_trade_price = self.add_entry_fee(self.prices.iloc[self._last_trade_tick].open)
|
||||||
current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open)
|
current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open)
|
||||||
return float(np.log(current_price) - np.log(last_trade_price))
|
factor = 1
|
||||||
|
if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr:
|
||||||
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
factor = 2
|
||||||
if self.close_trade_profit[-1] > self.profit_aim * self.rr:
|
return float((np.log(current_price) - np.log(last_trade_price)) * factor)
|
||||||
last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open)
|
|
||||||
current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open)
|
|
||||||
return float((np.log(current_price) - np.log(last_trade_price)) * 2)
|
|
||||||
|
|
||||||
# close short
|
# close short
|
||||||
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
||||||
last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open)
|
last_trade_price = self.add_entry_fee(self.prices.iloc[self._last_trade_tick].open)
|
||||||
current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open)
|
current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open)
|
||||||
return float(np.log(last_trade_price) - np.log(current_price))
|
factor = 1
|
||||||
|
if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr:
|
||||||
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
factor = 2
|
||||||
if self.close_trade_profit[-1] > self.profit_aim * self.rr:
|
return float(np.log(last_trade_price) - np.log(current_price) * factor)
|
||||||
last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open)
|
|
||||||
current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open)
|
|
||||||
return float((np.log(last_trade_price) - np.log(current_price)) * 2)
|
|
||||||
|
|
||||||
return 0.
|
return 0.
|
||||||
|
@ -10,3 +10,5 @@ torch==1.12.1
|
|||||||
stable-baselines3==1.6.0
|
stable-baselines3==1.6.0
|
||||||
gym==0.21.0
|
gym==0.21.0
|
||||||
tensorboard==2.9.1
|
tensorboard==2.9.1
|
||||||
|
optuna==2.10.1
|
||||||
|
sb3-contrib==1.6.0
|
Loading…
Reference in New Issue
Block a user