add tests. add guardrails.

This commit is contained in:
robcaulk 2022-09-15 00:46:35 +02:00
parent 48140bff91
commit 8aac644009
9 changed files with 84 additions and 37 deletions

View File

@ -62,6 +62,7 @@
"train_period_days": 5, "train_period_days": 5,
"backtest_period_days": 2, "backtest_period_days": 2,
"identifier": "unique-id", "identifier": "unique-id",
"continual_learning": false,
"data_kitchen_thread_count": 2, "data_kitchen_thread_count": 2,
"feature_parameters": { "feature_parameters": {
"include_corr_pairlist": [ "include_corr_pairlist": [
@ -91,7 +92,6 @@
"max_trade_duration_candles": 300, "max_trade_duration_candles": 300,
"model_type": "PPO", "model_type": "PPO",
"policy_type": "MlpPolicy", "policy_type": "MlpPolicy",
"continual_learning": false,
"max_training_drawdown_pct": 0.5, "max_training_drawdown_pct": 0.5,
"model_reward_parameters": { "model_reward_parameters": {
"rr": 1, "rr": 1,

View File

@ -21,7 +21,7 @@ from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
from freqtrade.persistence import Trade from freqtrade.persistence import Trade
import pytest
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,7 +45,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.eval_callback: EvalCallback = None self.eval_callback: EvalCallback = None
self.model_type = self.freqai_info['rl_config']['model_type'] self.model_type = self.freqai_info['rl_config']['model_type']
self.rl_config = self.freqai_info['rl_config'] self.rl_config = self.freqai_info['rl_config']
self.continual_learning = self.rl_config.get('continual_learning', False) self.continual_learning = self.freqai_info.get('continual_learning', False)
if self.model_type in SB3_MODELS: if self.model_type in SB3_MODELS:
import_str = 'stable_baselines3' import_str = 'stable_baselines3'
elif self.model_type in SB3_CONTRIB_MODELS: elif self.model_type in SB3_CONTRIB_MODELS:
@ -59,14 +59,30 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.model_type]) self.model_type])
self.MODELCLASS = getattr(mod, self.model_type) self.MODELCLASS = getattr(mod, self.model_type)
self.policy_type = self.freqai_info['rl_config']['policy_type'] self.policy_type = self.freqai_info['rl_config']['policy_type']
self.unset_outlier_removal()
def unset_outlier_removal(self):
"""
If user has activated any function that may remove training points, this
function will set them to false and warn them
"""
if self.ft_params.get('use_SVM_to_remove_outliers', False):
self.ft_params.update({'use_SVM_to_remove_outliers': False})
logger.warning('User tried to use SVM with RL. Deactivating SVM.')
if self.ft_params.get('use_DBSCAN_to_remove_outliers', False):
self.ft_params.update({'use_SVM_to_remove_outliers': False})
logger.warning('User tried to use DBSCAN with RL. Deactivating DBSCAN.')
if self.freqai_info['data_split_parameters'].get('shuffle', False):
self.freqai_info['data_split_parameters'].update('shuffle', False)
logger.warning('User tried to shuffle training data. Setting shuffle to False')
def train( def train(
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
) -> Any: ) -> Any:
""" """
Filter the training data and train a model to it. Train makes heavy use of the datakitchen Filter the training data and train a model to it. Train makes heavy use of the datakitchen
for storing, saving, loading, and analyzing the data. for storing, saving, loading, and analyzing the data.
:param unfiltered_dataframe: Full dataframe for the current training period :param unfiltered_df: Full dataframe for the current training period
:param metadata: pair metadata from strategy. :param metadata: pair metadata from strategy.
:returns: :returns:
:model: Trained model which can be used to inference (self.predict) :model: Trained model which can be used to inference (self.predict)
@ -75,7 +91,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
logger.info("--------------------Starting training " f"{pair} --------------------") logger.info("--------------------Starting training " f"{pair} --------------------")
features_filtered, labels_filtered = dk.filter_features( features_filtered, labels_filtered = dk.filter_features(
unfiltered_dataframe, unfiltered_df,
dk.training_features_list, dk.training_features_list,
dk.label_list, dk.label_list,
training_filter=True, training_filter=True,
@ -99,7 +115,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.set_train_and_eval_environments(data_dictionary, prices_train, prices_test, dk) self.set_train_and_eval_environments(data_dictionary, prices_train, prices_test, dk)
model = self.fit_rl(data_dictionary, dk) model = self.fit(data_dictionary, dk)
logger.info(f"--------------------done training {pair}--------------------") logger.info(f"--------------------done training {pair}--------------------")
@ -124,7 +140,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
best_model_save_path=str(dk.data_path)) best_model_save_path=str(dk.data_path))
@abstractmethod @abstractmethod
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen): def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
""" """
Agent customizations and abstract Reinforcement Learning customizations Agent customizations and abstract Reinforcement Learning customizations
go in here. Abstract method, so this function must be overridden by go in here. Abstract method, so this function must be overridden by
@ -142,6 +158,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
# FIXME: mypy typing doesnt like that strategy may be "None" (it never will be) # FIXME: mypy typing doesnt like that strategy may be "None" (it never will be)
# FIXME: get_rate and trade_udration shouldn't work with backtesting, # FIXME: get_rate and trade_udration shouldn't work with backtesting,
# we need to use candle dates and prices to compute that. # we need to use candle dates and prices to compute that.
pytest.set_trace()
current_value = self.strategy.dp._exchange.get_rate( current_value = self.strategy.dp._exchange.get_rate(
pair, refresh=False, side="exit", is_short=trade.is_short) pair, refresh=False, side="exit", is_short=trade.is_short)
openrate = trade.open_rate openrate = trade.open_rate
@ -162,7 +179,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return market_side, current_profit, int(trade_duration) return market_side, current_profit, int(trade_duration)
def predict( def predict(
self, unfiltered_dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
) -> Tuple[DataFrame, npt.NDArray[np.int_]]: ) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
""" """
Filter the prediction features data and predict with it. Filter the prediction features data and predict with it.
@ -173,9 +190,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
data (NaNs) or felt uncertain about data (PCA and DI index) data (NaNs) or felt uncertain about data (PCA and DI index)
""" """
dk.find_features(unfiltered_dataframe) dk.find_features(unfiltered_df)
filtered_dataframe, _ = dk.filter_features( filtered_dataframe, _ = dk.filter_features(
unfiltered_dataframe, dk.training_features_list, training_filter=False unfiltered_df, dk.training_features_list, training_filter=False
) )
filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe) filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe)
dk.data_dictionary["prediction_features"] = filtered_dataframe dk.data_dictionary["prediction_features"] = filtered_dataframe
@ -305,8 +322,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
# But FreqaiRL needs more objects passed to fit() (like DK) and we dont want to go refactor # But FreqaiRL needs more objects passed to fit() (like DK) and we dont want to go refactor
# all the other existing fit() functions to include dk argument. For now we instantiate and # all the other existing fit() functions to include dk argument. For now we instantiate and
# leave it. # leave it.
def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any: # def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any:
return # return
def make_env(MyRLEnv: BaseEnvironment, env_id: str, rank: int, def make_env(MyRLEnv: BaseEnvironment, env_id: str, rank: int,

View File

@ -553,7 +553,8 @@ class IFreqaiModel(ABC):
# find the features indicated by strategy and store in datakitchen # find the features indicated by strategy and store in datakitchen
dk.find_features(unfiltered_dataframe) dk.find_features(unfiltered_dataframe)
# import pytest
# pytest.set_trace()
model = self.train(unfiltered_dataframe, pair, dk) model = self.train(unfiltered_dataframe, pair, dk)
self.dd.pair_dict[pair]["trained_timestamp"] = new_trained_timerange.stopts self.dd.pair_dict[pair]["trained_timestamp"] = new_trained_timerange.stopts

View File

@ -18,13 +18,13 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
User created Reinforcement Learning Model prediction model. User created Reinforcement Learning Model prediction model.
""" """
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen): def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
train_df = data_dictionary["train_features"] train_df = data_dictionary["train_features"]
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df) total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
policy_kwargs = dict(activation_fn=th.nn.ReLU, policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=[512, 512, 256]) net_arch=[128, 128])
if dk.pair not in self.dd.model_dictionary or not self.continual_learning: if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs, model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
@ -69,8 +69,8 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
factor = 100 factor = 100
# reward agent for entering trades # reward agent for entering trades
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \ if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
and self._position == Positions.Neutral: and self._position == Positions.Neutral):
return 25 return 25
# discourage agent from not entering trades # discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral: if action == Actions.Neutral.value and self._position == Positions.Neutral:
@ -85,8 +85,8 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
factor *= 0.5 factor *= 0.5
# discourage sitting in position # discourage sitting in position
if self._position in (Positions.Short, Positions.Long) and \ if (self._position in (Positions.Short, Positions.Long) and
action == Actions.Neutral.value: action == Actions.Neutral.value):
return -1 * trade_duration / max_trade_duration return -1 * trade_duration / max_trade_duration
# close long # close long

View File

@ -20,14 +20,14 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
User created Reinforcement Learning Model prediction model. User created Reinforcement Learning Model prediction model.
""" """
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen): def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
train_df = data_dictionary["train_features"] train_df = data_dictionary["train_features"]
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df) total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
# model arch # model arch
policy_kwargs = dict(activation_fn=th.nn.ReLU, policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=[256, 256, 128]) net_arch=[128, 128])
if dk.pair not in self.dd.model_dictionary or not self.continual_learning: if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs, model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,

View File

@ -29,15 +29,16 @@ def freqai_conf(default_conf, tmpdir):
"enabled": True, "enabled": True,
"startup_candles": 10000, "startup_candles": 10000,
"purge_old_models": True, "purge_old_models": True,
"train_period_days": 5, "train_period_days": 2,
"backtest_period_days": 2, "backtest_period_days": 2,
"live_retrain_hours": 0, "live_retrain_hours": 0,
"expiration_hours": 1, "expiration_hours": 1,
"identifier": "uniqe-id100", "identifier": "uniqe-id100",
"live_trained_timestamp": 0, "live_trained_timestamp": 0,
"data_kitchen_thread_count": 2,
"feature_parameters": { "feature_parameters": {
"include_timeframes": ["5m"], "include_timeframes": ["5m"],
"include_corr_pairlist": ["ADA/BTC", "DASH/BTC"], "include_corr_pairlist": ["ADA/BTC"],
"label_period_candles": 20, "label_period_candles": 20,
"include_shifted_candles": 1, "include_shifted_candles": 1,
"DI_threshold": 0.9, "DI_threshold": 0.9,
@ -47,7 +48,7 @@ def freqai_conf(default_conf, tmpdir):
"stratify_training_data": 0, "stratify_training_data": 0,
"indicator_periods_candles": [10], "indicator_periods_candles": [10],
}, },
"data_split_parameters": {"test_size": 0.33, "random_state": 1}, "data_split_parameters": {"test_size": 0.33, "shuffle": False},
"model_training_parameters": {"n_estimators": 100}, "model_training_parameters": {"n_estimators": 100},
}, },
"config_files": [Path('config_examples', 'config_freqai.example.json')] "config_files": [Path('config_examples', 'config_freqai.example.json')]

View File

@ -90,5 +90,5 @@ def test_use_strategy_to_populate_indicators(mocker, freqai_conf):
df = freqai.dk.use_strategy_to_populate_indicators(strategy, corr_df, base_df, 'LTC/BTC') df = freqai.dk.use_strategy_to_populate_indicators(strategy, corr_df, base_df, 'LTC/BTC')
assert len(df.columns) == 45 assert len(df.columns) == 33
shutil.rmtree(Path(freqai.dk.full_path)) shutil.rmtree(Path(freqai.dk.full_path))

View File

@ -72,7 +72,7 @@ def test_use_DBSCAN_to_remove_outliers(mocker, freqai_conf, caplog):
# freqai_conf['freqai']['feature_parameters'].update({"outlier_protection_percentage": 1}) # freqai_conf['freqai']['feature_parameters'].update({"outlier_protection_percentage": 1})
freqai.dk.use_DBSCAN_to_remove_outliers(predict=False) freqai.dk.use_DBSCAN_to_remove_outliers(predict=False)
assert log_has_re( assert log_has_re(
"DBSCAN found eps of 2.36.", "DBSCAN found eps of 1.75.",
caplog, caplog,
) )
@ -81,7 +81,7 @@ def test_compute_distances(mocker, freqai_conf):
freqai = make_data_dictionary(mocker, freqai_conf) freqai = make_data_dictionary(mocker, freqai_conf)
freqai_conf['freqai']['feature_parameters'].update({"DI_threshold": 1}) freqai_conf['freqai']['feature_parameters'].update({"DI_threshold": 1})
avg_mean_dist = freqai.dk.compute_distances() avg_mean_dist = freqai.dk.compute_distances()
assert round(avg_mean_dist, 2) == 2.54 assert round(avg_mean_dist, 2) == 1.99
def test_use_SVM_to_remove_outliers_and_outlier_protection(mocker, freqai_conf, caplog): def test_use_SVM_to_remove_outliers_and_outlier_protection(mocker, freqai_conf, caplog):
@ -89,7 +89,7 @@ def test_use_SVM_to_remove_outliers_and_outlier_protection(mocker, freqai_conf,
freqai_conf['freqai']['feature_parameters'].update({"outlier_protection_percentage": 0.1}) freqai_conf['freqai']['feature_parameters'].update({"outlier_protection_percentage": 0.1})
freqai.dk.use_SVM_to_remove_outliers(predict=False) freqai.dk.use_SVM_to_remove_outliers(predict=False)
assert log_has_re( assert log_has_re(
"SVM detected 8.09%", "SVM detected 7.36%",
caplog, caplog,
) )
@ -128,7 +128,7 @@ def test_normalize_data(mocker, freqai_conf):
freqai = make_data_dictionary(mocker, freqai_conf) freqai = make_data_dictionary(mocker, freqai_conf)
data_dict = freqai.dk.data_dictionary data_dict = freqai.dk.data_dictionary
freqai.dk.normalize_data(data_dict) freqai.dk.normalize_data(data_dict)
assert len(freqai.dk.data) == 56 assert len(freqai.dk.data) == 32
def test_filter_features(mocker, freqai_conf): def test_filter_features(mocker, freqai_conf):
@ -142,7 +142,7 @@ def test_filter_features(mocker, freqai_conf):
training_filter=True, training_filter=True,
) )
assert len(filtered_df.columns) == 26 assert len(filtered_df.columns) == 14
def test_make_train_test_datasets(mocker, freqai_conf): def test_make_train_test_datasets(mocker, freqai_conf):

View File

@ -21,15 +21,40 @@ def is_arm() -> bool:
'LightGBMRegressor', 'LightGBMRegressor',
'XGBoostRegressor', 'XGBoostRegressor',
'CatboostRegressor', 'CatboostRegressor',
'ReinforcementLearner',
'ReinforcementLearner_multiproc'
]) ])
def test_extract_data_and_train_model_Regressors(mocker, freqai_conf, model): def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model):
if is_arm() and model == 'CatboostRegressor': if is_arm() and model == 'CatboostRegressor':
pytest.skip("CatBoost is not supported on ARM") pytest.skip("CatBoost is not supported on ARM")
model_save_ext = 'joblib'
freqai_conf.update({"freqaimodel": model}) freqai_conf.update({"freqaimodel": model})
freqai_conf.update({"timerange": "20180110-20180130"}) freqai_conf.update({"timerange": "20180110-20180130"})
freqai_conf.update({"strategy": "freqai_test_strat"}) freqai_conf.update({"strategy": "freqai_test_strat"})
if 'ReinforcementLearner' in model:
model_save_ext = 'zip'
freqai_conf.update({"strategy": "freqai_rl_test_strat"})
freqai_conf["freqai"].update({"model_training_parameters": {
"learning_rate": 0.00025,
"gamma": 0.9,
"verbose": 1
}})
freqai_conf["freqai"].update({"model_save_type": 'stable_baselines'})
freqai_conf["freqai"]["rl_config"] = {
"train_cycles": 1,
"thread_count": 2,
"max_trade_duration_candles": 300,
"model_type": "PPO",
"policy_type": "MlpPolicy",
"max_training_drawdown_pct": 0.5,
"model_reward_parameters": {
"rr": 1,
"profit_aim": 0.02,
"win_reward_factor": 2
}}
strategy = get_patched_freqai_strategy(mocker, freqai_conf) strategy = get_patched_freqai_strategy(mocker, freqai_conf)
exchange = get_patched_exchange(mocker, freqai_conf) exchange = get_patched_exchange(mocker, freqai_conf)
strategy.dp = DataProvider(freqai_conf, exchange) strategy.dp = DataProvider(freqai_conf, exchange)
@ -42,16 +67,19 @@ def test_extract_data_and_train_model_Regressors(mocker, freqai_conf, model):
freqai.dd.pair_dict = MagicMock() freqai.dd.pair_dict = MagicMock()
data_load_timerange = TimeRange.parse_timerange("20180110-20180130") data_load_timerange = TimeRange.parse_timerange("20180125-20180130")
new_timerange = TimeRange.parse_timerange("20180120-20180130") new_timerange = TimeRange.parse_timerange("20180127-20180130")
freqai.extract_data_and_train_model( freqai.extract_data_and_train_model(
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange) new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").is_file() assert Path(freqai.dk.data_path /
f"{freqai.dk.model_filename}_model.{model_save_ext}").is_file()
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file() assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file()
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file() assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file()
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").is_file() # if 'ReinforcementLearner' not in model:
# assert Path(freqai.dk.data_path /
# f"{freqai.dk.model_filename}_svm_model.joblib").is_file()
shutil.rmtree(Path(freqai.dk.full_path)) shutil.rmtree(Path(freqai.dk.full_path))
@ -91,7 +119,7 @@ def test_extract_data_and_train_model_MultiTargets(mocker, freqai_conf, model):
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file() assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file()
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file() assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file()
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").is_file() assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").is_file()
assert len(freqai.dk.data['training_features_list']) == 26 assert len(freqai.dk.data['training_features_list']) == 14
shutil.rmtree(Path(freqai.dk.full_path)) shutil.rmtree(Path(freqai.dk.full_path))