From 16cec7dfbd51f34c479d842ca023c8cd34aa79a7 Mon Sep 17 00:00:00 2001 From: robcaulk Date: Tue, 16 Aug 2022 12:18:06 +0200 Subject: [PATCH] fix save/reload functionality for stablebaselines --- .../config_reinforcementlearning_example.json | 110 ------------------ freqtrade/freqai/data_drawer.py | 6 +- 2 files changed, 3 insertions(+), 113 deletions(-) delete mode 100644 config_examples/config_reinforcementlearning_example.json diff --git a/config_examples/config_reinforcementlearning_example.json b/config_examples/config_reinforcementlearning_example.json deleted file mode 100644 index 29f088ef3..000000000 --- a/config_examples/config_reinforcementlearning_example.json +++ /dev/null @@ -1,110 +0,0 @@ -{ - "trading_mode": "futures", - "new_pairs_days": 30, - "margin_mode": "isolated", - "max_open_trades": 8, - "stake_currency": "USDT", - "stake_amount": 1000, - "tradable_balance_ratio": 1, - "fiat_display_currency": "USD", - "dry_run": true, - "timeframe": "5m", - "dataformat_ohlcv": "json", - "dry_run_wallet": 12000, - "cancel_open_orders_on_exit": true, - "unfilledtimeout": { - "entry": 10, - "exit": 30 - }, - "exchange": { - "name": "binance", - "key": "", - "secret": "", - "ccxt_config": { - "enableRateLimit": true - }, - "ccxt_async_config": { - "enableRateLimit": true, - "rateLimit": 200 - }, - "pair_whitelist": [ - "1INCH/USDT", - "AAVE/USDT" - ], - "pair_blacklist": [] - }, - "entry_pricing": { - "price_side": "same", - "purge_old_models": true, - "use_order_book": true, - "order_book_top": 1, - "price_last_balance": 0.0, - "check_depth_of_market": { - "enabled": false, - "bids_to_ask_delta": 1 - } - }, - "exit_pricing": { - "price_side": "other", - "use_order_book": true, - "order_book_top": 1 - }, - "pairlists": [ - { - "method": "StaticPairList" - } - ], - "freqai": { - "model_save_type": "stable_baselines", - "conv_width": 10, - "follow_mode": false, - "purge_old_models": true, - "expiration_hours": 1, - "train_period_days": 10, - "backtest_period_days": 2, - "identifier": "test_rl10", - "feature_parameters": { - "include_corr_pairlist": [ - "BTC/USDT", - "ETH/USDT" - ], - "include_timeframes": [ - "15m", - "30m" - ], - "label_period_candles": 80, - "include_shifted_candles": 0, - "DI_threshold": 0, - "weight_factor": 0.9, - "principal_component_analysis": false, - "use_SVM_to_remove_outliers": false, - "svm_params": {"shuffle": true, "nu": 0.1}, - "stratify_training_data": 0, - "indicator_max_period_candles": 10, - "indicator_periods_candles": [5] - }, - "data_split_parameters": { - "test_size": 0.5, - "random_state": 1, - "shuffle": false - }, - "model_training_parameters": { - "n_steps": 2048, - "ent_coef": 0.005, - "learning_rate": 0.000025, - "batch_size": 256, - "eval_cycles" : 5, - "train_cycles" : 15 - }, - "model_reward_parameters": { - "rr": 1, - "profit_aim": 0.01 - } - }, - "bot_name": "RL_test", - "force_entry_enable": true, - "initial_state": "running", - "internals": { - "process_throttle_secs": 5 - } -} \ No newline at end of file diff --git a/freqtrade/freqai/data_drawer.py b/freqtrade/freqai/data_drawer.py index 68f688ed4..9603fb9ab 100644 --- a/freqtrade/freqai/data_drawer.py +++ b/freqtrade/freqai/data_drawer.py @@ -395,7 +395,7 @@ class FreqaiDataDrawer: dump(model, save_path / f"{dk.model_filename}_model.joblib") elif model_type == 'keras': model.save(save_path / f"{dk.model_filename}_model.h5") - elif model_type == 'stable_baselines': + elif 'stable_baselines' in model_type: model.save(save_path / f"{dk.model_filename}_model.zip") if dk.svm_model is not None: @@ -473,10 +473,10 @@ class FreqaiDataDrawer: model = keras.models.load_model(dk.data_path / f"{dk.model_filename}_model.h5") elif model_type == 'stable_baselines_ppo': from stable_baselines3.ppo.ppo import PPO - model = PPO.load(dk.data_path / f"{dk.model_filename}_model.zip") + model = PPO.load(dk.data_path / f"{dk.model_filename}_model") elif model_type == 'stable_baselines_dqn': from stable_baselines3 import DQN - model = DQN.load(dk.data_path / f"{dk.model_filename}_model.zip") + model = DQN.load(dk.data_path / f"{dk.model_filename}_model") if Path(dk.data_path / f"{dk.model_filename}_svm_model.joblib").is_file(): dk.svm_model = load(dk.data_path / f"{dk.model_filename}_svm_model.joblib")