stable/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py
2022-12-14 22:03:05 +03:00

62 lines
2.9 KiB
Python

import logging
from typing import Any, Dict
from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import SubprocVecEnv
from freqtrade.enums import RunMode
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
logger = logging.getLogger(__name__)
class ReinforcementLearner_multiproc(ReinforcementLearner):
"""
Demonstration of how to build vectorized environments
"""
def set_train_and_eval_environments(self, data_dictionary: Dict[str, Any],
prices_train: DataFrame, prices_test: DataFrame,
dk: FreqaiDataKitchen):
"""
User can override this if they are using a custom MyRLEnv
:param data_dictionary: dict = common data dictionary containing train and test
features/labels/weights.
:param prices_train/test: DataFrame = dataframe comprised of the prices to be used in
the environment during training
or testing
:param dk: FreqaiDataKitchen = the datakitchen for the current pair
"""
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
env_info = {"live": False}
if self.data_provider:
env_info["live"] = self.data_provider.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
env_info["fee"] = self.data_provider._exchange \
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
env_id = "train_env"
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
self.reward_params, self.CONV_WIDTH, monitor=True,
config=self.config, env_info=env_info) for i
in range(self.max_threads)])
eval_env_id = 'eval_env'
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
test_df, prices_test,
self.reward_params, self.CONV_WIDTH, monitor=True,
config=self.config, env_info=env_info) for i
in range(self.max_threads)])
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path))
actions = self.train_env.env_method("get_actions")[0]
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)