add continual retraining feature, handly mypy typing reqs, improve docstrings
This commit is contained in:
@@ -19,6 +19,7 @@ from typing import Callable
|
||||
from datetime import datetime, timezone
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
import gym
|
||||
from pathlib import Path
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
@@ -40,6 +41,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
self.eval_env: Base5ActionRLEnv = None
|
||||
self.eval_callback: EvalCallback = None
|
||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||
self.rl_config = self.freqai_info['rl_config']
|
||||
self.continual_retraining = self.rl_config['continual_retraining']
|
||||
if self.model_type in SB3_MODELS:
|
||||
import_str = 'stable_baselines3'
|
||||
elif self.model_type in SB3_CONTRIB_MODELS:
|
||||
@@ -68,7 +71,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
logger.info("--------------------Starting training " f"{pair} --------------------")
|
||||
|
||||
# filter the features requested by user in the configuration file and elegantly handle NaNs
|
||||
features_filtered, labels_filtered = dk.filter_features(
|
||||
unfiltered_dataframe,
|
||||
dk.training_features_list,
|
||||
@@ -78,19 +80,19 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
data_dictionary: Dict[str, Any] = dk.make_train_test_datasets(
|
||||
features_filtered, labels_filtered)
|
||||
dk.fit_labels() # useless for now, but just satiating append methods
|
||||
dk.fit_labels() # FIXME useless for now, but just satiating append methods
|
||||
|
||||
# normalize all data based on train_dataset only
|
||||
prices_train, prices_test = self.build_ohlc_price_dataframes(dk.data_dictionary, pair, dk)
|
||||
data_dictionary = dk.normalize_data(data_dictionary)
|
||||
|
||||
# optional additional data cleaning/analysis
|
||||
# data cleaning/analysis
|
||||
self.data_cleaning_train(dk)
|
||||
|
||||
logger.info(
|
||||
f'Training model on {len(dk.data_dictionary["train_features"].columns)}' " features"
|
||||
f'Training model on {len(dk.data_dictionary["train_features"].columns)}'
|
||||
f' features and {len(data_dictionary["train_features"])} data points'
|
||||
)
|
||||
logger.info(f'Training model on {len(data_dictionary["train_features"])} data points')
|
||||
|
||||
self.set_train_and_eval_environments(data_dictionary, prices_train, prices_test, dk)
|
||||
|
||||
@@ -100,9 +102,11 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
return model
|
||||
|
||||
def set_train_and_eval_environments(self, data_dictionary, prices_train, prices_test, dk):
|
||||
def set_train_and_eval_environments(self, data_dictionary: Dict[str, DataFrame],
|
||||
prices_train: DataFrame, prices_test: DataFrame,
|
||||
dk: FreqaiDataKitchen):
|
||||
"""
|
||||
User overrides this as shown here if they are using a custom MyRLEnv
|
||||
User can override this if they are using a custom MyRLEnv
|
||||
"""
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
@@ -114,18 +118,22 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
reward_kwargs=self.reward_params, config=self.config)
|
||||
self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test,
|
||||
window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params, config=self.config), ".")
|
||||
reward_kwargs=self.reward_params, config=self.config),
|
||||
str(Path(dk.data_path / 'monitor')))
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=eval_freq,
|
||||
best_model_save_path=dk.data_path)
|
||||
best_model_save_path=str(dk.data_path))
|
||||
else:
|
||||
self.train_env.reset()
|
||||
self.eval_env.reset()
|
||||
self.train_env.reset_env(train_df, prices_train, self.CONV_WIDTH, self.reward_params)
|
||||
self.eval_env.reset_env(test_df, prices_test, self.CONV_WIDTH, self.reward_params)
|
||||
# self.eval_callback.eval_env = self.eval_env
|
||||
# self.eval_callback.best_model_save_path = str(dk.data_path)
|
||||
# self.eval_callback._init_callback()
|
||||
self.eval_callback.__init__(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=eval_freq,
|
||||
best_model_save_path=dk.data_path)
|
||||
best_model_save_path=str(dk.data_path))
|
||||
|
||||
@abstractmethod
|
||||
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
|
||||
@@ -137,19 +145,20 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
return
|
||||
|
||||
def get_state_info(self, pair):
|
||||
def get_state_info(self, pair: str):
|
||||
open_trades = Trade.get_trades_proxy(is_open=True)
|
||||
market_side = 0.5
|
||||
current_profit = 0
|
||||
current_profit: float = 0
|
||||
trade_duration = 0
|
||||
for trade in open_trades:
|
||||
if trade.pair == pair:
|
||||
# FIXME: mypy typing doesnt like that strategy may be "None" (it never will be)
|
||||
current_value = self.strategy.dp._exchange.get_rate(
|
||||
pair, refresh=False, side="exit", is_short=trade.is_short)
|
||||
openrate = trade.open_rate
|
||||
now = datetime.now(timezone.utc).timestamp()
|
||||
trade_duration = (now - trade.open_date.timestamp()) / self.base_tf_seconds
|
||||
if 'long' in trade.enter_tag:
|
||||
trade_duration = int((now - trade.open_date.timestamp()) / self.base_tf_seconds)
|
||||
if 'long' in str(trade.enter_tag):
|
||||
market_side = 1
|
||||
current_profit = (current_value - openrate) / openrate
|
||||
else:
|
||||
@@ -245,8 +254,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
return
|
||||
|
||||
|
||||
def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
||||
reward_params, window_size, monitor=False, config={}) -> Callable:
|
||||
def make_env(env_id: str, rank: int, seed: int, train_df: DataFrame, price: DataFrame,
|
||||
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
||||
config: Dict[str, Any] = {}) -> Callable:
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user