add tests. add guardrails.
This commit is contained in:
parent
48140bff91
commit
8aac644009
@ -62,6 +62,7 @@
|
|||||||
"train_period_days": 5,
|
"train_period_days": 5,
|
||||||
"backtest_period_days": 2,
|
"backtest_period_days": 2,
|
||||||
"identifier": "unique-id",
|
"identifier": "unique-id",
|
||||||
|
"continual_learning": false,
|
||||||
"data_kitchen_thread_count": 2,
|
"data_kitchen_thread_count": 2,
|
||||||
"feature_parameters": {
|
"feature_parameters": {
|
||||||
"include_corr_pairlist": [
|
"include_corr_pairlist": [
|
||||||
@ -91,7 +92,6 @@
|
|||||||
"max_trade_duration_candles": 300,
|
"max_trade_duration_candles": 300,
|
||||||
"model_type": "PPO",
|
"model_type": "PPO",
|
||||||
"policy_type": "MlpPolicy",
|
"policy_type": "MlpPolicy",
|
||||||
"continual_learning": false,
|
|
||||||
"max_training_drawdown_pct": 0.5,
|
"max_training_drawdown_pct": 0.5,
|
||||||
"model_reward_parameters": {
|
"model_reward_parameters": {
|
||||||
"rr": 1,
|
"rr": 1,
|
||||||
|
@ -21,7 +21,7 @@ from freqtrade.freqai.freqai_interface import IFreqaiModel
|
|||||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
||||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||||
from freqtrade.persistence import Trade
|
from freqtrade.persistence import Trade
|
||||||
|
import pytest
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
self.eval_callback: EvalCallback = None
|
self.eval_callback: EvalCallback = None
|
||||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||||
self.rl_config = self.freqai_info['rl_config']
|
self.rl_config = self.freqai_info['rl_config']
|
||||||
self.continual_learning = self.rl_config.get('continual_learning', False)
|
self.continual_learning = self.freqai_info.get('continual_learning', False)
|
||||||
if self.model_type in SB3_MODELS:
|
if self.model_type in SB3_MODELS:
|
||||||
import_str = 'stable_baselines3'
|
import_str = 'stable_baselines3'
|
||||||
elif self.model_type in SB3_CONTRIB_MODELS:
|
elif self.model_type in SB3_CONTRIB_MODELS:
|
||||||
@ -59,14 +59,30 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
self.model_type])
|
self.model_type])
|
||||||
self.MODELCLASS = getattr(mod, self.model_type)
|
self.MODELCLASS = getattr(mod, self.model_type)
|
||||||
self.policy_type = self.freqai_info['rl_config']['policy_type']
|
self.policy_type = self.freqai_info['rl_config']['policy_type']
|
||||||
|
self.unset_outlier_removal()
|
||||||
|
|
||||||
|
def unset_outlier_removal(self):
|
||||||
|
"""
|
||||||
|
If user has activated any function that may remove training points, this
|
||||||
|
function will set them to false and warn them
|
||||||
|
"""
|
||||||
|
if self.ft_params.get('use_SVM_to_remove_outliers', False):
|
||||||
|
self.ft_params.update({'use_SVM_to_remove_outliers': False})
|
||||||
|
logger.warning('User tried to use SVM with RL. Deactivating SVM.')
|
||||||
|
if self.ft_params.get('use_DBSCAN_to_remove_outliers', False):
|
||||||
|
self.ft_params.update({'use_SVM_to_remove_outliers': False})
|
||||||
|
logger.warning('User tried to use DBSCAN with RL. Deactivating DBSCAN.')
|
||||||
|
if self.freqai_info['data_split_parameters'].get('shuffle', False):
|
||||||
|
self.freqai_info['data_split_parameters'].update('shuffle', False)
|
||||||
|
logger.warning('User tried to shuffle training data. Setting shuffle to False')
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen
|
self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Filter the training data and train a model to it. Train makes heavy use of the datakitchen
|
Filter the training data and train a model to it. Train makes heavy use of the datakitchen
|
||||||
for storing, saving, loading, and analyzing the data.
|
for storing, saving, loading, and analyzing the data.
|
||||||
:param unfiltered_dataframe: Full dataframe for the current training period
|
:param unfiltered_df: Full dataframe for the current training period
|
||||||
:param metadata: pair metadata from strategy.
|
:param metadata: pair metadata from strategy.
|
||||||
:returns:
|
:returns:
|
||||||
:model: Trained model which can be used to inference (self.predict)
|
:model: Trained model which can be used to inference (self.predict)
|
||||||
@ -75,7 +91,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
logger.info("--------------------Starting training " f"{pair} --------------------")
|
logger.info("--------------------Starting training " f"{pair} --------------------")
|
||||||
|
|
||||||
features_filtered, labels_filtered = dk.filter_features(
|
features_filtered, labels_filtered = dk.filter_features(
|
||||||
unfiltered_dataframe,
|
unfiltered_df,
|
||||||
dk.training_features_list,
|
dk.training_features_list,
|
||||||
dk.label_list,
|
dk.label_list,
|
||||||
training_filter=True,
|
training_filter=True,
|
||||||
@ -99,7 +115,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
self.set_train_and_eval_environments(data_dictionary, prices_train, prices_test, dk)
|
self.set_train_and_eval_environments(data_dictionary, prices_train, prices_test, dk)
|
||||||
|
|
||||||
model = self.fit_rl(data_dictionary, dk)
|
model = self.fit(data_dictionary, dk)
|
||||||
|
|
||||||
logger.info(f"--------------------done training {pair}--------------------")
|
logger.info(f"--------------------done training {pair}--------------------")
|
||||||
|
|
||||||
@ -124,7 +140,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
best_model_save_path=str(dk.data_path))
|
best_model_save_path=str(dk.data_path))
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
|
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||||
"""
|
"""
|
||||||
Agent customizations and abstract Reinforcement Learning customizations
|
Agent customizations and abstract Reinforcement Learning customizations
|
||||||
go in here. Abstract method, so this function must be overridden by
|
go in here. Abstract method, so this function must be overridden by
|
||||||
@ -142,6 +158,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
# FIXME: mypy typing doesnt like that strategy may be "None" (it never will be)
|
# FIXME: mypy typing doesnt like that strategy may be "None" (it never will be)
|
||||||
# FIXME: get_rate and trade_udration shouldn't work with backtesting,
|
# FIXME: get_rate and trade_udration shouldn't work with backtesting,
|
||||||
# we need to use candle dates and prices to compute that.
|
# we need to use candle dates and prices to compute that.
|
||||||
|
pytest.set_trace()
|
||||||
current_value = self.strategy.dp._exchange.get_rate(
|
current_value = self.strategy.dp._exchange.get_rate(
|
||||||
pair, refresh=False, side="exit", is_short=trade.is_short)
|
pair, refresh=False, side="exit", is_short=trade.is_short)
|
||||||
openrate = trade.open_rate
|
openrate = trade.open_rate
|
||||||
@ -162,7 +179,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
return market_side, current_profit, int(trade_duration)
|
return market_side, current_profit, int(trade_duration)
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self, unfiltered_dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False
|
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||||
"""
|
"""
|
||||||
Filter the prediction features data and predict with it.
|
Filter the prediction features data and predict with it.
|
||||||
@ -173,9 +190,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
data (NaNs) or felt uncertain about data (PCA and DI index)
|
data (NaNs) or felt uncertain about data (PCA and DI index)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dk.find_features(unfiltered_dataframe)
|
dk.find_features(unfiltered_df)
|
||||||
filtered_dataframe, _ = dk.filter_features(
|
filtered_dataframe, _ = dk.filter_features(
|
||||||
unfiltered_dataframe, dk.training_features_list, training_filter=False
|
unfiltered_df, dk.training_features_list, training_filter=False
|
||||||
)
|
)
|
||||||
filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe)
|
filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe)
|
||||||
dk.data_dictionary["prediction_features"] = filtered_dataframe
|
dk.data_dictionary["prediction_features"] = filtered_dataframe
|
||||||
@ -305,8 +322,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
# But FreqaiRL needs more objects passed to fit() (like DK) and we dont want to go refactor
|
# But FreqaiRL needs more objects passed to fit() (like DK) and we dont want to go refactor
|
||||||
# all the other existing fit() functions to include dk argument. For now we instantiate and
|
# all the other existing fit() functions to include dk argument. For now we instantiate and
|
||||||
# leave it.
|
# leave it.
|
||||||
def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any:
|
# def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any:
|
||||||
return
|
# return
|
||||||
|
|
||||||
|
|
||||||
def make_env(MyRLEnv: BaseEnvironment, env_id: str, rank: int,
|
def make_env(MyRLEnv: BaseEnvironment, env_id: str, rank: int,
|
||||||
|
@ -553,7 +553,8 @@ class IFreqaiModel(ABC):
|
|||||||
|
|
||||||
# find the features indicated by strategy and store in datakitchen
|
# find the features indicated by strategy and store in datakitchen
|
||||||
dk.find_features(unfiltered_dataframe)
|
dk.find_features(unfiltered_dataframe)
|
||||||
|
# import pytest
|
||||||
|
# pytest.set_trace()
|
||||||
model = self.train(unfiltered_dataframe, pair, dk)
|
model = self.train(unfiltered_dataframe, pair, dk)
|
||||||
|
|
||||||
self.dd.pair_dict[pair]["trained_timestamp"] = new_trained_timerange.stopts
|
self.dd.pair_dict[pair]["trained_timestamp"] = new_trained_timerange.stopts
|
||||||
|
@ -18,13 +18,13 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
User created Reinforcement Learning Model prediction model.
|
User created Reinforcement Learning Model prediction model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
|
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||||
|
|
||||||
train_df = data_dictionary["train_features"]
|
train_df = data_dictionary["train_features"]
|
||||||
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
||||||
|
|
||||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||||
net_arch=[512, 512, 256])
|
net_arch=[128, 128])
|
||||||
|
|
||||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
||||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
||||||
@ -69,8 +69,8 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
factor = 100
|
factor = 100
|
||||||
|
|
||||||
# reward agent for entering trades
|
# reward agent for entering trades
|
||||||
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \
|
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
||||||
and self._position == Positions.Neutral:
|
and self._position == Positions.Neutral):
|
||||||
return 25
|
return 25
|
||||||
# discourage agent from not entering trades
|
# discourage agent from not entering trades
|
||||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||||
@ -85,8 +85,8 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
factor *= 0.5
|
factor *= 0.5
|
||||||
|
|
||||||
# discourage sitting in position
|
# discourage sitting in position
|
||||||
if self._position in (Positions.Short, Positions.Long) and \
|
if (self._position in (Positions.Short, Positions.Long) and
|
||||||
action == Actions.Neutral.value:
|
action == Actions.Neutral.value):
|
||||||
return -1 * trade_duration / max_trade_duration
|
return -1 * trade_duration / max_trade_duration
|
||||||
|
|
||||||
# close long
|
# close long
|
||||||
|
@ -20,14 +20,14 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
|
|||||||
User created Reinforcement Learning Model prediction model.
|
User created Reinforcement Learning Model prediction model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
|
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||||
|
|
||||||
train_df = data_dictionary["train_features"]
|
train_df = data_dictionary["train_features"]
|
||||||
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
||||||
|
|
||||||
# model arch
|
# model arch
|
||||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||||
net_arch=[256, 256, 128])
|
net_arch=[128, 128])
|
||||||
|
|
||||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
||||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
||||||
|
@ -29,15 +29,16 @@ def freqai_conf(default_conf, tmpdir):
|
|||||||
"enabled": True,
|
"enabled": True,
|
||||||
"startup_candles": 10000,
|
"startup_candles": 10000,
|
||||||
"purge_old_models": True,
|
"purge_old_models": True,
|
||||||
"train_period_days": 5,
|
"train_period_days": 2,
|
||||||
"backtest_period_days": 2,
|
"backtest_period_days": 2,
|
||||||
"live_retrain_hours": 0,
|
"live_retrain_hours": 0,
|
||||||
"expiration_hours": 1,
|
"expiration_hours": 1,
|
||||||
"identifier": "uniqe-id100",
|
"identifier": "uniqe-id100",
|
||||||
"live_trained_timestamp": 0,
|
"live_trained_timestamp": 0,
|
||||||
|
"data_kitchen_thread_count": 2,
|
||||||
"feature_parameters": {
|
"feature_parameters": {
|
||||||
"include_timeframes": ["5m"],
|
"include_timeframes": ["5m"],
|
||||||
"include_corr_pairlist": ["ADA/BTC", "DASH/BTC"],
|
"include_corr_pairlist": ["ADA/BTC"],
|
||||||
"label_period_candles": 20,
|
"label_period_candles": 20,
|
||||||
"include_shifted_candles": 1,
|
"include_shifted_candles": 1,
|
||||||
"DI_threshold": 0.9,
|
"DI_threshold": 0.9,
|
||||||
@ -47,7 +48,7 @@ def freqai_conf(default_conf, tmpdir):
|
|||||||
"stratify_training_data": 0,
|
"stratify_training_data": 0,
|
||||||
"indicator_periods_candles": [10],
|
"indicator_periods_candles": [10],
|
||||||
},
|
},
|
||||||
"data_split_parameters": {"test_size": 0.33, "random_state": 1},
|
"data_split_parameters": {"test_size": 0.33, "shuffle": False},
|
||||||
"model_training_parameters": {"n_estimators": 100},
|
"model_training_parameters": {"n_estimators": 100},
|
||||||
},
|
},
|
||||||
"config_files": [Path('config_examples', 'config_freqai.example.json')]
|
"config_files": [Path('config_examples', 'config_freqai.example.json')]
|
||||||
|
@ -90,5 +90,5 @@ def test_use_strategy_to_populate_indicators(mocker, freqai_conf):
|
|||||||
|
|
||||||
df = freqai.dk.use_strategy_to_populate_indicators(strategy, corr_df, base_df, 'LTC/BTC')
|
df = freqai.dk.use_strategy_to_populate_indicators(strategy, corr_df, base_df, 'LTC/BTC')
|
||||||
|
|
||||||
assert len(df.columns) == 45
|
assert len(df.columns) == 33
|
||||||
shutil.rmtree(Path(freqai.dk.full_path))
|
shutil.rmtree(Path(freqai.dk.full_path))
|
||||||
|
@ -72,7 +72,7 @@ def test_use_DBSCAN_to_remove_outliers(mocker, freqai_conf, caplog):
|
|||||||
# freqai_conf['freqai']['feature_parameters'].update({"outlier_protection_percentage": 1})
|
# freqai_conf['freqai']['feature_parameters'].update({"outlier_protection_percentage": 1})
|
||||||
freqai.dk.use_DBSCAN_to_remove_outliers(predict=False)
|
freqai.dk.use_DBSCAN_to_remove_outliers(predict=False)
|
||||||
assert log_has_re(
|
assert log_has_re(
|
||||||
"DBSCAN found eps of 2.36.",
|
"DBSCAN found eps of 1.75.",
|
||||||
caplog,
|
caplog,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ def test_compute_distances(mocker, freqai_conf):
|
|||||||
freqai = make_data_dictionary(mocker, freqai_conf)
|
freqai = make_data_dictionary(mocker, freqai_conf)
|
||||||
freqai_conf['freqai']['feature_parameters'].update({"DI_threshold": 1})
|
freqai_conf['freqai']['feature_parameters'].update({"DI_threshold": 1})
|
||||||
avg_mean_dist = freqai.dk.compute_distances()
|
avg_mean_dist = freqai.dk.compute_distances()
|
||||||
assert round(avg_mean_dist, 2) == 2.54
|
assert round(avg_mean_dist, 2) == 1.99
|
||||||
|
|
||||||
|
|
||||||
def test_use_SVM_to_remove_outliers_and_outlier_protection(mocker, freqai_conf, caplog):
|
def test_use_SVM_to_remove_outliers_and_outlier_protection(mocker, freqai_conf, caplog):
|
||||||
@ -89,7 +89,7 @@ def test_use_SVM_to_remove_outliers_and_outlier_protection(mocker, freqai_conf,
|
|||||||
freqai_conf['freqai']['feature_parameters'].update({"outlier_protection_percentage": 0.1})
|
freqai_conf['freqai']['feature_parameters'].update({"outlier_protection_percentage": 0.1})
|
||||||
freqai.dk.use_SVM_to_remove_outliers(predict=False)
|
freqai.dk.use_SVM_to_remove_outliers(predict=False)
|
||||||
assert log_has_re(
|
assert log_has_re(
|
||||||
"SVM detected 8.09%",
|
"SVM detected 7.36%",
|
||||||
caplog,
|
caplog,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -128,7 +128,7 @@ def test_normalize_data(mocker, freqai_conf):
|
|||||||
freqai = make_data_dictionary(mocker, freqai_conf)
|
freqai = make_data_dictionary(mocker, freqai_conf)
|
||||||
data_dict = freqai.dk.data_dictionary
|
data_dict = freqai.dk.data_dictionary
|
||||||
freqai.dk.normalize_data(data_dict)
|
freqai.dk.normalize_data(data_dict)
|
||||||
assert len(freqai.dk.data) == 56
|
assert len(freqai.dk.data) == 32
|
||||||
|
|
||||||
|
|
||||||
def test_filter_features(mocker, freqai_conf):
|
def test_filter_features(mocker, freqai_conf):
|
||||||
@ -142,7 +142,7 @@ def test_filter_features(mocker, freqai_conf):
|
|||||||
training_filter=True,
|
training_filter=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(filtered_df.columns) == 26
|
assert len(filtered_df.columns) == 14
|
||||||
|
|
||||||
|
|
||||||
def test_make_train_test_datasets(mocker, freqai_conf):
|
def test_make_train_test_datasets(mocker, freqai_conf):
|
||||||
|
@ -21,15 +21,40 @@ def is_arm() -> bool:
|
|||||||
'LightGBMRegressor',
|
'LightGBMRegressor',
|
||||||
'XGBoostRegressor',
|
'XGBoostRegressor',
|
||||||
'CatboostRegressor',
|
'CatboostRegressor',
|
||||||
|
'ReinforcementLearner',
|
||||||
|
'ReinforcementLearner_multiproc'
|
||||||
])
|
])
|
||||||
def test_extract_data_and_train_model_Regressors(mocker, freqai_conf, model):
|
def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model):
|
||||||
if is_arm() and model == 'CatboostRegressor':
|
if is_arm() and model == 'CatboostRegressor':
|
||||||
pytest.skip("CatBoost is not supported on ARM")
|
pytest.skip("CatBoost is not supported on ARM")
|
||||||
|
|
||||||
|
model_save_ext = 'joblib'
|
||||||
freqai_conf.update({"freqaimodel": model})
|
freqai_conf.update({"freqaimodel": model})
|
||||||
freqai_conf.update({"timerange": "20180110-20180130"})
|
freqai_conf.update({"timerange": "20180110-20180130"})
|
||||||
freqai_conf.update({"strategy": "freqai_test_strat"})
|
freqai_conf.update({"strategy": "freqai_test_strat"})
|
||||||
|
|
||||||
|
if 'ReinforcementLearner' in model:
|
||||||
|
model_save_ext = 'zip'
|
||||||
|
freqai_conf.update({"strategy": "freqai_rl_test_strat"})
|
||||||
|
freqai_conf["freqai"].update({"model_training_parameters": {
|
||||||
|
"learning_rate": 0.00025,
|
||||||
|
"gamma": 0.9,
|
||||||
|
"verbose": 1
|
||||||
|
}})
|
||||||
|
freqai_conf["freqai"].update({"model_save_type": 'stable_baselines'})
|
||||||
|
freqai_conf["freqai"]["rl_config"] = {
|
||||||
|
"train_cycles": 1,
|
||||||
|
"thread_count": 2,
|
||||||
|
"max_trade_duration_candles": 300,
|
||||||
|
"model_type": "PPO",
|
||||||
|
"policy_type": "MlpPolicy",
|
||||||
|
"max_training_drawdown_pct": 0.5,
|
||||||
|
"model_reward_parameters": {
|
||||||
|
"rr": 1,
|
||||||
|
"profit_aim": 0.02,
|
||||||
|
"win_reward_factor": 2
|
||||||
|
}}
|
||||||
|
|
||||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
||||||
exchange = get_patched_exchange(mocker, freqai_conf)
|
exchange = get_patched_exchange(mocker, freqai_conf)
|
||||||
strategy.dp = DataProvider(freqai_conf, exchange)
|
strategy.dp = DataProvider(freqai_conf, exchange)
|
||||||
@ -42,16 +67,19 @@ def test_extract_data_and_train_model_Regressors(mocker, freqai_conf, model):
|
|||||||
|
|
||||||
freqai.dd.pair_dict = MagicMock()
|
freqai.dd.pair_dict = MagicMock()
|
||||||
|
|
||||||
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
data_load_timerange = TimeRange.parse_timerange("20180125-20180130")
|
||||||
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
new_timerange = TimeRange.parse_timerange("20180127-20180130")
|
||||||
|
|
||||||
freqai.extract_data_and_train_model(
|
freqai.extract_data_and_train_model(
|
||||||
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
|
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
|
||||||
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").is_file()
|
assert Path(freqai.dk.data_path /
|
||||||
|
f"{freqai.dk.model_filename}_model.{model_save_ext}").is_file()
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file()
|
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file()
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file()
|
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file()
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").is_file()
|
# if 'ReinforcementLearner' not in model:
|
||||||
|
# assert Path(freqai.dk.data_path /
|
||||||
|
# f"{freqai.dk.model_filename}_svm_model.joblib").is_file()
|
||||||
|
|
||||||
shutil.rmtree(Path(freqai.dk.full_path))
|
shutil.rmtree(Path(freqai.dk.full_path))
|
||||||
|
|
||||||
@ -91,7 +119,7 @@ def test_extract_data_and_train_model_MultiTargets(mocker, freqai_conf, model):
|
|||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file()
|
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file()
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file()
|
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file()
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").is_file()
|
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").is_file()
|
||||||
assert len(freqai.dk.data['training_features_list']) == 26
|
assert len(freqai.dk.data['training_features_list']) == 14
|
||||||
|
|
||||||
shutil.rmtree(Path(freqai.dk.full_path))
|
shutil.rmtree(Path(freqai.dk.full_path))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user