add tests. add guardrails.

This commit is contained in:
robcaulk
2022-09-15 00:46:35 +02:00
parent 48140bff91
commit 8aac644009
9 changed files with 84 additions and 37 deletions

View File

@@ -18,13 +18,13 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
User created Reinforcement Learning Model prediction model.
"""
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
train_df = data_dictionary["train_features"]
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=[512, 512, 256])
net_arch=[128, 128])
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
@@ -69,8 +69,8 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
factor = 100
# reward agent for entering trades
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \
and self._position == Positions.Neutral:
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
and self._position == Positions.Neutral):
return 25
# discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral:
@@ -85,8 +85,8 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
factor *= 0.5
# discourage sitting in position
if self._position in (Positions.Short, Positions.Long) and \
action == Actions.Neutral.value:
if (self._position in (Positions.Short, Positions.Long) and
action == Actions.Neutral.value):
return -1 * trade_duration / max_trade_duration
# close long

View File

@@ -20,14 +20,14 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
User created Reinforcement Learning Model prediction model.
"""
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
train_df = data_dictionary["train_features"]
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
# model arch
policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=[256, 256, 128])
net_arch=[128, 128])
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,