improve typing, improve docstrings, ensure global tests pass
This commit is contained in:
parent
9c361f4422
commit
77c360b264
@ -25,6 +25,17 @@ class Base4ActionRLEnv(BaseEnvironment):
|
||||
self.action_space = spaces.Discrete(len(Actions))
|
||||
|
||||
def step(self, action: int):
|
||||
"""
|
||||
Logic for a single step (incrementing one candle in time)
|
||||
by the agent
|
||||
:param: action: int = the action type that the agent plans
|
||||
to take for the current step.
|
||||
:returns:
|
||||
observation = current state of environment
|
||||
step_reward = the reward from `calculate_reward()`
|
||||
_done = if the agent "died" or if the candles finished
|
||||
info = dict passed back to openai gym lib
|
||||
"""
|
||||
self._done = False
|
||||
self._current_tick += 1
|
||||
|
||||
@ -92,7 +103,6 @@ class Base4ActionRLEnv(BaseEnvironment):
|
||||
return observation, step_reward, self._done, info
|
||||
|
||||
def is_tradesignal(self, action: int):
|
||||
# trade signal
|
||||
"""
|
||||
Determine if the signal is a trade signal
|
||||
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
||||
@ -107,7 +117,6 @@ class Base4ActionRLEnv(BaseEnvironment):
|
||||
(action == Actions.Long_enter.value and self._position == Positions.Short))
|
||||
|
||||
def _is_valid(self, action: int):
|
||||
# trade signal
|
||||
"""
|
||||
Determine if the signal is valid.
|
||||
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
||||
|
@ -60,6 +60,17 @@ class Base5ActionRLEnv(BaseEnvironment):
|
||||
return self._get_observation()
|
||||
|
||||
def step(self, action: int):
|
||||
"""
|
||||
Logic for a single step (incrementing one candle in time)
|
||||
by the agent
|
||||
:param: action: int = the action type that the agent plans
|
||||
to take for the current step.
|
||||
:returns:
|
||||
observation = current state of environment
|
||||
step_reward = the reward from `calculate_reward()`
|
||||
_done = if the agent "died" or if the candles finished
|
||||
info = dict passed back to openai gym lib
|
||||
"""
|
||||
self._done = False
|
||||
self._current_tick += 1
|
||||
|
||||
|
@ -43,6 +43,10 @@ class BaseEnvironment(gym.Env):
|
||||
|
||||
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
|
||||
reward_kwargs: dict, starting_point=True):
|
||||
"""
|
||||
Resets the environment when the agent fails (in our case, if the drawdown
|
||||
exceeds the user set max_training_drawdown_pct)
|
||||
"""
|
||||
self.df = df
|
||||
self.signal_features = self.df
|
||||
self.prices = prices
|
||||
@ -133,13 +137,18 @@ class BaseEnvironment(gym.Env):
|
||||
return features_and_state
|
||||
|
||||
def get_trade_duration(self):
|
||||
"""
|
||||
Get the trade duration if the agent is in a trade
|
||||
"""
|
||||
if self._last_trade_tick is None:
|
||||
return 0
|
||||
else:
|
||||
return self._current_tick - self._last_trade_tick
|
||||
|
||||
def get_unrealized_profit(self):
|
||||
|
||||
"""
|
||||
Get the unrealized profit if the agent is in a trade
|
||||
"""
|
||||
if self._last_trade_tick is None:
|
||||
return 0.
|
||||
|
||||
@ -158,7 +167,6 @@ class BaseEnvironment(gym.Env):
|
||||
|
||||
@abstractmethod
|
||||
def is_tradesignal(self, action: int):
|
||||
# trade signal
|
||||
"""
|
||||
Determine if the signal is a trade signal. This is
|
||||
unique to the actions in the environment, and therefore must be
|
||||
@ -167,7 +175,6 @@ class BaseEnvironment(gym.Env):
|
||||
return
|
||||
|
||||
def _is_valid(self, action: int):
|
||||
# trade signal
|
||||
"""
|
||||
Determine if the signal is valid.This is
|
||||
unique to the actions in the environment, and therefore must be
|
||||
@ -191,8 +198,13 @@ class BaseEnvironment(gym.Env):
|
||||
@abstractmethod
|
||||
def calculate_reward(self, action):
|
||||
"""
|
||||
Reward is created by BaseReinforcementLearningModel and can
|
||||
be inherited/edited by the user made ReinforcementLearner file.
|
||||
An example reward function. This is the one function that users will likely
|
||||
wish to inject their own creativity into.
|
||||
:params:
|
||||
action: int = The action made by the agent for the current candle.
|
||||
:returns:
|
||||
float = the reward to give to the agent for current step (used for optimization
|
||||
of weights in NN)
|
||||
"""
|
||||
|
||||
def _update_unrealized_total_profit(self):
|
||||
|
@ -2,7 +2,7 @@ import logging
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
from typing import Any, Callable, Dict, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@ -19,8 +19,9 @@ from freqtrade.exceptions import OperationalException
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||
from freqtrade.freqai.RL.BaseEnvironment import Positions
|
||||
from freqtrade.persistence import Trade
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -33,15 +34,15 @@ SB3_CONTRIB_MODELS = ['TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO']
|
||||
|
||||
class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
"""
|
||||
User created Reinforcement Learning Model prediction model.
|
||||
User created Reinforcement Learning Model prediction class
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(config=kwargs['config'])
|
||||
th.set_num_threads(self.freqai_info['rl_config'].get('thread_count', 4))
|
||||
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
||||
self.train_env: BaseEnvironment = None
|
||||
self.eval_env: BaseEnvironment = None
|
||||
self.train_env: Union[SubprocVecEnv, gym.Env] = None
|
||||
self.eval_env: Union[SubprocVecEnv, gym.Env] = None
|
||||
self.eval_callback: EvalCallback = None
|
||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||
self.rl_config = self.freqai_info['rl_config']
|
||||
@ -126,6 +127,13 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
dk: FreqaiDataKitchen):
|
||||
"""
|
||||
User can override this if they are using a custom MyRLEnv
|
||||
:params:
|
||||
data_dictionary: dict = common data dictionary containing train and test
|
||||
features/labels/weights.
|
||||
prices_train/test: DataFrame = dataframe comprised of the prices to be used in the
|
||||
environment during training
|
||||
or testing
|
||||
dk: FreqaiDataKitchen = the datakitchen for the current pair
|
||||
"""
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
@ -148,15 +156,24 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
"""
|
||||
return
|
||||
|
||||
def get_state_info(self, pair: str):
|
||||
def get_state_info(self, pair: str) -> Tuple[float, float, int]:
|
||||
"""
|
||||
State info during dry/live/backtesting which is fed back
|
||||
into the model.
|
||||
:param:
|
||||
pair: str = COIN/STAKE to get the environment information for
|
||||
:returns:
|
||||
market_side: float = representing short, long, or neutral for
|
||||
pair
|
||||
trade_duration: int = the number of candles that the trade has
|
||||
been open for
|
||||
"""
|
||||
open_trades = Trade.get_trades_proxy(is_open=True)
|
||||
market_side = 0.5
|
||||
current_profit: float = 0
|
||||
trade_duration = 0
|
||||
for trade in open_trades:
|
||||
if trade.pair == pair:
|
||||
# FIXME: get_rate and trade_udration shouldn't work with backtesting,
|
||||
# we need to use candle dates and prices to compute that.
|
||||
if self.strategy.dp._exchange is None: # type: ignore
|
||||
logger.error('No exchange available.')
|
||||
else:
|
||||
@ -172,11 +189,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
market_side = 0
|
||||
current_profit = (openrate - current_value) / openrate
|
||||
|
||||
# total_profit = 0
|
||||
# closed_trades = Trade.get_trades_proxy(pair=pair, is_open=False)
|
||||
# for trade in closed_trades:
|
||||
# total_profit += trade.close_profit
|
||||
|
||||
return market_side, current_profit, int(trade_duration)
|
||||
|
||||
def predict(
|
||||
@ -209,7 +221,13 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
def rl_model_predict(self, dataframe: DataFrame,
|
||||
dk: FreqaiDataKitchen, model: Any) -> DataFrame:
|
||||
|
||||
"""
|
||||
A helper function to make predictions in the Reinforcement learning module.
|
||||
:params:
|
||||
dataframe: DataFrame = the dataframe of features to make the predictions on
|
||||
dk: FreqaiDatakitchen = data kitchen for the current pair
|
||||
model: Any = the trained model used to inference the features.
|
||||
"""
|
||||
output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
|
||||
|
||||
def _predict(window):
|
||||
@ -274,26 +292,37 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
sets a custom reward based on profit and trade duration.
|
||||
"""
|
||||
|
||||
def calculate_reward(self, action):
|
||||
|
||||
def calculate_reward(self, action: int) -> float:
|
||||
"""
|
||||
An example reward function. This is the one function that users will likely
|
||||
wish to inject their own creativity into.
|
||||
:params:
|
||||
action: int = The action made by the agent for the current candle.
|
||||
:returns:
|
||||
float = the reward to give to the agent for current step (used for optimization
|
||||
of weights in NN)
|
||||
"""
|
||||
# first, penalize if the action is not valid
|
||||
if not self._is_valid(action):
|
||||
return -2
|
||||
|
||||
pnl = self.get_unrealized_profit()
|
||||
rew = np.sign(pnl) * (pnl + 1)
|
||||
factor = 100
|
||||
factor = 100.
|
||||
|
||||
# reward agent for entering trades
|
||||
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \
|
||||
and self._position == Positions.Neutral:
|
||||
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
||||
and self._position == Positions.Neutral):
|
||||
return 25
|
||||
# discourage agent from not entering trades
|
||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||
return -1
|
||||
|
||||
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
||||
trade_duration = self._current_tick - self._last_trade_tick
|
||||
if self._last_trade_tick:
|
||||
trade_duration = self._current_tick - self._last_trade_tick
|
||||
else:
|
||||
trade_duration = 0
|
||||
|
||||
if trade_duration <= max_trade_duration:
|
||||
factor *= 1.5
|
||||
@ -301,8 +330,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
factor *= 0.5
|
||||
|
||||
# discourage sitting in position
|
||||
if self._position in (Positions.Short, Positions.Long) and \
|
||||
action == Actions.Neutral.value:
|
||||
if (self._position in (Positions.Short, Positions.Long) and
|
||||
action == Actions.Neutral.value):
|
||||
return -1 * trade_duration / max_trade_duration
|
||||
|
||||
# close long
|
||||
@ -320,7 +349,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
return 0.
|
||||
|
||||
|
||||
def make_env(MyRLEnv: BaseEnvironment, env_id: str, rank: int,
|
||||
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
||||
seed: int, train_df: DataFrame, price: DataFrame,
|
||||
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
||||
config: Dict[str, Any] = {}) -> Callable:
|
||||
|
@ -19,7 +19,15 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||
|
||||
"""
|
||||
User customizable fit method
|
||||
:params:
|
||||
data_dictionary: dict = common data dictionary containing all train/test
|
||||
features/labels/weights.
|
||||
dk: FreqaiDatakitchen = data kitchen for current pair.
|
||||
:returns:
|
||||
model: Any = trained model to be used for inference in dry/live/backtesting
|
||||
"""
|
||||
train_df = data_dictionary["train_features"]
|
||||
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
||||
|
||||
@ -59,7 +67,15 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
"""
|
||||
|
||||
def calculate_reward(self, action):
|
||||
|
||||
"""
|
||||
An example reward function. This is the one function that users will likely
|
||||
wish to inject their own creativity into.
|
||||
:params:
|
||||
action: int = The action made by the agent for the current candle.
|
||||
:returns:
|
||||
float = the reward to give to the agent for current step (used for optimization
|
||||
of weights in NN)
|
||||
"""
|
||||
# first, penalize if the action is not valid
|
||||
if not self._is_valid(action):
|
||||
return -2
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, Dict # , Tuple
|
||||
import torch as th
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
|
||||
from pandas import DataFrame
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import (BaseReinforcementLearningModel,
|
||||
make_env)
|
||||
@ -55,11 +55,18 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
|
||||
|
||||
return model
|
||||
|
||||
def set_train_and_eval_environments(self, data_dictionary, prices_train, prices_test, dk):
|
||||
def set_train_and_eval_environments(self, data_dictionary: Dict[str, Any],
|
||||
prices_train: DataFrame, prices_test: DataFrame,
|
||||
dk: FreqaiDataKitchen):
|
||||
"""
|
||||
If user has particular environment configuration needs, they can do that by
|
||||
overriding this function. In the present case, the user wants to setup training
|
||||
environments for multiple workers.
|
||||
User can override this if they are using a custom MyRLEnv
|
||||
:params:
|
||||
data_dictionary: dict = common data dictionary containing train and test
|
||||
features/labels/weights.
|
||||
prices_train/test: DataFrame = dataframe comprised of the prices to be used in
|
||||
the environment during training
|
||||
or testing
|
||||
dk: FreqaiDataKitchen = the datakitchen for the current pair
|
||||
"""
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
@ -79,4 +86,4 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
|
||||
in range(num_cpu)])
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=len(train_df),
|
||||
best_model_save_path=dk.data_path)
|
||||
best_model_save_path=str(dk.data_path))
|
||||
|
@ -244,7 +244,7 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat):
|
||||
model_folders = [x for x in freqai.dd.full_path.iterdir() if x.is_dir()]
|
||||
|
||||
assert len(model_folders) == num_files
|
||||
|
||||
Trade.use_db = True
|
||||
shutil.rmtree(Path(freqai.dk.full_path))
|
||||
|
||||
|
||||
@ -297,7 +297,7 @@ def test_start_backtesting_from_existing_folder(mocker, freqai_conf, caplog):
|
||||
|
||||
assert len(model_folders) == 6
|
||||
|
||||
# without deleting the exiting folder structure, re-run
|
||||
# without deleting the existing folder structure, re-run
|
||||
|
||||
freqai_conf.update({"timerange": "20180120-20180130"})
|
||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
||||
|
Loading…
Reference in New Issue
Block a user