add tests. add guardrails.
This commit is contained in:
@@ -18,13 +18,13 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
User created Reinforcement Learning Model prediction model.
|
||||
"""
|
||||
|
||||
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
|
||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||
|
||||
train_df = data_dictionary["train_features"]
|
||||
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
||||
|
||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||
net_arch=[512, 512, 256])
|
||||
net_arch=[128, 128])
|
||||
|
||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
||||
@@ -69,8 +69,8 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
factor = 100
|
||||
|
||||
# reward agent for entering trades
|
||||
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \
|
||||
and self._position == Positions.Neutral:
|
||||
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
||||
and self._position == Positions.Neutral):
|
||||
return 25
|
||||
# discourage agent from not entering trades
|
||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||
@@ -85,8 +85,8 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
factor *= 0.5
|
||||
|
||||
# discourage sitting in position
|
||||
if self._position in (Positions.Short, Positions.Long) and \
|
||||
action == Actions.Neutral.value:
|
||||
if (self._position in (Positions.Short, Positions.Long) and
|
||||
action == Actions.Neutral.value):
|
||||
return -1 * trade_duration / max_trade_duration
|
||||
|
||||
# close long
|
||||
|
@@ -20,14 +20,14 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
|
||||
User created Reinforcement Learning Model prediction model.
|
||||
"""
|
||||
|
||||
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
|
||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||
|
||||
train_df = data_dictionary["train_features"]
|
||||
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
||||
|
||||
# model arch
|
||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||
net_arch=[256, 256, 128])
|
||||
net_arch=[128, 128])
|
||||
|
||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
||||
|
Reference in New Issue
Block a user