add test coverage, fix bug in base environment. Ensure proper fee is used.
This commit is contained in:
@@ -74,10 +74,10 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
self.ft_params.update({'use_SVM_to_remove_outliers': False})
|
||||
logger.warning('User tried to use SVM with RL. Deactivating SVM.')
|
||||
if self.ft_params.get('use_DBSCAN_to_remove_outliers', False):
|
||||
self.ft_params.update({'use_SVM_to_remove_outliers': False})
|
||||
self.ft_params.update({'use_DBSCAN_to_remove_outliers': False})
|
||||
logger.warning('User tried to use DBSCAN with RL. Deactivating DBSCAN.')
|
||||
if self.freqai_info['data_split_parameters'].get('shuffle', False):
|
||||
self.freqai_info['data_split_parameters'].update('shuffle', False)
|
||||
self.freqai_info['data_split_parameters'].update({'shuffle': False})
|
||||
logger.warning('User tried to shuffle training data. Setting shuffle to False')
|
||||
|
||||
def train(
|
||||
@@ -141,11 +141,18 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params, config=self.config)
|
||||
self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test,
|
||||
window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params, config=self.config))
|
||||
self.train_env = self.MyRLEnv(df=train_df,
|
||||
prices=prices_train,
|
||||
window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params,
|
||||
config=self.config,
|
||||
dp=self.data_provider)
|
||||
self.eval_env = Monitor(self.MyRLEnv(df=test_df,
|
||||
prices=prices_test,
|
||||
window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params,
|
||||
config=self.config,
|
||||
dp=self.data_provider))
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=len(train_df),
|
||||
best_model_save_path=str(dk.data_path))
|
||||
@@ -179,12 +186,13 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
if trade.pair == pair:
|
||||
if self.data_provider._exchange is None: # type: ignore
|
||||
logger.error('No exchange available.')
|
||||
return 0, 0, 0
|
||||
else:
|
||||
current_rate = self.data_provider._exchange.get_rate( # type: ignore
|
||||
pair, refresh=False, side="exit", is_short=trade.is_short)
|
||||
|
||||
now = datetime.now(timezone.utc).timestamp()
|
||||
trade_duration = int((now - trade.open_date_utc) / self.base_tf_seconds)
|
||||
trade_duration = int((now - trade.open_date_utc.timestamp()) / self.base_tf_seconds)
|
||||
current_profit = trade.calc_profit_ratio(current_rate)
|
||||
|
||||
return market_side, current_profit, int(trade_duration)
|
||||
@@ -230,7 +238,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
def _predict(window):
|
||||
observations = dataframe.iloc[window.index]
|
||||
if self.live: # self.guard_state_info_if_backtest():
|
||||
if self.live and self.rl_config('add_state_info', False):
|
||||
market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
|
||||
observations['current_profit_pct'] = current_profit
|
||||
observations['position'] = market_side
|
||||
@@ -242,17 +250,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
return output
|
||||
|
||||
# def guard_state_info_if_backtest(self):
|
||||
# """
|
||||
# Ensure that backtesting mode doesnt try to use state information.
|
||||
# """
|
||||
# if self.rl_config('add_state_info', False) and not self.live:
|
||||
# logger.warning('Backtesting with state info is currently unavailable '
|
||||
# 'turning it off.')
|
||||
# self.rl_config['add_state_info'] = False
|
||||
|
||||
# return not self.rl_config['add_state_info']
|
||||
|
||||
def build_ohlc_price_dataframes(self, data_dictionary: dict,
|
||||
pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame,
|
||||
DataFrame]:
|
||||
|
Reference in New Issue
Block a user