add ability to integrate state info or not, and prevent state info integration during backtesting

This commit is contained in:
robcaulk 2022-11-12 18:46:48 +01:00
parent 9c6b97c678
commit e71a8b8ac1
3 changed files with 35 additions and 28 deletions

View File

@ -2,9 +2,7 @@ import logging
from enum import Enum
import numpy as np
import pandas as pd
from gym import spaces
from pandas import DataFrame
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
@ -145,19 +143,6 @@ class Base5ActionRLEnv(BaseEnvironment):
return observation, step_reward, self._done, info
def _get_observation(self):
features_window = self.signal_features[(
self._current_tick - self.window_size):self._current_tick]
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
columns=['current_profit_pct', 'position', 'trade_duration'],
index=features_window.index)
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
features_and_state['position'] = self._position.value
features_and_state['trade_duration'] = self.get_trade_duration()
features_and_state = pd.concat([features_window, features_and_state], axis=1)
return features_and_state
def get_trade_duration(self):
if self._last_trade_tick is None:
return 0

View File

@ -35,6 +35,7 @@ class BaseEnvironment(gym.Env):
id: str = 'baseenv-1', seed: int = 1, config: dict = {}):
self.rl_config = config['freqai']['rl_config']
self.add_state_info = self.rl_config.get('add_state_info', False)
self.id = id
self.seed(seed)
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
@ -58,7 +59,11 @@ class BaseEnvironment(gym.Env):
self.fee = 0.0015
# # spaces
self.shape = (window_size, self.signal_features.shape[1] + 3)
if self.add_state_info:
self.total_features = self.signal_features.shape[1] + 3
else:
self.total_features = self.signal_features.shape[1]
self.shape = (window_size, self.total_features)
self.set_action_space()
self.observation_space = spaces.Box(
low=-1, high=1, shape=self.shape, dtype=np.float32)
@ -126,15 +131,20 @@ class BaseEnvironment(gym.Env):
"""
features_window = self.signal_features[(
self._current_tick - self.window_size):self._current_tick]
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
columns=['current_profit_pct', 'position', 'trade_duration'],
index=features_window.index)
if self.add_state_info:
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
columns=['current_profit_pct',
'position',
'trade_duration'],
index=features_window.index)
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
features_and_state['position'] = self._position.value
features_and_state['trade_duration'] = self.get_trade_duration()
features_and_state = pd.concat([features_window, features_and_state], axis=1)
return features_and_state
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
features_and_state['position'] = self._position.value
features_and_state['trade_duration'] = self.get_trade_duration()
features_and_state = pd.concat([features_window, features_and_state], axis=1)
return features_and_state
else:
return features_window
def get_trade_duration(self):
"""

View File

@ -234,11 +234,12 @@ class BaseReinforcementLearningModel(IFreqaiModel):
output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
def _predict(window):
market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
observations = dataframe.iloc[window.index]
observations['current_profit_pct'] = current_profit
observations['position'] = market_side
observations['trade_duration'] = trade_duration
if self.live: # self.guard_state_info_if_backtest():
market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
observations['current_profit_pct'] = current_profit
observations['position'] = market_side
observations['trade_duration'] = trade_duration
res, _ = model.predict(observations, deterministic=True)
return res
@ -246,6 +247,17 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return output
# def guard_state_info_if_backtest(self):
# """
# Ensure that backtesting mode doesnt try to use state information.
# """
# if self.rl_config('add_state_info', False) and not self.live:
# logger.warning('Backtesting with state info is currently unavailable '
# 'turning it off.')
# self.rl_config['add_state_info'] = False
# return not self.rl_config['add_state_info']
def build_ohlc_price_dataframes(self, data_dictionary: dict,
pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame,
DataFrame]: