improve price df handling to enable backtesting
This commit is contained in:
@@ -10,8 +10,11 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||
from freqtrade.freqai.RL.Base3ActionRLEnv import Base3ActionRLEnv, Actions, Positions
|
||||
from freqtrade.persistence import Trade
|
||||
|
||||
import torch.multiprocessing
|
||||
import torch as th
|
||||
logger = logging.getLogger(__name__)
|
||||
th.set_num_threads(8)
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
|
||||
|
||||
class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
@@ -46,6 +49,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
dk.fit_labels() # useless for now, but just satiating append methods
|
||||
|
||||
# normalize all data based on train_dataset only
|
||||
prices_train, prices_test = self.build_ohlc_price_dataframes(dk.data_dictionary, pair, dk)
|
||||
data_dictionary = dk.normalize_data(data_dictionary)
|
||||
|
||||
# optional additional data cleaning/analysis
|
||||
@@ -56,7 +60,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
)
|
||||
logger.info(f'Training model on {len(data_dictionary["train_features"])} data points')
|
||||
|
||||
model = self.fit_rl(data_dictionary, pair, dk)
|
||||
model = self.fit_rl(data_dictionary, pair, dk, prices_train, prices_test)
|
||||
|
||||
if pair not in self.dd.historic_predictions:
|
||||
self.set_initial_historic_predictions(
|
||||
@@ -69,7 +73,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
return model
|
||||
|
||||
@abstractmethod
|
||||
def fit_rl(self, data_dictionary: Dict[str, Any], pair: str, dk: FreqaiDataKitchen):
|
||||
def fit_rl(self, data_dictionary: Dict[str, Any], pair: str, dk: FreqaiDataKitchen,
|
||||
prices_train: DataFrame, prices_test: DataFrame):
|
||||
"""
|
||||
Agent customizations and abstract Reinforcement Learning customizations
|
||||
go in here. Abstract method, so this function must be overridden by
|
||||
@@ -141,6 +146,34 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
return output
|
||||
|
||||
def build_ohlc_price_dataframes(self, data_dictionary: dict,
|
||||
pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame,
|
||||
DataFrame]:
|
||||
"""
|
||||
Builds the train prices and test prices for the environment.
|
||||
"""
|
||||
|
||||
coin = pair.split('/')[0]
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
# price data for model training and evaluation
|
||||
tf = self.config['timeframe']
|
||||
ohlc_list = [f'%-{coin}raw_open_{tf}', f'%-{coin}raw_low_{tf}',
|
||||
f'%-{coin}raw_high_{tf}', f'%-{coin}raw_close_{tf}']
|
||||
rename_dict = {f'%-{coin}raw_open_{tf}': 'open', f'%-{coin}raw_low_{tf}': 'low',
|
||||
f'%-{coin}raw_high_{tf}': ' high', f'%-{coin}raw_close_{tf}': 'close'}
|
||||
|
||||
prices_train = train_df.filter(ohlc_list, axis=1)
|
||||
prices_train.rename(columns=rename_dict, inplace=True)
|
||||
prices_train.reset_index(drop=True)
|
||||
|
||||
prices_test = test_df.filter(ohlc_list, axis=1)
|
||||
prices_test.rename(columns=rename_dict, inplace=True)
|
||||
prices_test.reset_index(drop=True)
|
||||
|
||||
return prices_train, prices_test
|
||||
|
||||
def set_initial_historic_predictions(
|
||||
self, df: DataFrame, model: Any, dk: FreqaiDataKitchen, pair: str
|
||||
) -> None:
|
||||
|
Reference in New Issue
Block a user