Merge branch 'develop' into freqai_bt_from_predictions_improvement

This commit is contained in:
Wagner Costa
2022-12-05 18:00:55 -03:00
14 changed files with 54 additions and 30 deletions

View File

@@ -194,12 +194,12 @@ class BaseEnvironment(gym.Env):
if self._position == Positions.Neutral:
return 0.
elif self._position == Positions.Short:
current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open)
last_trade_price = self.add_entry_fee(self.prices.iloc[self._last_trade_tick].open)
return (last_trade_price - current_price) / last_trade_price
elif self._position == Positions.Long:
current_price = self.add_entry_fee(self.prices.iloc[self._current_tick].open)
last_trade_price = self.add_exit_fee(self.prices.iloc[self._last_trade_tick].open)
return (last_trade_price - current_price) / last_trade_price
elif self._position == Positions.Long:
current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open)
last_trade_price = self.add_entry_fee(self.prices.iloc[self._last_trade_tick].open)
return (current_price - last_trade_price) / last_trade_price
else:
return 0.

View File

@@ -64,7 +64,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.policy_type = self.freqai_info['rl_config']['policy_type']
self.unset_outlier_removal()
self.net_arch = self.rl_config.get('net_arch', [128, 128])
self.dd.model_type = "stable_baselines"
self.dd.model_type = import_str
def unset_outlier_removal(self):
"""

View File

@@ -503,7 +503,7 @@ class FreqaiDataDrawer:
dump(model, save_path / f"{dk.model_filename}_model.joblib")
elif self.model_type == 'keras':
model.save(save_path / f"{dk.model_filename}_model.h5")
elif 'stable_baselines' in self.model_type:
elif 'stable_baselines' in self.model_type or 'sb3_contrib' == self.model_type:
model.save(save_path / f"{dk.model_filename}_model.zip")
if dk.svm_model is not None:
@@ -589,9 +589,9 @@ class FreqaiDataDrawer:
elif self.model_type == 'keras':
from tensorflow import keras
model = keras.models.load_model(dk.data_path / f"{dk.model_filename}_model.h5")
elif self.model_type == 'stable_baselines':
elif 'stable_baselines' in self.model_type or 'sb3_contrib' == self.model_type:
mod = importlib.import_module(
'stable_baselines3', self.freqai_info['rl_config']['model_type'])
self.model_type, self.freqai_info['rl_config']['model_type'])
MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type'])
model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")

View File

@@ -7,6 +7,8 @@ import logging
import sys
from typing import Any, List
from freqtrade.util.gc_setup import gc_set_threshold
# check min. python version
if sys.version_info < (3, 8): # pragma: no cover
@@ -36,6 +38,7 @@ def main(sysargv: List[str] = None) -> None:
# Call subcommand.
if 'func' in args:
logger.info(f'freqtrade {__version__}')
gc_set_threshold()
return_code = args['func'](args)
else:
# No subcommand was issued.

View File

@@ -0,0 +1,18 @@
import gc
import logging
import platform
logger = logging.getLogger(__name__)
def gc_set_threshold():
"""
Reduce number of GC runs to improve performance (explanation video)
https://www.youtube.com/watch?v=p4Sn6UcFTOU
"""
if platform.python_implementation() == "CPython":
# allocs, g1, g2 = gc.get_threshold()
gc.set_threshold(50_000, 500, 1000)
logger.debug("Adjusting python allocations to reduce GC runs")