add ability to integrate state info or not, and prevent state info integration during backtesting
This commit is contained in:
parent
9c6b97c678
commit
e71a8b8ac1
@ -2,9 +2,7 @@ import logging
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
from pandas import DataFrame
|
|
||||||
|
|
||||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||||
|
|
||||||
@ -145,19 +143,6 @@ class Base5ActionRLEnv(BaseEnvironment):
|
|||||||
|
|
||||||
return observation, step_reward, self._done, info
|
return observation, step_reward, self._done, info
|
||||||
|
|
||||||
def _get_observation(self):
|
|
||||||
features_window = self.signal_features[(
|
|
||||||
self._current_tick - self.window_size):self._current_tick]
|
|
||||||
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
|
|
||||||
columns=['current_profit_pct', 'position', 'trade_duration'],
|
|
||||||
index=features_window.index)
|
|
||||||
|
|
||||||
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
|
|
||||||
features_and_state['position'] = self._position.value
|
|
||||||
features_and_state['trade_duration'] = self.get_trade_duration()
|
|
||||||
features_and_state = pd.concat([features_window, features_and_state], axis=1)
|
|
||||||
return features_and_state
|
|
||||||
|
|
||||||
def get_trade_duration(self):
|
def get_trade_duration(self):
|
||||||
if self._last_trade_tick is None:
|
if self._last_trade_tick is None:
|
||||||
return 0
|
return 0
|
||||||
|
@ -35,6 +35,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
id: str = 'baseenv-1', seed: int = 1, config: dict = {}):
|
id: str = 'baseenv-1', seed: int = 1, config: dict = {}):
|
||||||
|
|
||||||
self.rl_config = config['freqai']['rl_config']
|
self.rl_config = config['freqai']['rl_config']
|
||||||
|
self.add_state_info = self.rl_config.get('add_state_info', False)
|
||||||
self.id = id
|
self.id = id
|
||||||
self.seed(seed)
|
self.seed(seed)
|
||||||
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
||||||
@ -58,7 +59,11 @@ class BaseEnvironment(gym.Env):
|
|||||||
self.fee = 0.0015
|
self.fee = 0.0015
|
||||||
|
|
||||||
# # spaces
|
# # spaces
|
||||||
self.shape = (window_size, self.signal_features.shape[1] + 3)
|
if self.add_state_info:
|
||||||
|
self.total_features = self.signal_features.shape[1] + 3
|
||||||
|
else:
|
||||||
|
self.total_features = self.signal_features.shape[1]
|
||||||
|
self.shape = (window_size, self.total_features)
|
||||||
self.set_action_space()
|
self.set_action_space()
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
low=-1, high=1, shape=self.shape, dtype=np.float32)
|
low=-1, high=1, shape=self.shape, dtype=np.float32)
|
||||||
@ -126,8 +131,11 @@ class BaseEnvironment(gym.Env):
|
|||||||
"""
|
"""
|
||||||
features_window = self.signal_features[(
|
features_window = self.signal_features[(
|
||||||
self._current_tick - self.window_size):self._current_tick]
|
self._current_tick - self.window_size):self._current_tick]
|
||||||
|
if self.add_state_info:
|
||||||
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
|
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
|
||||||
columns=['current_profit_pct', 'position', 'trade_duration'],
|
columns=['current_profit_pct',
|
||||||
|
'position',
|
||||||
|
'trade_duration'],
|
||||||
index=features_window.index)
|
index=features_window.index)
|
||||||
|
|
||||||
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
|
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
|
||||||
@ -135,6 +143,8 @@ class BaseEnvironment(gym.Env):
|
|||||||
features_and_state['trade_duration'] = self.get_trade_duration()
|
features_and_state['trade_duration'] = self.get_trade_duration()
|
||||||
features_and_state = pd.concat([features_window, features_and_state], axis=1)
|
features_and_state = pd.concat([features_window, features_and_state], axis=1)
|
||||||
return features_and_state
|
return features_and_state
|
||||||
|
else:
|
||||||
|
return features_window
|
||||||
|
|
||||||
def get_trade_duration(self):
|
def get_trade_duration(self):
|
||||||
"""
|
"""
|
||||||
|
@ -234,8 +234,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
|
output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
|
||||||
|
|
||||||
def _predict(window):
|
def _predict(window):
|
||||||
market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
|
|
||||||
observations = dataframe.iloc[window.index]
|
observations = dataframe.iloc[window.index]
|
||||||
|
if self.live: # self.guard_state_info_if_backtest():
|
||||||
|
market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
|
||||||
observations['current_profit_pct'] = current_profit
|
observations['current_profit_pct'] = current_profit
|
||||||
observations['position'] = market_side
|
observations['position'] = market_side
|
||||||
observations['trade_duration'] = trade_duration
|
observations['trade_duration'] = trade_duration
|
||||||
@ -246,6 +247,17 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
# def guard_state_info_if_backtest(self):
|
||||||
|
# """
|
||||||
|
# Ensure that backtesting mode doesnt try to use state information.
|
||||||
|
# """
|
||||||
|
# if self.rl_config('add_state_info', False) and not self.live:
|
||||||
|
# logger.warning('Backtesting with state info is currently unavailable '
|
||||||
|
# 'turning it off.')
|
||||||
|
# self.rl_config['add_state_info'] = False
|
||||||
|
|
||||||
|
# return not self.rl_config['add_state_info']
|
||||||
|
|
||||||
def build_ohlc_price_dataframes(self, data_dictionary: dict,
|
def build_ohlc_price_dataframes(self, data_dictionary: dict,
|
||||||
pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame,
|
pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame,
|
||||||
DataFrame]:
|
DataFrame]:
|
||||||
|
Loading…
Reference in New Issue
Block a user