add ability to integrate state info or not, and prevent state info integration during backtesting

This commit is contained in:
robcaulk 2022-11-12 18:46:48 +01:00
parent 9c6b97c678
commit e71a8b8ac1
3 changed files with 35 additions and 28 deletions

View File

@ -2,9 +2,7 @@ import logging
from enum import Enum from enum import Enum
import numpy as np import numpy as np
import pandas as pd
from gym import spaces from gym import spaces
from pandas import DataFrame
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
@ -145,19 +143,6 @@ class Base5ActionRLEnv(BaseEnvironment):
return observation, step_reward, self._done, info return observation, step_reward, self._done, info
def _get_observation(self):
features_window = self.signal_features[(
self._current_tick - self.window_size):self._current_tick]
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
columns=['current_profit_pct', 'position', 'trade_duration'],
index=features_window.index)
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
features_and_state['position'] = self._position.value
features_and_state['trade_duration'] = self.get_trade_duration()
features_and_state = pd.concat([features_window, features_and_state], axis=1)
return features_and_state
def get_trade_duration(self): def get_trade_duration(self):
if self._last_trade_tick is None: if self._last_trade_tick is None:
return 0 return 0

View File

@ -35,6 +35,7 @@ class BaseEnvironment(gym.Env):
id: str = 'baseenv-1', seed: int = 1, config: dict = {}): id: str = 'baseenv-1', seed: int = 1, config: dict = {}):
self.rl_config = config['freqai']['rl_config'] self.rl_config = config['freqai']['rl_config']
self.add_state_info = self.rl_config.get('add_state_info', False)
self.id = id self.id = id
self.seed(seed) self.seed(seed)
self.reset_env(df, prices, window_size, reward_kwargs, starting_point) self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
@ -58,7 +59,11 @@ class BaseEnvironment(gym.Env):
self.fee = 0.0015 self.fee = 0.0015
# # spaces # # spaces
self.shape = (window_size, self.signal_features.shape[1] + 3) if self.add_state_info:
self.total_features = self.signal_features.shape[1] + 3
else:
self.total_features = self.signal_features.shape[1]
self.shape = (window_size, self.total_features)
self.set_action_space() self.set_action_space()
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
low=-1, high=1, shape=self.shape, dtype=np.float32) low=-1, high=1, shape=self.shape, dtype=np.float32)
@ -126,15 +131,20 @@ class BaseEnvironment(gym.Env):
""" """
features_window = self.signal_features[( features_window = self.signal_features[(
self._current_tick - self.window_size):self._current_tick] self._current_tick - self.window_size):self._current_tick]
features_and_state = DataFrame(np.zeros((len(features_window), 3)), if self.add_state_info:
columns=['current_profit_pct', 'position', 'trade_duration'], features_and_state = DataFrame(np.zeros((len(features_window), 3)),
index=features_window.index) columns=['current_profit_pct',
'position',
'trade_duration'],
index=features_window.index)
features_and_state['current_profit_pct'] = self.get_unrealized_profit() features_and_state['current_profit_pct'] = self.get_unrealized_profit()
features_and_state['position'] = self._position.value features_and_state['position'] = self._position.value
features_and_state['trade_duration'] = self.get_trade_duration() features_and_state['trade_duration'] = self.get_trade_duration()
features_and_state = pd.concat([features_window, features_and_state], axis=1) features_and_state = pd.concat([features_window, features_and_state], axis=1)
return features_and_state return features_and_state
else:
return features_window
def get_trade_duration(self): def get_trade_duration(self):
""" """

View File

@ -234,11 +234,12 @@ class BaseReinforcementLearningModel(IFreqaiModel):
output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list) output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
def _predict(window): def _predict(window):
market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
observations = dataframe.iloc[window.index] observations = dataframe.iloc[window.index]
observations['current_profit_pct'] = current_profit if self.live: # self.guard_state_info_if_backtest():
observations['position'] = market_side market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
observations['trade_duration'] = trade_duration observations['current_profit_pct'] = current_profit
observations['position'] = market_side
observations['trade_duration'] = trade_duration
res, _ = model.predict(observations, deterministic=True) res, _ = model.predict(observations, deterministic=True)
return res return res
@ -246,6 +247,17 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return output return output
# def guard_state_info_if_backtest(self):
# """
# Ensure that backtesting mode doesnt try to use state information.
# """
# if self.rl_config('add_state_info', False) and not self.live:
# logger.warning('Backtesting with state info is currently unavailable '
# 'turning it off.')
# self.rl_config['add_state_info'] = False
# return not self.rl_config['add_state_info']
def build_ohlc_price_dataframes(self, data_dictionary: dict, def build_ohlc_price_dataframes(self, data_dictionary: dict,
pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame, pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame,
DataFrame]: DataFrame]: