diff --git a/freqtrade/freqai/RL/Base5ActionRLEnv.py b/freqtrade/freqai/RL/Base5ActionRLEnv.py index 80543bf72..663ecc77e 100644 --- a/freqtrade/freqai/RL/Base5ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base5ActionRLEnv.py @@ -2,9 +2,7 @@ import logging from enum import Enum import numpy as np -import pandas as pd from gym import spaces -from pandas import DataFrame from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions @@ -145,19 +143,6 @@ class Base5ActionRLEnv(BaseEnvironment): return observation, step_reward, self._done, info - def _get_observation(self): - features_window = self.signal_features[( - self._current_tick - self.window_size):self._current_tick] - features_and_state = DataFrame(np.zeros((len(features_window), 3)), - columns=['current_profit_pct', 'position', 'trade_duration'], - index=features_window.index) - - features_and_state['current_profit_pct'] = self.get_unrealized_profit() - features_and_state['position'] = self._position.value - features_and_state['trade_duration'] = self.get_trade_duration() - features_and_state = pd.concat([features_window, features_and_state], axis=1) - return features_and_state - def get_trade_duration(self): if self._last_trade_tick is None: return 0 diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index 6474483c6..6633bf3e8 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -35,6 +35,7 @@ class BaseEnvironment(gym.Env): id: str = 'baseenv-1', seed: int = 1, config: dict = {}): self.rl_config = config['freqai']['rl_config'] + self.add_state_info = self.rl_config.get('add_state_info', False) self.id = id self.seed(seed) self.reset_env(df, prices, window_size, reward_kwargs, starting_point) @@ -58,7 +59,11 @@ class BaseEnvironment(gym.Env): self.fee = 0.0015 # # spaces - self.shape = (window_size, self.signal_features.shape[1] + 3) + if self.add_state_info: + self.total_features = self.signal_features.shape[1] + 3 + else: + self.total_features = self.signal_features.shape[1] + self.shape = (window_size, self.total_features) self.set_action_space() self.observation_space = spaces.Box( low=-1, high=1, shape=self.shape, dtype=np.float32) @@ -126,15 +131,20 @@ class BaseEnvironment(gym.Env): """ features_window = self.signal_features[( self._current_tick - self.window_size):self._current_tick] - features_and_state = DataFrame(np.zeros((len(features_window), 3)), - columns=['current_profit_pct', 'position', 'trade_duration'], - index=features_window.index) + if self.add_state_info: + features_and_state = DataFrame(np.zeros((len(features_window), 3)), + columns=['current_profit_pct', + 'position', + 'trade_duration'], + index=features_window.index) - features_and_state['current_profit_pct'] = self.get_unrealized_profit() - features_and_state['position'] = self._position.value - features_and_state['trade_duration'] = self.get_trade_duration() - features_and_state = pd.concat([features_window, features_and_state], axis=1) - return features_and_state + features_and_state['current_profit_pct'] = self.get_unrealized_profit() + features_and_state['position'] = self._position.value + features_and_state['trade_duration'] = self.get_trade_duration() + features_and_state = pd.concat([features_window, features_and_state], axis=1) + return features_and_state + else: + return features_window def get_trade_duration(self): """ diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 323cfd782..885918ffb 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -234,11 +234,12 @@ class BaseReinforcementLearningModel(IFreqaiModel): output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list) def _predict(window): - market_side, current_profit, trade_duration = self.get_state_info(dk.pair) observations = dataframe.iloc[window.index] - observations['current_profit_pct'] = current_profit - observations['position'] = market_side - observations['trade_duration'] = trade_duration + if self.live: # self.guard_state_info_if_backtest(): + market_side, current_profit, trade_duration = self.get_state_info(dk.pair) + observations['current_profit_pct'] = current_profit + observations['position'] = market_side + observations['trade_duration'] = trade_duration res, _ = model.predict(observations, deterministic=True) return res @@ -246,6 +247,17 @@ class BaseReinforcementLearningModel(IFreqaiModel): return output + # def guard_state_info_if_backtest(self): + # """ + # Ensure that backtesting mode doesnt try to use state information. + # """ + # if self.rl_config('add_state_info', False) and not self.live: + # logger.warning('Backtesting with state info is currently unavailable ' + # 'turning it off.') + # self.rl_config['add_state_info'] = False + + # return not self.rl_config['add_state_info'] + def build_ohlc_price_dataframes(self, data_dictionary: dict, pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame, DataFrame]: