add continual retraining feature, handly mypy typing reqs, improve docstrings

This commit is contained in:
robcaulk
2022-08-24 12:54:02 +02:00
parent b708134c1a
commit c0cee5df07
11 changed files with 387 additions and 362 deletions

View File

@@ -19,6 +19,7 @@ from typing import Callable
from datetime import datetime, timezone
from stable_baselines3.common.utils import set_random_seed
import gym
from pathlib import Path
logger = logging.getLogger(__name__)
torch.multiprocessing.set_sharing_strategy('file_system')
@@ -40,6 +41,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.eval_env: Base5ActionRLEnv = None
self.eval_callback: EvalCallback = None
self.model_type = self.freqai_info['rl_config']['model_type']
self.rl_config = self.freqai_info['rl_config']
self.continual_retraining = self.rl_config['continual_retraining']
if self.model_type in SB3_MODELS:
import_str = 'stable_baselines3'
elif self.model_type in SB3_CONTRIB_MODELS:
@@ -68,7 +71,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
logger.info("--------------------Starting training " f"{pair} --------------------")
# filter the features requested by user in the configuration file and elegantly handle NaNs
features_filtered, labels_filtered = dk.filter_features(
unfiltered_dataframe,
dk.training_features_list,
@@ -78,19 +80,19 @@ class BaseReinforcementLearningModel(IFreqaiModel):
data_dictionary: Dict[str, Any] = dk.make_train_test_datasets(
features_filtered, labels_filtered)
dk.fit_labels() # useless for now, but just satiating append methods
dk.fit_labels() # FIXME useless for now, but just satiating append methods
# normalize all data based on train_dataset only
prices_train, prices_test = self.build_ohlc_price_dataframes(dk.data_dictionary, pair, dk)
data_dictionary = dk.normalize_data(data_dictionary)
# optional additional data cleaning/analysis
# data cleaning/analysis
self.data_cleaning_train(dk)
logger.info(
f'Training model on {len(dk.data_dictionary["train_features"].columns)}' " features"
f'Training model on {len(dk.data_dictionary["train_features"].columns)}'
f' features and {len(data_dictionary["train_features"])} data points'
)
logger.info(f'Training model on {len(data_dictionary["train_features"])} data points')
self.set_train_and_eval_environments(data_dictionary, prices_train, prices_test, dk)
@@ -100,9 +102,11 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return model
def set_train_and_eval_environments(self, data_dictionary, prices_train, prices_test, dk):
def set_train_and_eval_environments(self, data_dictionary: Dict[str, DataFrame],
prices_train: DataFrame, prices_test: DataFrame,
dk: FreqaiDataKitchen):
"""
User overrides this as shown here if they are using a custom MyRLEnv
User can override this if they are using a custom MyRLEnv
"""
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
@@ -114,18 +118,22 @@ class BaseReinforcementLearningModel(IFreqaiModel):
reward_kwargs=self.reward_params, config=self.config)
self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test,
window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params, config=self.config), ".")
reward_kwargs=self.reward_params, config=self.config),
str(Path(dk.data_path / 'monitor')))
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=eval_freq,
best_model_save_path=dk.data_path)
best_model_save_path=str(dk.data_path))
else:
self.train_env.reset()
self.eval_env.reset()
self.train_env.reset_env(train_df, prices_train, self.CONV_WIDTH, self.reward_params)
self.eval_env.reset_env(test_df, prices_test, self.CONV_WIDTH, self.reward_params)
# self.eval_callback.eval_env = self.eval_env
# self.eval_callback.best_model_save_path = str(dk.data_path)
# self.eval_callback._init_callback()
self.eval_callback.__init__(self.eval_env, deterministic=True,
render=False, eval_freq=eval_freq,
best_model_save_path=dk.data_path)
best_model_save_path=str(dk.data_path))
@abstractmethod
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
@@ -137,19 +145,20 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return
def get_state_info(self, pair):
def get_state_info(self, pair: str):
open_trades = Trade.get_trades_proxy(is_open=True)
market_side = 0.5
current_profit = 0
current_profit: float = 0
trade_duration = 0
for trade in open_trades:
if trade.pair == pair:
# FIXME: mypy typing doesnt like that strategy may be "None" (it never will be)
current_value = self.strategy.dp._exchange.get_rate(
pair, refresh=False, side="exit", is_short=trade.is_short)
openrate = trade.open_rate
now = datetime.now(timezone.utc).timestamp()
trade_duration = (now - trade.open_date.timestamp()) / self.base_tf_seconds
if 'long' in trade.enter_tag:
trade_duration = int((now - trade.open_date.timestamp()) / self.base_tf_seconds)
if 'long' in str(trade.enter_tag):
market_side = 1
current_profit = (current_value - openrate) / openrate
else:
@@ -245,8 +254,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return
def make_env(env_id: str, rank: int, seed: int, train_df, price,
reward_params, window_size, monitor=False, config={}) -> Callable:
def make_env(env_id: str, rank: int, seed: int, train_df: DataFrame, price: DataFrame,
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
config: Dict[str, Any] = {}) -> Callable:
"""
Utility function for multiprocessed env.