improve typing, improve docstrings, ensure global tests pass

This commit is contained in:
robcaulk
2022-09-23 19:17:27 +02:00
parent 9c361f4422
commit 77c360b264
7 changed files with 124 additions and 40 deletions

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict # , Tuple
import torch as th
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import SubprocVecEnv
from pandas import DataFrame
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.RL.BaseReinforcementLearningModel import (BaseReinforcementLearningModel,
make_env)
@@ -55,11 +55,18 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
return model
def set_train_and_eval_environments(self, data_dictionary, prices_train, prices_test, dk):
def set_train_and_eval_environments(self, data_dictionary: Dict[str, Any],
prices_train: DataFrame, prices_test: DataFrame,
dk: FreqaiDataKitchen):
"""
If user has particular environment configuration needs, they can do that by
overriding this function. In the present case, the user wants to setup training
environments for multiple workers.
User can override this if they are using a custom MyRLEnv
:params:
data_dictionary: dict = common data dictionary containing train and test
features/labels/weights.
prices_train/test: DataFrame = dataframe comprised of the prices to be used in
the environment during training
or testing
dk: FreqaiDataKitchen = the datakitchen for the current pair
"""
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
@@ -79,4 +86,4 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
in range(num_cpu)])
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df),
best_model_save_path=dk.data_path)
best_model_save_path=str(dk.data_path))