Merge branch 'freqtrade:develop' into develop
This commit is contained in:
135
freqtrade/freqai/RL/Base4ActionRLEnv.py
Normal file
135
freqtrade/freqai/RL/Base4ActionRLEnv.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from gym import spaces
|
||||
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Actions(Enum):
|
||||
Neutral = 0
|
||||
Exit = 1
|
||||
Long_enter = 2
|
||||
Short_enter = 3
|
||||
|
||||
|
||||
class Base4ActionRLEnv(BaseEnvironment):
|
||||
"""
|
||||
Base class for a 4 action environment
|
||||
"""
|
||||
|
||||
def set_action_space(self):
|
||||
self.action_space = spaces.Discrete(len(Actions))
|
||||
|
||||
def step(self, action: int):
|
||||
"""
|
||||
Logic for a single step (incrementing one candle in time)
|
||||
by the agent
|
||||
:param: action: int = the action type that the agent plans
|
||||
to take for the current step.
|
||||
:returns:
|
||||
observation = current state of environment
|
||||
step_reward = the reward from `calculate_reward()`
|
||||
_done = if the agent "died" or if the candles finished
|
||||
info = dict passed back to openai gym lib
|
||||
"""
|
||||
self._done = False
|
||||
self._current_tick += 1
|
||||
|
||||
if self._current_tick == self._end_tick:
|
||||
self._done = True
|
||||
|
||||
self._update_unrealized_total_profit()
|
||||
|
||||
step_reward = self.calculate_reward(action)
|
||||
self.total_reward += step_reward
|
||||
|
||||
trade_type = None
|
||||
if self.is_tradesignal(action):
|
||||
"""
|
||||
Action: Neutral, position: Long -> Close Long
|
||||
Action: Neutral, position: Short -> Close Short
|
||||
|
||||
Action: Long, position: Neutral -> Open Long
|
||||
Action: Long, position: Short -> Close Short and Open Long
|
||||
|
||||
Action: Short, position: Neutral -> Open Short
|
||||
Action: Short, position: Long -> Close Long and Open Short
|
||||
"""
|
||||
|
||||
if action == Actions.Neutral.value:
|
||||
self._position = Positions.Neutral
|
||||
trade_type = "neutral"
|
||||
self._last_trade_tick = None
|
||||
elif action == Actions.Long_enter.value:
|
||||
self._position = Positions.Long
|
||||
trade_type = "long"
|
||||
self._last_trade_tick = self._current_tick
|
||||
elif action == Actions.Short_enter.value:
|
||||
self._position = Positions.Short
|
||||
trade_type = "short"
|
||||
self._last_trade_tick = self._current_tick
|
||||
elif action == Actions.Exit.value:
|
||||
self._update_total_profit()
|
||||
self._position = Positions.Neutral
|
||||
trade_type = "neutral"
|
||||
self._last_trade_tick = None
|
||||
else:
|
||||
print("case not defined")
|
||||
|
||||
if trade_type is not None:
|
||||
self.trade_history.append(
|
||||
{'price': self.current_price(), 'index': self._current_tick,
|
||||
'type': trade_type})
|
||||
|
||||
if self._total_profit < 1 - self.rl_config.get('max_training_drawdown_pct', 0.8):
|
||||
self._done = True
|
||||
|
||||
self._position_history.append(self._position)
|
||||
|
||||
info = dict(
|
||||
tick=self._current_tick,
|
||||
total_reward=self.total_reward,
|
||||
total_profit=self._total_profit,
|
||||
position=self._position.value
|
||||
)
|
||||
|
||||
observation = self._get_observation()
|
||||
|
||||
self._update_history(info)
|
||||
|
||||
return observation, step_reward, self._done, info
|
||||
|
||||
def is_tradesignal(self, action: int) -> bool:
|
||||
"""
|
||||
Determine if the signal is a trade signal
|
||||
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
||||
"""
|
||||
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
|
||||
(action == Actions.Neutral.value and self._position == Positions.Short) or
|
||||
(action == Actions.Neutral.value and self._position == Positions.Long) or
|
||||
(action == Actions.Short_enter.value and self._position == Positions.Short) or
|
||||
(action == Actions.Short_enter.value and self._position == Positions.Long) or
|
||||
(action == Actions.Exit.value and self._position == Positions.Neutral) or
|
||||
(action == Actions.Long_enter.value and self._position == Positions.Long) or
|
||||
(action == Actions.Long_enter.value and self._position == Positions.Short))
|
||||
|
||||
def _is_valid(self, action: int) -> bool:
|
||||
"""
|
||||
Determine if the signal is valid.
|
||||
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
||||
"""
|
||||
# Agent should only try to exit if it is in position
|
||||
if action == Actions.Exit.value:
|
||||
if self._position not in (Positions.Short, Positions.Long):
|
||||
return False
|
||||
|
||||
# Agent should only try to enter if it is not in position
|
||||
if action in (Actions.Short_enter.value, Actions.Long_enter.value):
|
||||
if self._position != Positions.Neutral:
|
||||
return False
|
||||
|
||||
return True
|
145
freqtrade/freqai/RL/Base5ActionRLEnv.py
Normal file
145
freqtrade/freqai/RL/Base5ActionRLEnv.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from gym import spaces
|
||||
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Actions(Enum):
|
||||
Neutral = 0
|
||||
Long_enter = 1
|
||||
Long_exit = 2
|
||||
Short_enter = 3
|
||||
Short_exit = 4
|
||||
|
||||
|
||||
class Base5ActionRLEnv(BaseEnvironment):
|
||||
"""
|
||||
Base class for a 5 action environment
|
||||
"""
|
||||
|
||||
def set_action_space(self):
|
||||
self.action_space = spaces.Discrete(len(Actions))
|
||||
|
||||
def step(self, action: int):
|
||||
"""
|
||||
Logic for a single step (incrementing one candle in time)
|
||||
by the agent
|
||||
:param: action: int = the action type that the agent plans
|
||||
to take for the current step.
|
||||
:returns:
|
||||
observation = current state of environment
|
||||
step_reward = the reward from `calculate_reward()`
|
||||
_done = if the agent "died" or if the candles finished
|
||||
info = dict passed back to openai gym lib
|
||||
"""
|
||||
self._done = False
|
||||
self._current_tick += 1
|
||||
|
||||
if self._current_tick == self._end_tick:
|
||||
self._done = True
|
||||
|
||||
self._update_unrealized_total_profit()
|
||||
step_reward = self.calculate_reward(action)
|
||||
self.total_reward += step_reward
|
||||
|
||||
trade_type = None
|
||||
if self.is_tradesignal(action):
|
||||
"""
|
||||
Action: Neutral, position: Long -> Close Long
|
||||
Action: Neutral, position: Short -> Close Short
|
||||
|
||||
Action: Long, position: Neutral -> Open Long
|
||||
Action: Long, position: Short -> Close Short and Open Long
|
||||
|
||||
Action: Short, position: Neutral -> Open Short
|
||||
Action: Short, position: Long -> Close Long and Open Short
|
||||
"""
|
||||
|
||||
if action == Actions.Neutral.value:
|
||||
self._position = Positions.Neutral
|
||||
trade_type = "neutral"
|
||||
self._last_trade_tick = None
|
||||
elif action == Actions.Long_enter.value:
|
||||
self._position = Positions.Long
|
||||
trade_type = "long"
|
||||
self._last_trade_tick = self._current_tick
|
||||
elif action == Actions.Short_enter.value:
|
||||
self._position = Positions.Short
|
||||
trade_type = "short"
|
||||
self._last_trade_tick = self._current_tick
|
||||
elif action == Actions.Long_exit.value:
|
||||
self._update_total_profit()
|
||||
self._position = Positions.Neutral
|
||||
trade_type = "neutral"
|
||||
self._last_trade_tick = None
|
||||
elif action == Actions.Short_exit.value:
|
||||
self._update_total_profit()
|
||||
self._position = Positions.Neutral
|
||||
trade_type = "neutral"
|
||||
self._last_trade_tick = None
|
||||
else:
|
||||
print("case not defined")
|
||||
|
||||
if trade_type is not None:
|
||||
self.trade_history.append(
|
||||
{'price': self.current_price(), 'index': self._current_tick,
|
||||
'type': trade_type})
|
||||
|
||||
if (self._total_profit < self.max_drawdown or
|
||||
self._total_unrealized_profit < self.max_drawdown):
|
||||
self._done = True
|
||||
|
||||
self._position_history.append(self._position)
|
||||
|
||||
info = dict(
|
||||
tick=self._current_tick,
|
||||
total_reward=self.total_reward,
|
||||
total_profit=self._total_profit,
|
||||
position=self._position.value
|
||||
)
|
||||
|
||||
observation = self._get_observation()
|
||||
|
||||
self._update_history(info)
|
||||
|
||||
return observation, step_reward, self._done, info
|
||||
|
||||
def is_tradesignal(self, action: int) -> bool:
|
||||
"""
|
||||
Determine if the signal is a trade signal
|
||||
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
||||
"""
|
||||
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
|
||||
(action == Actions.Neutral.value and self._position == Positions.Short) or
|
||||
(action == Actions.Neutral.value and self._position == Positions.Long) or
|
||||
(action == Actions.Short_enter.value and self._position == Positions.Short) or
|
||||
(action == Actions.Short_enter.value and self._position == Positions.Long) or
|
||||
(action == Actions.Short_exit.value and self._position == Positions.Long) or
|
||||
(action == Actions.Short_exit.value and self._position == Positions.Neutral) or
|
||||
(action == Actions.Long_enter.value and self._position == Positions.Long) or
|
||||
(action == Actions.Long_enter.value and self._position == Positions.Short) or
|
||||
(action == Actions.Long_exit.value and self._position == Positions.Short) or
|
||||
(action == Actions.Long_exit.value and self._position == Positions.Neutral))
|
||||
|
||||
def _is_valid(self, action: int) -> bool:
|
||||
# trade signal
|
||||
"""
|
||||
Determine if the signal is valid.
|
||||
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
||||
"""
|
||||
# Agent should only try to exit if it is in position
|
||||
if action in (Actions.Short_exit.value, Actions.Long_exit.value):
|
||||
if self._position not in (Positions.Short, Positions.Long):
|
||||
return False
|
||||
|
||||
# Agent should only try to enter if it is not in position
|
||||
if action in (Actions.Short_enter.value, Actions.Long_enter.value):
|
||||
if self._position != Positions.Neutral:
|
||||
return False
|
||||
|
||||
return True
|
307
freqtrade/freqai/RL/BaseEnvironment.py
Normal file
307
freqtrade/freqai/RL/BaseEnvironment.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import logging
|
||||
import random
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gym import spaces
|
||||
from gym.utils import seeding
|
||||
from pandas import DataFrame
|
||||
|
||||
from freqtrade.data.dataprovider import DataProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Positions(Enum):
|
||||
Short = 0
|
||||
Long = 1
|
||||
Neutral = 0.5
|
||||
|
||||
def opposite(self):
|
||||
return Positions.Short if self == Positions.Long else Positions.Long
|
||||
|
||||
|
||||
class BaseEnvironment(gym.Env):
|
||||
"""
|
||||
Base class for environments. This class is agnostic to action count.
|
||||
Inherited classes customize this to include varying action counts/types,
|
||||
See RL/Base5ActionRLEnv.py and RL/Base4ActionRLEnv.py
|
||||
"""
|
||||
|
||||
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
||||
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
||||
id: str = 'baseenv-1', seed: int = 1, config: dict = {},
|
||||
dp: Optional[DataProvider] = None):
|
||||
"""
|
||||
Initializes the training/eval environment.
|
||||
:param df: dataframe of features
|
||||
:param prices: dataframe of prices to be used in the training environment
|
||||
:param window_size: size of window (temporal) to pass to the agent
|
||||
:param reward_kwargs: extra config settings assigned by user in `rl_config`
|
||||
:param starting_point: start at edge of window or not
|
||||
:param id: string id of the environment (used in backend for multiprocessed env)
|
||||
:param seed: Sets the seed of the environment higher in the gym.Env object
|
||||
:param config: Typical user configuration file
|
||||
:param dp: dataprovider from freqtrade
|
||||
"""
|
||||
self.config = config
|
||||
self.rl_config = config['freqai']['rl_config']
|
||||
self.add_state_info = self.rl_config.get('add_state_info', False)
|
||||
self.id = id
|
||||
self.seed(seed)
|
||||
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
||||
self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
|
||||
self.compound_trades = config['stake_amount'] == 'unlimited'
|
||||
if self.config.get('fee', None) is not None:
|
||||
self.fee = self.config['fee']
|
||||
elif dp is not None:
|
||||
self.fee = dp._exchange.get_fee(symbol=dp.current_whitelist()[0]) # type: ignore
|
||||
else:
|
||||
self.fee = 0.0015
|
||||
|
||||
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
|
||||
reward_kwargs: dict, starting_point=True):
|
||||
"""
|
||||
Resets the environment when the agent fails (in our case, if the drawdown
|
||||
exceeds the user set max_training_drawdown_pct)
|
||||
:param df: dataframe of features
|
||||
:param prices: dataframe of prices to be used in the training environment
|
||||
:param window_size: size of window (temporal) to pass to the agent
|
||||
:param reward_kwargs: extra config settings assigned by user in `rl_config`
|
||||
:param starting_point: start at edge of window or not
|
||||
"""
|
||||
self.df = df
|
||||
self.signal_features = self.df
|
||||
self.prices = prices
|
||||
self.window_size = window_size
|
||||
self.starting_point = starting_point
|
||||
self.rr = reward_kwargs["rr"]
|
||||
self.profit_aim = reward_kwargs["profit_aim"]
|
||||
|
||||
# # spaces
|
||||
if self.add_state_info:
|
||||
self.total_features = self.signal_features.shape[1] + 3
|
||||
else:
|
||||
self.total_features = self.signal_features.shape[1]
|
||||
self.shape = (window_size, self.total_features)
|
||||
self.set_action_space()
|
||||
self.observation_space = spaces.Box(
|
||||
low=-1, high=1, shape=self.shape, dtype=np.float32)
|
||||
|
||||
# episode
|
||||
self._start_tick: int = self.window_size
|
||||
self._end_tick: int = len(self.prices) - 1
|
||||
self._done: bool = False
|
||||
self._current_tick: int = self._start_tick
|
||||
self._last_trade_tick: Optional[int] = None
|
||||
self._position = Positions.Neutral
|
||||
self._position_history: list = [None]
|
||||
self.total_reward: float = 0
|
||||
self._total_profit: float = 1
|
||||
self._total_unrealized_profit: float = 1
|
||||
self.history: dict = {}
|
||||
self.trade_history: list = []
|
||||
|
||||
@abstractmethod
|
||||
def set_action_space(self):
|
||||
"""
|
||||
Unique to the environment action count. Must be inherited.
|
||||
"""
|
||||
|
||||
def seed(self, seed: int = 1):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def reset(self):
|
||||
|
||||
self._done = False
|
||||
|
||||
if self.starting_point is True:
|
||||
if self.rl_config.get('randomize_starting_position', False):
|
||||
length_of_data = int(self._end_tick / 4)
|
||||
start_tick = random.randint(self.window_size + 1, length_of_data)
|
||||
self._start_tick = start_tick
|
||||
self._position_history = (self._start_tick * [None]) + [self._position]
|
||||
else:
|
||||
self._position_history = (self.window_size * [None]) + [self._position]
|
||||
|
||||
self._current_tick = self._start_tick
|
||||
self._last_trade_tick = None
|
||||
self._position = Positions.Neutral
|
||||
|
||||
self.total_reward = 0.
|
||||
self._total_profit = 1. # unit
|
||||
self.history = {}
|
||||
self.trade_history = []
|
||||
self.portfolio_log_returns = np.zeros(len(self.prices))
|
||||
|
||||
self._profits = [(self._start_tick, 1)]
|
||||
self.close_trade_profit = []
|
||||
self._total_unrealized_profit = 1
|
||||
|
||||
return self._get_observation()
|
||||
|
||||
@abstractmethod
|
||||
def step(self, action: int):
|
||||
"""
|
||||
Step depeneds on action types, this must be inherited.
|
||||
"""
|
||||
return
|
||||
|
||||
def _get_observation(self):
|
||||
"""
|
||||
This may or may not be independent of action types, user can inherit
|
||||
this in their custom "MyRLEnv"
|
||||
"""
|
||||
features_window = self.signal_features[(
|
||||
self._current_tick - self.window_size):self._current_tick]
|
||||
if self.add_state_info:
|
||||
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
|
||||
columns=['current_profit_pct',
|
||||
'position',
|
||||
'trade_duration'],
|
||||
index=features_window.index)
|
||||
|
||||
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
|
||||
features_and_state['position'] = self._position.value
|
||||
features_and_state['trade_duration'] = self.get_trade_duration()
|
||||
features_and_state = pd.concat([features_window, features_and_state], axis=1)
|
||||
return features_and_state
|
||||
else:
|
||||
return features_window
|
||||
|
||||
def get_trade_duration(self):
|
||||
"""
|
||||
Get the trade duration if the agent is in a trade
|
||||
"""
|
||||
if self._last_trade_tick is None:
|
||||
return 0
|
||||
else:
|
||||
return self._current_tick - self._last_trade_tick
|
||||
|
||||
def get_unrealized_profit(self):
|
||||
"""
|
||||
Get the unrealized profit if the agent is in a trade
|
||||
"""
|
||||
if self._last_trade_tick is None:
|
||||
return 0.
|
||||
|
||||
if self._position == Positions.Neutral:
|
||||
return 0.
|
||||
elif self._position == Positions.Short:
|
||||
current_price = self.add_entry_fee(self.prices.iloc[self._current_tick].open)
|
||||
last_trade_price = self.add_exit_fee(self.prices.iloc[self._last_trade_tick].open)
|
||||
return (last_trade_price - current_price) / last_trade_price
|
||||
elif self._position == Positions.Long:
|
||||
current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open)
|
||||
last_trade_price = self.add_entry_fee(self.prices.iloc[self._last_trade_tick].open)
|
||||
return (current_price - last_trade_price) / last_trade_price
|
||||
else:
|
||||
return 0.
|
||||
|
||||
@abstractmethod
|
||||
def is_tradesignal(self, action: int) -> bool:
|
||||
"""
|
||||
Determine if the signal is a trade signal. This is
|
||||
unique to the actions in the environment, and therefore must be
|
||||
inherited.
|
||||
"""
|
||||
return True
|
||||
|
||||
def _is_valid(self, action: int) -> bool:
|
||||
"""
|
||||
Determine if the signal is valid.This is
|
||||
unique to the actions in the environment, and therefore must be
|
||||
inherited.
|
||||
"""
|
||||
return True
|
||||
|
||||
def add_entry_fee(self, price):
|
||||
return price * (1 + self.fee)
|
||||
|
||||
def add_exit_fee(self, price):
|
||||
return price / (1 + self.fee)
|
||||
|
||||
def _update_history(self, info):
|
||||
if not self.history:
|
||||
self.history = {key: [] for key in info.keys()}
|
||||
|
||||
for key, value in info.items():
|
||||
self.history[key].append(value)
|
||||
|
||||
@abstractmethod
|
||||
def calculate_reward(self, action: int) -> float:
|
||||
"""
|
||||
An example reward function. This is the one function that users will likely
|
||||
wish to inject their own creativity into.
|
||||
:param action: int = The action made by the agent for the current candle.
|
||||
:return:
|
||||
float = the reward to give to the agent for current step (used for optimization
|
||||
of weights in NN)
|
||||
"""
|
||||
|
||||
def _update_unrealized_total_profit(self):
|
||||
"""
|
||||
Update the unrealized total profit incase of episode end.
|
||||
"""
|
||||
if self._position in (Positions.Long, Positions.Short):
|
||||
pnl = self.get_unrealized_profit()
|
||||
if self.compound_trades:
|
||||
# assumes unit stake and compounding
|
||||
unrl_profit = self._total_profit * (1 + pnl)
|
||||
else:
|
||||
# assumes unit stake and no compounding
|
||||
unrl_profit = self._total_profit + pnl
|
||||
self._total_unrealized_profit = unrl_profit
|
||||
|
||||
def _update_total_profit(self):
|
||||
pnl = self.get_unrealized_profit()
|
||||
if self.compound_trades:
|
||||
# assumes unit stake and compounding
|
||||
self._total_profit = self._total_profit * (1 + pnl)
|
||||
else:
|
||||
# assumes unit stake and no compounding
|
||||
self._total_profit += pnl
|
||||
|
||||
def current_price(self) -> float:
|
||||
return self.prices.iloc[self._current_tick].open
|
||||
|
||||
# Keeping around incase we want to start building more complex environment
|
||||
# templates in the future.
|
||||
# def most_recent_return(self):
|
||||
# """
|
||||
# Calculate the tick to tick return if in a trade.
|
||||
# Return is generated from rising prices in Long
|
||||
# and falling prices in Short positions.
|
||||
# The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
|
||||
# """
|
||||
# # Long positions
|
||||
# if self._position == Positions.Long:
|
||||
# current_price = self.prices.iloc[self._current_tick].open
|
||||
# previous_price = self.prices.iloc[self._current_tick - 1].open
|
||||
|
||||
# if (self._position_history[self._current_tick - 1] == Positions.Short
|
||||
# or self._position_history[self._current_tick - 1] == Positions.Neutral):
|
||||
# previous_price = self.add_entry_fee(previous_price)
|
||||
|
||||
# return np.log(current_price) - np.log(previous_price)
|
||||
|
||||
# # Short positions
|
||||
# if self._position == Positions.Short:
|
||||
# current_price = self.prices.iloc[self._current_tick].open
|
||||
# previous_price = self.prices.iloc[self._current_tick - 1].open
|
||||
# if (self._position_history[self._current_tick - 1] == Positions.Long
|
||||
# or self._position_history[self._current_tick - 1] == Positions.Neutral):
|
||||
# previous_price = self.add_exit_fee(previous_price)
|
||||
|
||||
# return np.log(previous_price) - np.log(current_price)
|
||||
|
||||
# return 0
|
||||
|
||||
# def update_portfolio_log_returns(self, action):
|
||||
# self.portfolio_log_returns[self._current_tick] = self.most_recent_return(action)
|
400
freqtrade/freqai/RL/BaseReinforcementLearningModel.py
Normal file
400
freqtrade/freqai/RL/BaseReinforcementLearningModel.py
Normal file
@@ -0,0 +1,400 @@
|
||||
import importlib
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pandas as pd
|
||||
import torch as th
|
||||
import torch.multiprocessing
|
||||
from pandas import DataFrame
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
|
||||
from freqtrade.exceptions import OperationalException
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
||||
from freqtrade.freqai.RL.BaseEnvironment import Positions
|
||||
from freqtrade.persistence import Trade
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
|
||||
SB3_MODELS = ['PPO', 'A2C', 'DQN']
|
||||
SB3_CONTRIB_MODELS = ['TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO']
|
||||
|
||||
|
||||
class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
"""
|
||||
User created Reinforcement Learning Model prediction class
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(config=kwargs['config'])
|
||||
self.max_threads = min(self.freqai_info['rl_config'].get(
|
||||
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
|
||||
th.set_num_threads(self.max_threads)
|
||||
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
||||
self.train_env: Union[SubprocVecEnv, gym.Env] = None
|
||||
self.eval_env: Union[SubprocVecEnv, gym.Env] = None
|
||||
self.eval_callback: Optional[EvalCallback] = None
|
||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||
self.rl_config = self.freqai_info['rl_config']
|
||||
self.continual_learning = self.freqai_info.get('continual_learning', False)
|
||||
if self.model_type in SB3_MODELS:
|
||||
import_str = 'stable_baselines3'
|
||||
elif self.model_type in SB3_CONTRIB_MODELS:
|
||||
import_str = 'sb3_contrib'
|
||||
else:
|
||||
raise OperationalException(f'{self.model_type} not available in stable_baselines3 or '
|
||||
f'sb3_contrib. please choose one of {SB3_MODELS} or '
|
||||
f'{SB3_CONTRIB_MODELS}')
|
||||
|
||||
mod = importlib.import_module(import_str, self.model_type)
|
||||
self.MODELCLASS = getattr(mod, self.model_type)
|
||||
self.policy_type = self.freqai_info['rl_config']['policy_type']
|
||||
self.unset_outlier_removal()
|
||||
self.net_arch = self.rl_config.get('net_arch', [128, 128])
|
||||
self.dd.model_type = import_str
|
||||
|
||||
def unset_outlier_removal(self):
|
||||
"""
|
||||
If user has activated any function that may remove training points, this
|
||||
function will set them to false and warn them
|
||||
"""
|
||||
if self.ft_params.get('use_SVM_to_remove_outliers', False):
|
||||
self.ft_params.update({'use_SVM_to_remove_outliers': False})
|
||||
logger.warning('User tried to use SVM with RL. Deactivating SVM.')
|
||||
if self.ft_params.get('use_DBSCAN_to_remove_outliers', False):
|
||||
self.ft_params.update({'use_DBSCAN_to_remove_outliers': False})
|
||||
logger.warning('User tried to use DBSCAN with RL. Deactivating DBSCAN.')
|
||||
if self.freqai_info['data_split_parameters'].get('shuffle', False):
|
||||
self.freqai_info['data_split_parameters'].update({'shuffle': False})
|
||||
logger.warning('User tried to shuffle training data. Setting shuffle to False')
|
||||
|
||||
def train(
|
||||
self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Filter the training data and train a model to it. Train makes heavy use of the datakitchen
|
||||
for storing, saving, loading, and analyzing the data.
|
||||
:param unfiltered_df: Full dataframe for the current training period
|
||||
:param metadata: pair metadata from strategy.
|
||||
:returns:
|
||||
:model: Trained model which can be used to inference (self.predict)
|
||||
"""
|
||||
|
||||
logger.info("--------------------Starting training " f"{pair} --------------------")
|
||||
|
||||
features_filtered, labels_filtered = dk.filter_features(
|
||||
unfiltered_df,
|
||||
dk.training_features_list,
|
||||
dk.label_list,
|
||||
training_filter=True,
|
||||
)
|
||||
|
||||
data_dictionary: Dict[str, Any] = dk.make_train_test_datasets(
|
||||
features_filtered, labels_filtered)
|
||||
dk.fit_labels() # FIXME useless for now, but just satiating append methods
|
||||
|
||||
# normalize all data based on train_dataset only
|
||||
prices_train, prices_test = self.build_ohlc_price_dataframes(dk.data_dictionary, pair, dk)
|
||||
data_dictionary = dk.normalize_data(data_dictionary)
|
||||
|
||||
# data cleaning/analysis
|
||||
self.data_cleaning_train(dk)
|
||||
|
||||
logger.info(
|
||||
f'Training model on {len(dk.data_dictionary["train_features"].columns)}'
|
||||
f' features and {len(data_dictionary["train_features"])} data points'
|
||||
)
|
||||
|
||||
self.set_train_and_eval_environments(data_dictionary, prices_train, prices_test, dk)
|
||||
|
||||
model = self.fit(data_dictionary, dk)
|
||||
|
||||
logger.info(f"--------------------done training {pair}--------------------")
|
||||
|
||||
return model
|
||||
|
||||
def set_train_and_eval_environments(self, data_dictionary: Dict[str, DataFrame],
|
||||
prices_train: DataFrame, prices_test: DataFrame,
|
||||
dk: FreqaiDataKitchen):
|
||||
"""
|
||||
User can override this if they are using a custom MyRLEnv
|
||||
:param data_dictionary: dict = common data dictionary containing train and test
|
||||
features/labels/weights.
|
||||
:param prices_train/test: DataFrame = dataframe comprised of the prices to be used in the
|
||||
environment during training or testing
|
||||
:param dk: FreqaiDataKitchen = the datakitchen for the current pair
|
||||
"""
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
self.train_env = self.MyRLEnv(df=train_df,
|
||||
prices=prices_train,
|
||||
window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params,
|
||||
config=self.config,
|
||||
dp=self.data_provider)
|
||||
self.eval_env = Monitor(self.MyRLEnv(df=test_df,
|
||||
prices=prices_test,
|
||||
window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params,
|
||||
config=self.config,
|
||||
dp=self.data_provider))
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=len(train_df),
|
||||
best_model_save_path=str(dk.data_path))
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||
"""
|
||||
Agent customizations and abstract Reinforcement Learning customizations
|
||||
go in here. Abstract method, so this function must be overridden by
|
||||
user class.
|
||||
"""
|
||||
return
|
||||
|
||||
def get_state_info(self, pair: str) -> Tuple[float, float, int]:
|
||||
"""
|
||||
State info during dry/live (not backtesting) which is fed back
|
||||
into the model.
|
||||
:param pair: str = COIN/STAKE to get the environment information for
|
||||
:return:
|
||||
:market_side: float = representing short, long, or neutral for
|
||||
pair
|
||||
:current_profit: float = unrealized profit of the current trade
|
||||
:trade_duration: int = the number of candles that the trade has
|
||||
been open for
|
||||
"""
|
||||
open_trades = Trade.get_trades_proxy(is_open=True)
|
||||
market_side = 0.5
|
||||
current_profit: float = 0
|
||||
trade_duration = 0
|
||||
for trade in open_trades:
|
||||
if trade.pair == pair:
|
||||
if self.data_provider._exchange is None: # type: ignore
|
||||
logger.error('No exchange available.')
|
||||
return 0, 0, 0
|
||||
else:
|
||||
current_rate = self.data_provider._exchange.get_rate( # type: ignore
|
||||
pair, refresh=False, side="exit", is_short=trade.is_short)
|
||||
|
||||
now = datetime.now(timezone.utc).timestamp()
|
||||
trade_duration = int((now - trade.open_date_utc.timestamp()) / self.base_tf_seconds)
|
||||
current_profit = trade.calc_profit_ratio(current_rate)
|
||||
if trade.is_short:
|
||||
market_side = 0
|
||||
else:
|
||||
market_side = 1
|
||||
|
||||
return market_side, current_profit, int(trade_duration)
|
||||
|
||||
def predict(
|
||||
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||
"""
|
||||
Filter the prediction features data and predict with it.
|
||||
:param unfiltered_dataframe: Full dataframe for the current backtest period.
|
||||
:return:
|
||||
:pred_df: dataframe containing the predictions
|
||||
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
|
||||
data (NaNs) or felt uncertain about data (PCA and DI index)
|
||||
"""
|
||||
|
||||
dk.find_features(unfiltered_df)
|
||||
filtered_dataframe, _ = dk.filter_features(
|
||||
unfiltered_df, dk.training_features_list, training_filter=False
|
||||
)
|
||||
filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe)
|
||||
dk.data_dictionary["prediction_features"] = filtered_dataframe
|
||||
|
||||
# optional additional data cleaning/analysis
|
||||
self.data_cleaning_predict(dk)
|
||||
|
||||
pred_df = self.rl_model_predict(
|
||||
dk.data_dictionary["prediction_features"], dk, self.model)
|
||||
pred_df.fillna(0, inplace=True)
|
||||
|
||||
return (pred_df, dk.do_predict)
|
||||
|
||||
def rl_model_predict(self, dataframe: DataFrame,
|
||||
dk: FreqaiDataKitchen, model: Any) -> DataFrame:
|
||||
"""
|
||||
A helper function to make predictions in the Reinforcement learning module.
|
||||
:param dataframe: DataFrame = the dataframe of features to make the predictions on
|
||||
:param dk: FreqaiDatakitchen = data kitchen for the current pair
|
||||
:param model: Any = the trained model used to inference the features.
|
||||
"""
|
||||
output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
|
||||
|
||||
def _predict(window):
|
||||
observations = dataframe.iloc[window.index]
|
||||
if self.live and self.rl_config.get('add_state_info', False):
|
||||
market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
|
||||
observations['current_profit_pct'] = current_profit
|
||||
observations['position'] = market_side
|
||||
observations['trade_duration'] = trade_duration
|
||||
res, _ = model.predict(observations, deterministic=True)
|
||||
return res
|
||||
|
||||
output = output.rolling(window=self.CONV_WIDTH).apply(_predict)
|
||||
|
||||
return output
|
||||
|
||||
def build_ohlc_price_dataframes(self, data_dictionary: dict,
|
||||
pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame,
|
||||
DataFrame]:
|
||||
"""
|
||||
Builds the train prices and test prices for the environment.
|
||||
"""
|
||||
|
||||
pair = pair.replace(':', '')
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
# price data for model training and evaluation
|
||||
tf = self.config['timeframe']
|
||||
ohlc_list = [f'%-{pair}raw_open_{tf}', f'%-{pair}raw_low_{tf}',
|
||||
f'%-{pair}raw_high_{tf}', f'%-{pair}raw_close_{tf}']
|
||||
rename_dict = {f'%-{pair}raw_open_{tf}': 'open', f'%-{pair}raw_low_{tf}': 'low',
|
||||
f'%-{pair}raw_high_{tf}': ' high', f'%-{pair}raw_close_{tf}': 'close'}
|
||||
|
||||
prices_train = train_df.filter(ohlc_list, axis=1)
|
||||
if prices_train.empty:
|
||||
raise OperationalException('Reinforcement learning module didnt find the raw prices '
|
||||
'assigned in populate_any_indicators. Please assign them '
|
||||
'with:\n'
|
||||
'informative[f"%-{pair}raw_close"] = informative["close"]\n'
|
||||
'informative[f"%-{pair}raw_open"] = informative["open"]\n'
|
||||
'informative[f"%-{pair}raw_high"] = informative["high"]\n'
|
||||
'informative[f"%-{pair}raw_low"] = informative["low"]\n')
|
||||
prices_train.rename(columns=rename_dict, inplace=True)
|
||||
prices_train.reset_index(drop=True)
|
||||
|
||||
prices_test = test_df.filter(ohlc_list, axis=1)
|
||||
prices_test.rename(columns=rename_dict, inplace=True)
|
||||
prices_test.reset_index(drop=True)
|
||||
|
||||
return prices_train, prices_test
|
||||
|
||||
def load_model_from_disk(self, dk: FreqaiDataKitchen) -> Any:
|
||||
"""
|
||||
Can be used by user if they are trying to limit_ram_usage *and*
|
||||
perform continual learning.
|
||||
For now, this is unused.
|
||||
"""
|
||||
exists = Path(dk.data_path / f"{dk.model_filename}_model").is_file()
|
||||
if exists:
|
||||
model = self.MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")
|
||||
else:
|
||||
logger.info('No model file on disk to continue learning from.')
|
||||
|
||||
return model
|
||||
|
||||
def _on_stop(self):
|
||||
"""
|
||||
Hook called on bot shutdown. Close SubprocVecEnv subprocesses for clean shutdown.
|
||||
"""
|
||||
|
||||
if self.train_env:
|
||||
self.train_env.close()
|
||||
|
||||
if self.eval_env:
|
||||
self.eval_env.close()
|
||||
|
||||
# Nested class which can be overridden by user to customize further
|
||||
class MyRLEnv(Base5ActionRLEnv):
|
||||
"""
|
||||
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||
sets a custom reward based on profit and trade duration.
|
||||
"""
|
||||
|
||||
def calculate_reward(self, action: int) -> float:
|
||||
"""
|
||||
An example reward function. This is the one function that users will likely
|
||||
wish to inject their own creativity into.
|
||||
:param action: int = The action made by the agent for the current candle.
|
||||
:return:
|
||||
float = the reward to give to the agent for current step (used for optimization
|
||||
of weights in NN)
|
||||
"""
|
||||
# first, penalize if the action is not valid
|
||||
if not self._is_valid(action):
|
||||
return -2
|
||||
|
||||
pnl = self.get_unrealized_profit()
|
||||
factor = 100.
|
||||
|
||||
# reward agent for entering trades
|
||||
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
||||
and self._position == Positions.Neutral):
|
||||
return 25
|
||||
# discourage agent from not entering trades
|
||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||
return -1
|
||||
|
||||
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
||||
if self._last_trade_tick:
|
||||
trade_duration = self._current_tick - self._last_trade_tick
|
||||
else:
|
||||
trade_duration = 0
|
||||
|
||||
if trade_duration <= max_trade_duration:
|
||||
factor *= 1.5
|
||||
elif trade_duration > max_trade_duration:
|
||||
factor *= 0.5
|
||||
|
||||
# discourage sitting in position
|
||||
if (self._position in (Positions.Short, Positions.Long) and
|
||||
action == Actions.Neutral.value):
|
||||
return -1 * trade_duration / max_trade_duration
|
||||
|
||||
# close long
|
||||
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
||||
if pnl > self.profit_aim * self.rr:
|
||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||
return float(pnl * factor)
|
||||
|
||||
# close short
|
||||
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
||||
if pnl > self.profit_aim * self.rr:
|
||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||
return float(pnl * factor)
|
||||
|
||||
return 0.
|
||||
|
||||
|
||||
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
||||
seed: int, train_df: DataFrame, price: DataFrame,
|
||||
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
||||
config: Dict[str, Any] = {}) -> Callable:
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
:param env_id: (str) the environment ID
|
||||
:param num_env: (int) the number of environment you wish to have in subprocesses
|
||||
:param seed: (int) the inital seed for RNG
|
||||
:param rank: (int) index of the subprocess
|
||||
:return: (Callable)
|
||||
"""
|
||||
|
||||
def _init() -> gym.Env:
|
||||
|
||||
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
||||
reward_kwargs=reward_params, id=env_id, seed=seed + rank, config=config)
|
||||
if monitor:
|
||||
env = Monitor(env)
|
||||
return env
|
||||
set_random_seed(seed)
|
||||
return _init
|
0
freqtrade/freqai/RL/__init__.py
Normal file
0
freqtrade/freqai/RL/__init__.py
Normal file
@@ -1,9 +1,10 @@
|
||||
import collections
|
||||
import importlib
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple, TypedDict
|
||||
|
||||
@@ -81,6 +82,7 @@ class FreqaiDataDrawer:
|
||||
self.historic_predictions_bkp_path = Path(
|
||||
self.full_path / "historic_predictions.backup.pkl")
|
||||
self.pair_dictionary_path = Path(self.full_path / "pair_dictionary.json")
|
||||
self.global_metadata_path = Path(self.full_path / "global_metadata.json")
|
||||
self.metric_tracker_path = Path(self.full_path / "metric_tracker.json")
|
||||
self.follow_mode = follow_mode
|
||||
if follow_mode:
|
||||
@@ -98,6 +100,7 @@ class FreqaiDataDrawer:
|
||||
self.empty_pair_dict: pair_info = {
|
||||
"model_filename": "", "trained_timestamp": 0,
|
||||
"data_path": "", "extras": {}}
|
||||
self.model_type = self.freqai_info.get('model_save_type', 'joblib')
|
||||
|
||||
def update_metric_tracker(self, metric: str, value: float, pair: str) -> None:
|
||||
"""
|
||||
@@ -125,6 +128,17 @@ class FreqaiDataDrawer:
|
||||
self.update_metric_tracker('cpu_load5min', load5 / cpus, pair)
|
||||
self.update_metric_tracker('cpu_load15min', load15 / cpus, pair)
|
||||
|
||||
def load_global_metadata_from_disk(self):
|
||||
"""
|
||||
Locate and load a previously saved global metadata in present model folder.
|
||||
"""
|
||||
exists = self.global_metadata_path.is_file()
|
||||
if exists:
|
||||
with open(self.global_metadata_path, "r") as fp:
|
||||
metatada_dict = rapidjson.load(fp, number_mode=rapidjson.NM_NATIVE)
|
||||
return metatada_dict
|
||||
return {}
|
||||
|
||||
def load_drawer_from_disk(self):
|
||||
"""
|
||||
Locate and load a previously saved data drawer full of all pair model metadata in
|
||||
@@ -225,6 +239,15 @@ class FreqaiDataDrawer:
|
||||
rapidjson.dump(self.follower_dict, fp, default=self.np_encoder,
|
||||
number_mode=rapidjson.NM_NATIVE)
|
||||
|
||||
def save_global_metadata_to_disk(self, metadata: Dict[str, Any]):
|
||||
"""
|
||||
Save global metadata json to disk
|
||||
"""
|
||||
with self.save_lock:
|
||||
with open(self.global_metadata_path, 'w') as fp:
|
||||
rapidjson.dump(metadata, fp, default=self.np_encoder,
|
||||
number_mode=rapidjson.NM_NATIVE)
|
||||
|
||||
def create_follower_dict(self):
|
||||
"""
|
||||
Create or dictionary for each follower to maintain unique persistent prediction targets
|
||||
@@ -476,10 +499,12 @@ class FreqaiDataDrawer:
|
||||
save_path = Path(dk.data_path)
|
||||
|
||||
# Save the trained model
|
||||
if not dk.keras:
|
||||
if self.model_type == 'joblib':
|
||||
dump(model, save_path / f"{dk.model_filename}_model.joblib")
|
||||
else:
|
||||
elif self.model_type == 'keras':
|
||||
model.save(save_path / f"{dk.model_filename}_model.h5")
|
||||
elif 'stable_baselines' in self.model_type or 'sb3_contrib' == self.model_type:
|
||||
model.save(save_path / f"{dk.model_filename}_model.zip")
|
||||
|
||||
if dk.svm_model is not None:
|
||||
dump(dk.svm_model, save_path / f"{dk.model_filename}_svm_model.joblib")
|
||||
@@ -506,11 +531,10 @@ class FreqaiDataDrawer:
|
||||
dk.pca, open(dk.data_path / f"{dk.model_filename}_pca_object.pkl", "wb")
|
||||
)
|
||||
|
||||
# if self.live:
|
||||
# store as much in ram as possible to increase performance
|
||||
self.model_dictionary[coin] = model
|
||||
self.pair_dict[coin]["model_filename"] = dk.model_filename
|
||||
self.pair_dict[coin]["data_path"] = str(dk.data_path)
|
||||
|
||||
if coin not in self.meta_data_dictionary:
|
||||
self.meta_data_dictionary[coin] = {}
|
||||
self.meta_data_dictionary[coin]["train_df"] = dk.data_dictionary["train_features"]
|
||||
@@ -542,14 +566,6 @@ class FreqaiDataDrawer:
|
||||
if dk.live:
|
||||
dk.model_filename = self.pair_dict[coin]["model_filename"]
|
||||
dk.data_path = Path(self.pair_dict[coin]["data_path"])
|
||||
if self.freqai_info.get("follow_mode", False):
|
||||
# follower can be on a different system which is rsynced from the leader:
|
||||
dk.data_path = Path(
|
||||
self.config["user_data_dir"]
|
||||
/ "models"
|
||||
/ dk.data_path.parts[-2]
|
||||
/ dk.data_path.parts[-1]
|
||||
)
|
||||
|
||||
if coin in self.meta_data_dictionary:
|
||||
dk.data = self.meta_data_dictionary[coin]["meta_data"]
|
||||
@@ -568,12 +584,16 @@ class FreqaiDataDrawer:
|
||||
# try to access model in memory instead of loading object from disk to save time
|
||||
if dk.live and coin in self.model_dictionary:
|
||||
model = self.model_dictionary[coin]
|
||||
elif not dk.keras:
|
||||
elif self.model_type == 'joblib':
|
||||
model = load(dk.data_path / f"{dk.model_filename}_model.joblib")
|
||||
else:
|
||||
elif self.model_type == 'keras':
|
||||
from tensorflow import keras
|
||||
|
||||
model = keras.models.load_model(dk.data_path / f"{dk.model_filename}_model.h5")
|
||||
elif 'stable_baselines' in self.model_type or 'sb3_contrib' == self.model_type:
|
||||
mod = importlib.import_module(
|
||||
self.model_type, self.freqai_info['rl_config']['model_type'])
|
||||
MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type'])
|
||||
model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")
|
||||
|
||||
if Path(dk.data_path / f"{dk.model_filename}_svm_model.joblib").is_file():
|
||||
dk.svm_model = load(dk.data_path / f"{dk.model_filename}_svm_model.joblib")
|
||||
@@ -583,6 +603,10 @@ class FreqaiDataDrawer:
|
||||
f"Unable to load model, ensure model exists at " f"{dk.data_path} "
|
||||
)
|
||||
|
||||
# load it into ram if it was loaded from disk
|
||||
if coin not in self.model_dictionary:
|
||||
self.model_dictionary[coin] = model
|
||||
|
||||
if self.config["freqai"]["feature_parameters"]["principal_component_analysis"]:
|
||||
dk.pca = cloudpickle.load(
|
||||
open(dk.data_path / f"{dk.model_filename}_pca_object.pkl", "rb")
|
||||
@@ -693,3 +717,31 @@ class FreqaiDataDrawer:
|
||||
).reset_index(drop=True)
|
||||
|
||||
return corr_dataframes, base_dataframes
|
||||
|
||||
def get_timerange_from_live_historic_predictions(self) -> TimeRange:
|
||||
"""
|
||||
Returns timerange information based on historic predictions file
|
||||
:return: timerange calculated from saved live data
|
||||
"""
|
||||
if not self.historic_predictions_path.is_file():
|
||||
raise OperationalException(
|
||||
'Historic predictions not found. Historic predictions data is required '
|
||||
'to run backtest with the freqai-backtest-live-models option '
|
||||
)
|
||||
|
||||
self.load_historic_predictions_from_disk()
|
||||
|
||||
all_pairs_end_dates = []
|
||||
for pair in self.historic_predictions:
|
||||
pair_historic_data = self.historic_predictions[pair]
|
||||
all_pairs_end_dates.append(pair_historic_data.date_pred.max())
|
||||
|
||||
global_metadata = self.load_global_metadata_from_disk()
|
||||
start_date = datetime.fromtimestamp(int(global_metadata["start_dry_live_date"]))
|
||||
end_date = max(all_pairs_end_dates)
|
||||
# add 1 day to string timerange to ensure BT module will load all dataframe data
|
||||
end_date = end_date + timedelta(days=1)
|
||||
backtesting_timerange = TimeRange(
|
||||
'date', 'date', int(start_date.timestamp()), int(end_date.timestamp())
|
||||
)
|
||||
return backtesting_timerange
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import copy
|
||||
import logging
|
||||
import shutil
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, timezone
|
||||
from math import cos, sin
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Tuple
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pandas as pd
|
||||
import psutil
|
||||
from pandas import DataFrame
|
||||
from scipy import stats
|
||||
from sklearn import linear_model
|
||||
@@ -86,12 +87,7 @@ class FreqaiDataKitchen:
|
||||
if not self.live:
|
||||
self.full_path = self.get_full_models_path(self.config)
|
||||
|
||||
if self.backtest_live_models:
|
||||
if self.pair:
|
||||
self.set_timerange_from_ready_models()
|
||||
(self.training_timeranges,
|
||||
self.backtesting_timeranges) = self.split_timerange_live_models()
|
||||
else:
|
||||
if not self.backtest_live_models:
|
||||
self.full_timerange = self.create_fulltimerange(
|
||||
self.config["timerange"], self.freqai_config.get("train_period_days", 0)
|
||||
)
|
||||
@@ -102,7 +98,10 @@ class FreqaiDataKitchen:
|
||||
)
|
||||
|
||||
self.data['extra_returns_per_train'] = self.freqai_config.get('extra_returns_per_train', {})
|
||||
self.thread_count = self.freqai_config.get("data_kitchen_thread_count", -1)
|
||||
if not self.freqai_config.get("data_kitchen_thread_count", 0):
|
||||
self.thread_count = max(int(psutil.cpu_count() * 2 - 2), 1)
|
||||
else:
|
||||
self.thread_count = self.freqai_config["data_kitchen_thread_count"]
|
||||
self.train_dates: DataFrame = pd.DataFrame()
|
||||
self.unique_classes: Dict[str, list] = {}
|
||||
self.unique_class_list: list = []
|
||||
@@ -456,29 +455,6 @@ class FreqaiDataKitchen:
|
||||
# print(tr_training_list, tr_backtesting_list)
|
||||
return tr_training_list_timerange, tr_backtesting_list_timerange
|
||||
|
||||
def split_timerange_live_models(
|
||||
self
|
||||
) -> Tuple[list, list]:
|
||||
|
||||
tr_backtesting_list_timerange = []
|
||||
asset = self.pair.split("/")[0]
|
||||
if asset not in self.backtest_live_models_data["assets_end_dates"]:
|
||||
raise OperationalException(
|
||||
f"Model not available for pair {self.pair}. "
|
||||
"Please, try again after removing this pair from the configuration file."
|
||||
)
|
||||
asset_data = self.backtest_live_models_data["assets_end_dates"][asset]
|
||||
backtesting_timerange = self.backtest_live_models_data["backtesting_timerange"]
|
||||
model_end_dates = [x for x in asset_data]
|
||||
model_end_dates.append(backtesting_timerange.stopts)
|
||||
model_end_dates.sort()
|
||||
for index, item in enumerate(model_end_dates):
|
||||
if len(model_end_dates) > (index + 1):
|
||||
tr_to_add = TimeRange("date", "date", item, model_end_dates[index + 1])
|
||||
tr_backtesting_list_timerange.append(tr_to_add)
|
||||
|
||||
return tr_backtesting_list_timerange, tr_backtesting_list_timerange
|
||||
|
||||
def slice_dataframe(self, timerange: TimeRange, df: DataFrame) -> DataFrame:
|
||||
"""
|
||||
Given a full dataframe, extract the user desired window
|
||||
@@ -988,7 +964,8 @@ class FreqaiDataKitchen:
|
||||
return weights
|
||||
|
||||
def get_predictions_to_append(self, predictions: DataFrame,
|
||||
do_predict: npt.ArrayLike) -> DataFrame:
|
||||
do_predict: npt.ArrayLike,
|
||||
dataframe_backtest: DataFrame) -> DataFrame:
|
||||
"""
|
||||
Get backtest prediction from current backtest period
|
||||
"""
|
||||
@@ -1010,7 +987,9 @@ class FreqaiDataKitchen:
|
||||
if self.freqai_config["feature_parameters"].get("DI_threshold", 0) > 0:
|
||||
append_df["DI_values"] = self.DI_values
|
||||
|
||||
return append_df
|
||||
dataframe_backtest.reset_index(drop=True, inplace=True)
|
||||
merged_df = pd.concat([dataframe_backtest["date"], append_df], axis=1)
|
||||
return merged_df
|
||||
|
||||
def append_predictions(self, append_df: DataFrame) -> None:
|
||||
"""
|
||||
@@ -1020,23 +999,18 @@ class FreqaiDataKitchen:
|
||||
if self.full_df.empty:
|
||||
self.full_df = append_df
|
||||
else:
|
||||
self.full_df = pd.concat([self.full_df, append_df], axis=0)
|
||||
self.full_df = pd.concat([self.full_df, append_df], axis=0, ignore_index=True)
|
||||
|
||||
def fill_predictions(self, dataframe):
|
||||
"""
|
||||
Back fill values to before the backtesting range so that the dataframe matches size
|
||||
when it goes back to the strategy. These rows are not included in the backtest.
|
||||
"""
|
||||
|
||||
len_filler = len(dataframe) - len(self.full_df.index) # startup_candle_count
|
||||
filler_df = pd.DataFrame(
|
||||
np.zeros((len_filler, len(self.full_df.columns))), columns=self.full_df.columns
|
||||
)
|
||||
|
||||
self.full_df = pd.concat([filler_df, self.full_df], axis=0, ignore_index=True)
|
||||
|
||||
to_keep = [col for col in dataframe.columns if not col.startswith("&")]
|
||||
self.return_dataframe = pd.concat([dataframe[to_keep], self.full_df], axis=1)
|
||||
self.return_dataframe = pd.merge(dataframe[to_keep],
|
||||
self.full_df, how='left', on='date')
|
||||
self.return_dataframe[self.full_df.columns] = (
|
||||
self.return_dataframe[self.full_df.columns].fillna(value=0))
|
||||
self.full_df = DataFrame()
|
||||
|
||||
return
|
||||
@@ -1334,22 +1308,22 @@ class FreqaiDataKitchen:
|
||||
self, append_df: DataFrame
|
||||
) -> None:
|
||||
"""
|
||||
Save prediction dataframe from backtesting to h5 file format
|
||||
Save prediction dataframe from backtesting to feather file format
|
||||
:param append_df: dataframe for backtesting period
|
||||
"""
|
||||
full_predictions_folder = Path(self.full_path / self.backtest_predictions_folder)
|
||||
if not full_predictions_folder.is_dir():
|
||||
full_predictions_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
append_df.to_hdf(self.backtesting_results_path, key='append_df', mode='w')
|
||||
append_df.to_feather(self.backtesting_results_path)
|
||||
|
||||
def get_backtesting_prediction(
|
||||
self
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Get prediction dataframe from h5 file format
|
||||
Get prediction dataframe from feather file format
|
||||
"""
|
||||
append_df = pd.read_hdf(self.backtesting_results_path)
|
||||
append_df = pd.read_feather(self.backtesting_results_path)
|
||||
return append_df
|
||||
|
||||
def check_if_backtest_prediction_is_valid(
|
||||
@@ -1365,19 +1339,20 @@ class FreqaiDataKitchen:
|
||||
"""
|
||||
path_to_predictionfile = Path(self.full_path /
|
||||
self.backtest_predictions_folder /
|
||||
f"{self.model_filename}_prediction.h5")
|
||||
f"{self.model_filename}_prediction.feather")
|
||||
self.backtesting_results_path = path_to_predictionfile
|
||||
|
||||
file_exists = path_to_predictionfile.is_file()
|
||||
|
||||
if file_exists:
|
||||
append_df = self.get_backtesting_prediction()
|
||||
if len(append_df) == len_backtest_df:
|
||||
if len(append_df) == len_backtest_df and 'date' in append_df:
|
||||
logger.info(f"Found backtesting prediction file at {path_to_predictionfile}")
|
||||
return True
|
||||
else:
|
||||
logger.info("A new backtesting prediction file is required. "
|
||||
"(Number of predictions is different from dataframe length).")
|
||||
"(Number of predictions is different from dataframe length or "
|
||||
"old prediction file version).")
|
||||
return False
|
||||
else:
|
||||
logger.info(
|
||||
@@ -1385,17 +1360,6 @@ class FreqaiDataKitchen:
|
||||
)
|
||||
return False
|
||||
|
||||
def set_timerange_from_ready_models(self):
|
||||
backtesting_timerange, \
|
||||
assets_end_dates = (
|
||||
self.get_timerange_and_assets_end_dates_from_ready_models(self.full_path))
|
||||
|
||||
self.backtest_live_models_data = {
|
||||
"backtesting_timerange": backtesting_timerange,
|
||||
"assets_end_dates": assets_end_dates
|
||||
}
|
||||
return
|
||||
|
||||
def get_full_models_path(self, config: Config) -> Path:
|
||||
"""
|
||||
Returns default FreqAI model path
|
||||
@@ -1406,88 +1370,6 @@ class FreqaiDataKitchen:
|
||||
config["user_data_dir"] / "models" / str(freqai_config.get("identifier"))
|
||||
)
|
||||
|
||||
def get_timerange_and_assets_end_dates_from_ready_models(
|
||||
self, models_path: Path) -> Tuple[TimeRange, Dict[str, Any]]:
|
||||
"""
|
||||
Returns timerange information based on a FreqAI model directory
|
||||
:param models_path: FreqAI model path
|
||||
|
||||
:return: a Tuple with (Timerange calculated from directory and
|
||||
a Dict with pair and model end training dates info)
|
||||
"""
|
||||
all_models_end_dates = []
|
||||
assets_end_dates: Dict[str, Any] = self.get_assets_timestamps_training_from_ready_models(
|
||||
models_path)
|
||||
for key in assets_end_dates:
|
||||
for model_end_date in assets_end_dates[key]:
|
||||
if model_end_date not in all_models_end_dates:
|
||||
all_models_end_dates.append(model_end_date)
|
||||
|
||||
if len(all_models_end_dates) == 0:
|
||||
raise OperationalException(
|
||||
'At least 1 saved model is required to '
|
||||
'run backtest with the freqai-backtest-live-models option'
|
||||
)
|
||||
|
||||
if len(all_models_end_dates) == 1:
|
||||
logger.warning(
|
||||
"Only 1 model was found. Backtesting will run with the "
|
||||
"timerange from the end of the training date to the current date"
|
||||
)
|
||||
|
||||
finish_timestamp = int(datetime.now(tz=timezone.utc).timestamp())
|
||||
if len(all_models_end_dates) > 1:
|
||||
# After last model end date, use the same period from previous model
|
||||
# to finish the backtest
|
||||
all_models_end_dates.sort(reverse=True)
|
||||
finish_timestamp = all_models_end_dates[0] + \
|
||||
(all_models_end_dates[0] - all_models_end_dates[1])
|
||||
|
||||
all_models_end_dates.append(finish_timestamp)
|
||||
all_models_end_dates.sort()
|
||||
start_date = (datetime(*datetime.fromtimestamp(min(all_models_end_dates),
|
||||
timezone.utc).timetuple()[:3], tzinfo=timezone.utc))
|
||||
end_date = (datetime(*datetime.fromtimestamp(max(all_models_end_dates),
|
||||
timezone.utc).timetuple()[:3], tzinfo=timezone.utc))
|
||||
|
||||
# add 1 day to string timerange to ensure BT module will load all dataframe data
|
||||
end_date = end_date + timedelta(days=1)
|
||||
backtesting_timerange = TimeRange(
|
||||
'date', 'date', int(start_date.timestamp()), int(end_date.timestamp())
|
||||
)
|
||||
return backtesting_timerange, assets_end_dates
|
||||
|
||||
def get_assets_timestamps_training_from_ready_models(
|
||||
self, models_path: Path) -> Dict[str, Any]:
|
||||
"""
|
||||
Scan the models path and returns all assets end training dates (timestamp)
|
||||
:param models_path: FreqAI model path
|
||||
|
||||
:return: a Dict with asset and model end training dates info
|
||||
"""
|
||||
assets_end_dates: Dict[str, Any] = {}
|
||||
if not models_path.is_dir():
|
||||
raise OperationalException(
|
||||
'Model folders not found. Saved models are required '
|
||||
'to run backtest with the freqai-backtest-live-models option'
|
||||
)
|
||||
for model_dir in models_path.iterdir():
|
||||
if str(model_dir.name).startswith("sub-train"):
|
||||
model_end_date = int(model_dir.name.split("_")[1])
|
||||
asset = model_dir.name.split("_")[0].replace("sub-train-", "")
|
||||
model_file_name = (
|
||||
f"cb_{str(model_dir.name).replace('sub-train-', '').lower()}"
|
||||
"_model.joblib"
|
||||
)
|
||||
|
||||
model_path_file = Path(model_dir / model_file_name)
|
||||
if model_path_file.is_file():
|
||||
if asset not in assets_end_dates:
|
||||
assets_end_dates[asset] = []
|
||||
assets_end_dates[asset].append(model_end_date)
|
||||
|
||||
return assets_end_dates
|
||||
|
||||
def remove_special_chars_from_feature_names(self, dataframe: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Remove all special characters from feature strings (:)
|
||||
|
@@ -5,15 +5,17 @@ from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Tuple
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import psutil
|
||||
from numpy.typing import NDArray
|
||||
from pandas import DataFrame
|
||||
|
||||
from freqtrade.configuration import TimeRange
|
||||
from freqtrade.constants import Config
|
||||
from freqtrade.data.dataprovider import DataProvider
|
||||
from freqtrade.enums import RunMode
|
||||
from freqtrade.exceptions import OperationalException
|
||||
from freqtrade.exchange import timeframe_to_seconds
|
||||
@@ -67,6 +69,7 @@ class IFreqaiModel(ABC):
|
||||
self.save_backtest_models: bool = self.freqai_info.get("save_backtest_models", True)
|
||||
if self.save_backtest_models:
|
||||
logger.info('Backtesting module configured to save all models.')
|
||||
|
||||
self.dd = FreqaiDataDrawer(Path(self.full_path), self.config, self.follow_mode)
|
||||
# set current candle to arbitrary historical date
|
||||
self.current_candle: datetime = datetime.fromtimestamp(637887600, tz=timezone.utc)
|
||||
@@ -98,6 +101,9 @@ class IFreqaiModel(ABC):
|
||||
self.get_corr_dataframes: bool = True
|
||||
self._threads: List[threading.Thread] = []
|
||||
self._stop_event = threading.Event()
|
||||
self.metadata: Dict[str, Any] = self.dd.load_global_metadata_from_disk()
|
||||
self.data_provider: Optional[DataProvider] = None
|
||||
self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1)
|
||||
|
||||
record_params(config, self.full_path)
|
||||
|
||||
@@ -126,11 +132,13 @@ class IFreqaiModel(ABC):
|
||||
|
||||
self.live = strategy.dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
|
||||
self.dd.set_pair_dict_info(metadata)
|
||||
self.data_provider = strategy.dp
|
||||
|
||||
if self.live:
|
||||
self.inference_timer('start')
|
||||
self.dk = FreqaiDataKitchen(self.config, self.live, metadata["pair"])
|
||||
dk = self.start_live(dataframe, metadata, strategy, self.dk)
|
||||
dataframe = dk.remove_features_from_df(dk.return_dataframe)
|
||||
|
||||
# For backtesting, each pair enters and then gets trained for each window along the
|
||||
# sliding window defined by "train_period_days" (training window) and "live_retrain_hours"
|
||||
@@ -139,20 +147,24 @@ class IFreqaiModel(ABC):
|
||||
# the concatenated results for the full backtesting period back to the strategy.
|
||||
elif not self.follow_mode:
|
||||
self.dk = FreqaiDataKitchen(self.config, self.live, metadata["pair"])
|
||||
if self.dk.backtest_live_models:
|
||||
logger.info(
|
||||
f"Backtesting {len(self.dk.backtesting_timeranges)} timeranges (live models)")
|
||||
else:
|
||||
logger.info(f"Training {len(self.dk.training_timeranges)} timeranges")
|
||||
dataframe = self.dk.use_strategy_to_populate_indicators(
|
||||
strategy, prediction_dataframe=dataframe, pair=metadata["pair"]
|
||||
)
|
||||
dk = self.start_backtesting(dataframe, metadata, self.dk)
|
||||
if not self.config.get("freqai_backtest_live_models", False):
|
||||
logger.info(f"Training {len(self.dk.training_timeranges)} timeranges")
|
||||
dk = self.start_backtesting(dataframe, metadata, self.dk)
|
||||
dataframe = dk.remove_features_from_df(dk.return_dataframe)
|
||||
else:
|
||||
logger.info(
|
||||
"Backtesting using historic predictions (live models)")
|
||||
dk = self.start_backtesting_from_historic_predictions(
|
||||
dataframe, metadata, self.dk)
|
||||
dataframe = dk.return_dataframe
|
||||
|
||||
dataframe = dk.remove_features_from_df(dk.return_dataframe)
|
||||
self.clean_up()
|
||||
if self.live:
|
||||
self.inference_timer('stop', metadata["pair"])
|
||||
|
||||
return dataframe
|
||||
|
||||
def clean_up(self):
|
||||
@@ -164,6 +176,13 @@ class IFreqaiModel(ABC):
|
||||
self.model = None
|
||||
self.dk = None
|
||||
|
||||
def _on_stop(self):
|
||||
"""
|
||||
Callback for Subclasses to override to include logic for shutting down resources
|
||||
when SIGINT is sent.
|
||||
"""
|
||||
return
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Cleans up threads on Shutdown, set stop event. Join threads to wait
|
||||
@@ -172,6 +191,9 @@ class IFreqaiModel(ABC):
|
||||
logger.info("Stopping FreqAI")
|
||||
self._stop_event.set()
|
||||
|
||||
self.data_provider = None
|
||||
self._on_stop()
|
||||
|
||||
logger.info("Waiting on Training iteration")
|
||||
for _thread in self._threads:
|
||||
_thread.join()
|
||||
@@ -301,10 +323,11 @@ class IFreqaiModel(ABC):
|
||||
self.model = self.dd.load_data(pair, dk)
|
||||
|
||||
pred_df, do_preds = self.predict(dataframe_backtest, dk)
|
||||
append_df = dk.get_predictions_to_append(pred_df, do_preds)
|
||||
append_df = dk.get_predictions_to_append(pred_df, do_preds, dataframe_backtest)
|
||||
dk.append_predictions(append_df)
|
||||
dk.save_backtesting_prediction(append_df)
|
||||
|
||||
self.backtesting_fit_live_predictions(dk)
|
||||
dk.fill_predictions(dataframe)
|
||||
|
||||
return dk
|
||||
@@ -617,6 +640,8 @@ class IFreqaiModel(ABC):
|
||||
self.dd.historic_predictions[pair] = pred_df
|
||||
hist_preds_df = self.dd.historic_predictions[pair]
|
||||
|
||||
self.set_start_dry_live_date(strat_df)
|
||||
|
||||
for label in hist_preds_df.columns:
|
||||
if hist_preds_df[label].dtype == object:
|
||||
continue
|
||||
@@ -629,7 +654,7 @@ class IFreqaiModel(ABC):
|
||||
hist_preds_df['DI_values'] = 0
|
||||
|
||||
for return_str in dk.data['extra_returns_per_train']:
|
||||
hist_preds_df[return_str] = 0
|
||||
hist_preds_df[return_str] = dk.data['extra_returns_per_train'][return_str]
|
||||
|
||||
hist_preds_df['close_price'] = strat_df['close']
|
||||
hist_preds_df['date_pred'] = strat_df['date']
|
||||
@@ -657,7 +682,8 @@ class IFreqaiModel(ABC):
|
||||
for label in full_labels:
|
||||
if self.dd.historic_predictions[dk.pair][label].dtype == object:
|
||||
continue
|
||||
f = spy.stats.norm.fit(self.dd.historic_predictions[dk.pair][label].tail(num_candles))
|
||||
f = spy.stats.norm.fit(
|
||||
self.dd.historic_predictions[dk.pair][label].tail(num_candles))
|
||||
dk.data["labels_mean"][label], dk.data["labels_std"][label] = f[0], f[1]
|
||||
|
||||
return
|
||||
@@ -811,6 +837,81 @@ class IFreqaiModel(ABC):
|
||||
f"to {tr_train.stop_fmt}, {train_it}/{total_trains} "
|
||||
"trains"
|
||||
)
|
||||
|
||||
def backtesting_fit_live_predictions(self, dk: FreqaiDataKitchen):
|
||||
"""
|
||||
Apply fit_live_predictions function in backtesting with a dummy historic_predictions
|
||||
The loop is required to simulate dry/live operation, as it is not possible to predict
|
||||
the type of logic implemented by the user.
|
||||
:param dk: datakitchen object
|
||||
"""
|
||||
fit_live_predictions_candles = self.freqai_info.get("fit_live_predictions_candles", 0)
|
||||
if fit_live_predictions_candles:
|
||||
logger.info("Applying fit_live_predictions in backtesting")
|
||||
label_columns = [col for col in dk.full_df.columns if (
|
||||
col.startswith("&") and
|
||||
not (col.startswith("&") and col.endswith("_mean")) and
|
||||
not (col.startswith("&") and col.endswith("_std")) and
|
||||
col not in self.dk.data["extra_returns_per_train"])
|
||||
]
|
||||
|
||||
for index in range(len(dk.full_df)):
|
||||
if index >= fit_live_predictions_candles:
|
||||
self.dd.historic_predictions[self.dk.pair] = (
|
||||
dk.full_df.iloc[index - fit_live_predictions_candles:index])
|
||||
self.fit_live_predictions(self.dk, self.dk.pair)
|
||||
for label in label_columns:
|
||||
if dk.full_df[label].dtype == object:
|
||||
continue
|
||||
if "labels_mean" in self.dk.data:
|
||||
dk.full_df.at[index, f"{label}_mean"] = (
|
||||
self.dk.data["labels_mean"][label])
|
||||
if "labels_std" in self.dk.data:
|
||||
dk.full_df.at[index, f"{label}_std"] = self.dk.data["labels_std"][label]
|
||||
|
||||
for extra_col in self.dk.data["extra_returns_per_train"]:
|
||||
dk.full_df.at[index, f"{extra_col}"] = (
|
||||
self.dk.data["extra_returns_per_train"][extra_col])
|
||||
|
||||
return
|
||||
|
||||
def update_metadata(self, metadata: Dict[str, Any]):
|
||||
"""
|
||||
Update global metadata and save the updated json file
|
||||
:param metadata: new global metadata dict
|
||||
"""
|
||||
self.dd.save_global_metadata_to_disk(metadata)
|
||||
self.metadata = metadata
|
||||
|
||||
def set_start_dry_live_date(self, live_dataframe: DataFrame):
|
||||
key_name = "start_dry_live_date"
|
||||
if key_name not in self.metadata:
|
||||
metadata = self.metadata
|
||||
metadata[key_name] = int(
|
||||
pd.to_datetime(live_dataframe.tail(1)["date"].values[0]).timestamp())
|
||||
self.update_metadata(metadata)
|
||||
|
||||
def start_backtesting_from_historic_predictions(
|
||||
self, dataframe: DataFrame, metadata: dict, dk: FreqaiDataKitchen
|
||||
) -> FreqaiDataKitchen:
|
||||
"""
|
||||
:param dataframe: DataFrame = strategy passed dataframe
|
||||
:param metadata: Dict = pair metadata
|
||||
:param dk: FreqaiDataKitchen = Data management/analysis tool associated to present pair only
|
||||
:return:
|
||||
FreqaiDataKitchen = Data management/analysis tool associated to present pair only
|
||||
"""
|
||||
pair = metadata["pair"]
|
||||
dk.return_dataframe = dataframe
|
||||
saved_dataframe = self.dd.historic_predictions[pair]
|
||||
columns_to_drop = list(set(saved_dataframe.columns).intersection(
|
||||
dk.return_dataframe.columns))
|
||||
dk.return_dataframe = dk.return_dataframe.drop(columns=list(columns_to_drop))
|
||||
dk.return_dataframe = pd.merge(
|
||||
dk.return_dataframe, saved_dataframe, how='left', left_on='date', right_on="date_pred")
|
||||
# dk.return_dataframe = dk.return_dataframe[saved_dataframe.columns].fillna(0)
|
||||
return dk
|
||||
|
||||
# Following methods which are overridden by user made prediction models.
|
||||
# See freqai/prediction_models/CatboostPredictionModel.py for an example.
|
||||
|
||||
|
141
freqtrade/freqai/prediction_models/ReinforcementLearner.py
Normal file
141
freqtrade/freqai/prediction_models/ReinforcementLearner.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch as th
|
||||
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
|
||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
"""
|
||||
Reinforcement Learning Model prediction model.
|
||||
|
||||
Users can inherit from this class to make their own RL model with custom
|
||||
environment/training controls. Define the file as follows:
|
||||
|
||||
```
|
||||
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||
|
||||
class MyCoolRLModel(ReinforcementLearner):
|
||||
```
|
||||
|
||||
Save the file to `user_data/freqaimodels`, then run it with:
|
||||
|
||||
freqtrade trade --freqaimodel MyCoolRLModel --config config.json --strategy SomeCoolStrat
|
||||
|
||||
Here the users can override any of the functions
|
||||
available in the `IFreqaiModel` inheritance tree. Most importantly for RL, this
|
||||
is where the user overrides `MyRLEnv` (see below), to define custom
|
||||
`calculate_reward()` function, or to override any other parts of the environment.
|
||||
|
||||
This class also allows users to override any other part of the IFreqaiModel tree.
|
||||
For example, the user can override `def fit()` or `def train()` or `def predict()`
|
||||
to take fine-tuned control over these processes.
|
||||
|
||||
Another common override may be `def data_cleaning_predict()` where the user can
|
||||
take fine-tuned control over the data handling pipeline.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||
"""
|
||||
User customizable fit method
|
||||
:param data_dictionary: dict = common data dictionary containing all train/test
|
||||
features/labels/weights.
|
||||
:param dk: FreqaiDatakitchen = data kitchen for current pair.
|
||||
:return:
|
||||
model Any = trained model to be used for inference in dry/live/backtesting
|
||||
"""
|
||||
train_df = data_dictionary["train_features"]
|
||||
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
||||
|
||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||
net_arch=self.net_arch)
|
||||
|
||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=Path(
|
||||
dk.full_path / "tensorboard" / dk.pair.split('/')[0]),
|
||||
**self.freqai_info['model_training_parameters']
|
||||
)
|
||||
else:
|
||||
logger.info('Continual training activated - starting training from previously '
|
||||
'trained agent.')
|
||||
model = self.dd.model_dictionary[dk.pair]
|
||||
model.set_env(self.train_env)
|
||||
|
||||
model.learn(
|
||||
total_timesteps=int(total_timesteps),
|
||||
callback=self.eval_callback
|
||||
)
|
||||
|
||||
if Path(dk.data_path / "best_model.zip").is_file():
|
||||
logger.info('Callback found a best model.')
|
||||
best_model = self.MODELCLASS.load(dk.data_path / "best_model")
|
||||
return best_model
|
||||
|
||||
logger.info('Couldnt find best model, using final model instead.')
|
||||
|
||||
return model
|
||||
|
||||
class MyRLEnv(Base5ActionRLEnv):
|
||||
"""
|
||||
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||
sets a custom reward based on profit and trade duration.
|
||||
"""
|
||||
|
||||
def calculate_reward(self, action: int) -> float:
|
||||
"""
|
||||
An example reward function. This is the one function that users will likely
|
||||
wish to inject their own creativity into.
|
||||
:param action: int = The action made by the agent for the current candle.
|
||||
:return:
|
||||
float = the reward to give to the agent for current step (used for optimization
|
||||
of weights in NN)
|
||||
"""
|
||||
# first, penalize if the action is not valid
|
||||
if not self._is_valid(action):
|
||||
return -2
|
||||
|
||||
pnl = self.get_unrealized_profit()
|
||||
factor = 100.
|
||||
|
||||
# reward agent for entering trades
|
||||
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
||||
and self._position == Positions.Neutral):
|
||||
return 25
|
||||
# discourage agent from not entering trades
|
||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||
return -1
|
||||
|
||||
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
||||
trade_duration = self._current_tick - self._last_trade_tick # type: ignore
|
||||
|
||||
if trade_duration <= max_trade_duration:
|
||||
factor *= 1.5
|
||||
elif trade_duration > max_trade_duration:
|
||||
factor *= 0.5
|
||||
|
||||
# discourage sitting in position
|
||||
if (self._position in (Positions.Short, Positions.Long) and
|
||||
action == Actions.Neutral.value):
|
||||
return -1 * trade_duration / max_trade_duration
|
||||
|
||||
# close long
|
||||
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
||||
if pnl > self.profit_aim * self.rr:
|
||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||
return float(pnl * factor)
|
||||
|
||||
# close short
|
||||
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
||||
if pnl > self.profit_aim * self.rr:
|
||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||
return float(pnl * factor)
|
||||
|
||||
return 0.
|
@@ -0,0 +1,51 @@
|
||||
import logging
|
||||
from typing import Any, Dict # , Tuple
|
||||
|
||||
# import numpy.typing as npt
|
||||
from pandas import DataFrame
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReinforcementLearner_multiproc(ReinforcementLearner):
|
||||
"""
|
||||
Demonstration of how to build vectorized environments
|
||||
"""
|
||||
|
||||
def set_train_and_eval_environments(self, data_dictionary: Dict[str, Any],
|
||||
prices_train: DataFrame, prices_test: DataFrame,
|
||||
dk: FreqaiDataKitchen):
|
||||
"""
|
||||
User can override this if they are using a custom MyRLEnv
|
||||
:param data_dictionary: dict = common data dictionary containing train and test
|
||||
features/labels/weights.
|
||||
:param prices_train/test: DataFrame = dataframe comprised of the prices to be used in
|
||||
the environment during training
|
||||
or testing
|
||||
:param dk: FreqaiDataKitchen = the datakitchen for the current pair
|
||||
"""
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
env_id = "train_env"
|
||||
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
|
||||
self.reward_params, self.CONV_WIDTH, monitor=True,
|
||||
config=self.config) for i
|
||||
in range(self.max_threads)])
|
||||
|
||||
eval_env_id = 'eval_env'
|
||||
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
|
||||
test_df, prices_test,
|
||||
self.reward_params, self.CONV_WIDTH, monitor=True,
|
||||
config=self.config) for i
|
||||
in range(self.max_threads)])
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=len(train_df),
|
||||
best_model_save_path=str(dk.data_path))
|
@@ -14,6 +14,7 @@ from freqtrade.data.history.history_utils import refresh_backtest_ohlcv_data
|
||||
from freqtrade.exceptions import OperationalException
|
||||
from freqtrade.exchange import timeframe_to_seconds
|
||||
from freqtrade.exchange.exchange import market_is_active
|
||||
from freqtrade.freqai.data_drawer import FreqaiDataDrawer
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.plugins.pairlist.pairlist_helpers import dynamic_expand_pairlist
|
||||
|
||||
@@ -229,5 +230,6 @@ def get_timerange_backtest_live_models(config: Config) -> str:
|
||||
"""
|
||||
dk = FreqaiDataKitchen(config)
|
||||
models_path = dk.get_full_models_path(config)
|
||||
timerange, _ = dk.get_timerange_and_assets_end_dates_from_ready_models(models_path)
|
||||
dd = FreqaiDataDrawer(models_path, config)
|
||||
timerange = dd.get_timerange_from_live_historic_predictions()
|
||||
return timerange.timerange_str
|
||||
|
Reference in New Issue
Block a user