improve typing, improve docstrings, ensure global tests pass

This commit is contained in:
robcaulk
2022-09-23 19:17:27 +02:00
parent 9c361f4422
commit 77c360b264
7 changed files with 124 additions and 40 deletions

View File

@@ -19,7 +19,15 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
"""
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
"""
User customizable fit method
:params:
data_dictionary: dict = common data dictionary containing all train/test
features/labels/weights.
dk: FreqaiDatakitchen = data kitchen for current pair.
:returns:
model: Any = trained model to be used for inference in dry/live/backtesting
"""
train_df = data_dictionary["train_features"]
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
@@ -59,7 +67,15 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
"""
def calculate_reward(self, action):
"""
An example reward function. This is the one function that users will likely
wish to inject their own creativity into.
:params:
action: int = The action made by the agent for the current candle.
:returns:
float = the reward to give to the agent for current step (used for optimization
of weights in NN)
"""
# first, penalize if the action is not valid
if not self._is_valid(action):
return -2

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict # , Tuple
import torch as th
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import SubprocVecEnv
from pandas import DataFrame
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.RL.BaseReinforcementLearningModel import (BaseReinforcementLearningModel,
make_env)
@@ -55,11 +55,18 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
return model
def set_train_and_eval_environments(self, data_dictionary, prices_train, prices_test, dk):
def set_train_and_eval_environments(self, data_dictionary: Dict[str, Any],
prices_train: DataFrame, prices_test: DataFrame,
dk: FreqaiDataKitchen):
"""
If user has particular environment configuration needs, they can do that by
overriding this function. In the present case, the user wants to setup training
environments for multiple workers.
User can override this if they are using a custom MyRLEnv
:params:
data_dictionary: dict = common data dictionary containing train and test
features/labels/weights.
prices_train/test: DataFrame = dataframe comprised of the prices to be used in
the environment during training
or testing
dk: FreqaiDataKitchen = the datakitchen for the current pair
"""
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
@@ -79,4 +86,4 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
in range(num_cpu)])
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df),
best_model_save_path=dk.data_path)
best_model_save_path=str(dk.data_path))