Merge branch 'develop' into backtest_fitlivepredictions
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
""" Freqtrade bot """
|
||||
__version__ = '2022.11.dev'
|
||||
__version__ = '2022.12.dev'
|
||||
|
||||
if 'dev' in __version__:
|
||||
try:
|
||||
|
@@ -512,6 +512,7 @@ CONF_SCHEMA = {
|
||||
'minimum': 0,
|
||||
'maximum': 65535
|
||||
},
|
||||
'secure': {'type': 'boolean', 'default': False},
|
||||
'ws_token': {'type': 'string'},
|
||||
},
|
||||
'required': ['name', 'host', 'ws_token']
|
||||
@@ -577,9 +578,27 @@ CONF_SCHEMA = {
|
||||
},
|
||||
},
|
||||
"model_training_parameters": {
|
||||
"type": "object"
|
||||
},
|
||||
"rl_config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"n_estimators": {"type": "integer", "default": 1000}
|
||||
"train_cycles": {"type": "integer"},
|
||||
"max_trade_duration_candles": {"type": "integer"},
|
||||
"add_state_info": {"type": "boolean", "default": False},
|
||||
"max_training_drawdown_pct": {"type": "number", "default": 0.02},
|
||||
"cpu_count": {"type": "integer", "default": 1},
|
||||
"model_type": {"type": "string", "default": "PPO"},
|
||||
"policy_type": {"type": "string", "default": "MlpPolicy"},
|
||||
"net_arch": {"type": "array", "default": [128, 128]},
|
||||
"randomize_startinng_position": {"type": "boolean", "default": False},
|
||||
"model_reward_parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"rr": {"type": "number", "default": 1},
|
||||
"profit_aim": {"type": "number", "default": 0.025}
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,7 @@ class Bybit(Exchange):
|
||||
"""
|
||||
|
||||
_ft_has: Dict = {
|
||||
"ohlcv_candle_limit": 200,
|
||||
"ohlcv_candle_limit": 1000,
|
||||
"ccxt_futures_name": "linear",
|
||||
"ohlcv_has_history": False,
|
||||
}
|
||||
|
@@ -218,3 +218,19 @@ class Kraken(Exchange):
|
||||
fees = sum(df['open_fund'] * df['open_mark'] * amount * time_in_ratio)
|
||||
|
||||
return fees if is_short else -fees
|
||||
|
||||
def _trades_contracts_to_amount(self, trades: List) -> List:
|
||||
"""
|
||||
Fix "last" id issue for kraken data downloads
|
||||
This whole override can probably be removed once the following
|
||||
issue is closed in ccxt: https://github.com/ccxt/ccxt/issues/15827
|
||||
"""
|
||||
super()._trades_contracts_to_amount(trades)
|
||||
if (
|
||||
len(trades) > 0
|
||||
and isinstance(trades[-1].get('info'), list)
|
||||
and len(trades[-1].get('info', [])) > 7
|
||||
):
|
||||
|
||||
trades[-1]['id'] = trades[-1].get('info', [])[-1]
|
||||
return trades
|
||||
|
135
freqtrade/freqai/RL/Base4ActionRLEnv.py
Normal file
135
freqtrade/freqai/RL/Base4ActionRLEnv.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from gym import spaces
|
||||
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Actions(Enum):
|
||||
Neutral = 0
|
||||
Exit = 1
|
||||
Long_enter = 2
|
||||
Short_enter = 3
|
||||
|
||||
|
||||
class Base4ActionRLEnv(BaseEnvironment):
|
||||
"""
|
||||
Base class for a 4 action environment
|
||||
"""
|
||||
|
||||
def set_action_space(self):
|
||||
self.action_space = spaces.Discrete(len(Actions))
|
||||
|
||||
def step(self, action: int):
|
||||
"""
|
||||
Logic for a single step (incrementing one candle in time)
|
||||
by the agent
|
||||
:param: action: int = the action type that the agent plans
|
||||
to take for the current step.
|
||||
:returns:
|
||||
observation = current state of environment
|
||||
step_reward = the reward from `calculate_reward()`
|
||||
_done = if the agent "died" or if the candles finished
|
||||
info = dict passed back to openai gym lib
|
||||
"""
|
||||
self._done = False
|
||||
self._current_tick += 1
|
||||
|
||||
if self._current_tick == self._end_tick:
|
||||
self._done = True
|
||||
|
||||
self._update_unrealized_total_profit()
|
||||
|
||||
step_reward = self.calculate_reward(action)
|
||||
self.total_reward += step_reward
|
||||
|
||||
trade_type = None
|
||||
if self.is_tradesignal(action):
|
||||
"""
|
||||
Action: Neutral, position: Long -> Close Long
|
||||
Action: Neutral, position: Short -> Close Short
|
||||
|
||||
Action: Long, position: Neutral -> Open Long
|
||||
Action: Long, position: Short -> Close Short and Open Long
|
||||
|
||||
Action: Short, position: Neutral -> Open Short
|
||||
Action: Short, position: Long -> Close Long and Open Short
|
||||
"""
|
||||
|
||||
if action == Actions.Neutral.value:
|
||||
self._position = Positions.Neutral
|
||||
trade_type = "neutral"
|
||||
self._last_trade_tick = None
|
||||
elif action == Actions.Long_enter.value:
|
||||
self._position = Positions.Long
|
||||
trade_type = "long"
|
||||
self._last_trade_tick = self._current_tick
|
||||
elif action == Actions.Short_enter.value:
|
||||
self._position = Positions.Short
|
||||
trade_type = "short"
|
||||
self._last_trade_tick = self._current_tick
|
||||
elif action == Actions.Exit.value:
|
||||
self._update_total_profit()
|
||||
self._position = Positions.Neutral
|
||||
trade_type = "neutral"
|
||||
self._last_trade_tick = None
|
||||
else:
|
||||
print("case not defined")
|
||||
|
||||
if trade_type is not None:
|
||||
self.trade_history.append(
|
||||
{'price': self.current_price(), 'index': self._current_tick,
|
||||
'type': trade_type})
|
||||
|
||||
if self._total_profit < 1 - self.rl_config.get('max_training_drawdown_pct', 0.8):
|
||||
self._done = True
|
||||
|
||||
self._position_history.append(self._position)
|
||||
|
||||
info = dict(
|
||||
tick=self._current_tick,
|
||||
total_reward=self.total_reward,
|
||||
total_profit=self._total_profit,
|
||||
position=self._position.value
|
||||
)
|
||||
|
||||
observation = self._get_observation()
|
||||
|
||||
self._update_history(info)
|
||||
|
||||
return observation, step_reward, self._done, info
|
||||
|
||||
def is_tradesignal(self, action: int) -> bool:
|
||||
"""
|
||||
Determine if the signal is a trade signal
|
||||
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
||||
"""
|
||||
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
|
||||
(action == Actions.Neutral.value and self._position == Positions.Short) or
|
||||
(action == Actions.Neutral.value and self._position == Positions.Long) or
|
||||
(action == Actions.Short_enter.value and self._position == Positions.Short) or
|
||||
(action == Actions.Short_enter.value and self._position == Positions.Long) or
|
||||
(action == Actions.Exit.value and self._position == Positions.Neutral) or
|
||||
(action == Actions.Long_enter.value and self._position == Positions.Long) or
|
||||
(action == Actions.Long_enter.value and self._position == Positions.Short))
|
||||
|
||||
def _is_valid(self, action: int) -> bool:
|
||||
"""
|
||||
Determine if the signal is valid.
|
||||
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
||||
"""
|
||||
# Agent should only try to exit if it is in position
|
||||
if action == Actions.Exit.value:
|
||||
if self._position not in (Positions.Short, Positions.Long):
|
||||
return False
|
||||
|
||||
# Agent should only try to enter if it is not in position
|
||||
if action in (Actions.Short_enter.value, Actions.Long_enter.value):
|
||||
if self._position != Positions.Neutral:
|
||||
return False
|
||||
|
||||
return True
|
145
freqtrade/freqai/RL/Base5ActionRLEnv.py
Normal file
145
freqtrade/freqai/RL/Base5ActionRLEnv.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from gym import spaces
|
||||
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Actions(Enum):
|
||||
Neutral = 0
|
||||
Long_enter = 1
|
||||
Long_exit = 2
|
||||
Short_enter = 3
|
||||
Short_exit = 4
|
||||
|
||||
|
||||
class Base5ActionRLEnv(BaseEnvironment):
|
||||
"""
|
||||
Base class for a 5 action environment
|
||||
"""
|
||||
|
||||
def set_action_space(self):
|
||||
self.action_space = spaces.Discrete(len(Actions))
|
||||
|
||||
def step(self, action: int):
|
||||
"""
|
||||
Logic for a single step (incrementing one candle in time)
|
||||
by the agent
|
||||
:param: action: int = the action type that the agent plans
|
||||
to take for the current step.
|
||||
:returns:
|
||||
observation = current state of environment
|
||||
step_reward = the reward from `calculate_reward()`
|
||||
_done = if the agent "died" or if the candles finished
|
||||
info = dict passed back to openai gym lib
|
||||
"""
|
||||
self._done = False
|
||||
self._current_tick += 1
|
||||
|
||||
if self._current_tick == self._end_tick:
|
||||
self._done = True
|
||||
|
||||
self._update_unrealized_total_profit()
|
||||
step_reward = self.calculate_reward(action)
|
||||
self.total_reward += step_reward
|
||||
|
||||
trade_type = None
|
||||
if self.is_tradesignal(action):
|
||||
"""
|
||||
Action: Neutral, position: Long -> Close Long
|
||||
Action: Neutral, position: Short -> Close Short
|
||||
|
||||
Action: Long, position: Neutral -> Open Long
|
||||
Action: Long, position: Short -> Close Short and Open Long
|
||||
|
||||
Action: Short, position: Neutral -> Open Short
|
||||
Action: Short, position: Long -> Close Long and Open Short
|
||||
"""
|
||||
|
||||
if action == Actions.Neutral.value:
|
||||
self._position = Positions.Neutral
|
||||
trade_type = "neutral"
|
||||
self._last_trade_tick = None
|
||||
elif action == Actions.Long_enter.value:
|
||||
self._position = Positions.Long
|
||||
trade_type = "long"
|
||||
self._last_trade_tick = self._current_tick
|
||||
elif action == Actions.Short_enter.value:
|
||||
self._position = Positions.Short
|
||||
trade_type = "short"
|
||||
self._last_trade_tick = self._current_tick
|
||||
elif action == Actions.Long_exit.value:
|
||||
self._update_total_profit()
|
||||
self._position = Positions.Neutral
|
||||
trade_type = "neutral"
|
||||
self._last_trade_tick = None
|
||||
elif action == Actions.Short_exit.value:
|
||||
self._update_total_profit()
|
||||
self._position = Positions.Neutral
|
||||
trade_type = "neutral"
|
||||
self._last_trade_tick = None
|
||||
else:
|
||||
print("case not defined")
|
||||
|
||||
if trade_type is not None:
|
||||
self.trade_history.append(
|
||||
{'price': self.current_price(), 'index': self._current_tick,
|
||||
'type': trade_type})
|
||||
|
||||
if (self._total_profit < self.max_drawdown or
|
||||
self._total_unrealized_profit < self.max_drawdown):
|
||||
self._done = True
|
||||
|
||||
self._position_history.append(self._position)
|
||||
|
||||
info = dict(
|
||||
tick=self._current_tick,
|
||||
total_reward=self.total_reward,
|
||||
total_profit=self._total_profit,
|
||||
position=self._position.value
|
||||
)
|
||||
|
||||
observation = self._get_observation()
|
||||
|
||||
self._update_history(info)
|
||||
|
||||
return observation, step_reward, self._done, info
|
||||
|
||||
def is_tradesignal(self, action: int) -> bool:
|
||||
"""
|
||||
Determine if the signal is a trade signal
|
||||
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
||||
"""
|
||||
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
|
||||
(action == Actions.Neutral.value and self._position == Positions.Short) or
|
||||
(action == Actions.Neutral.value and self._position == Positions.Long) or
|
||||
(action == Actions.Short_enter.value and self._position == Positions.Short) or
|
||||
(action == Actions.Short_enter.value and self._position == Positions.Long) or
|
||||
(action == Actions.Short_exit.value and self._position == Positions.Long) or
|
||||
(action == Actions.Short_exit.value and self._position == Positions.Neutral) or
|
||||
(action == Actions.Long_enter.value and self._position == Positions.Long) or
|
||||
(action == Actions.Long_enter.value and self._position == Positions.Short) or
|
||||
(action == Actions.Long_exit.value and self._position == Positions.Short) or
|
||||
(action == Actions.Long_exit.value and self._position == Positions.Neutral))
|
||||
|
||||
def _is_valid(self, action: int) -> bool:
|
||||
# trade signal
|
||||
"""
|
||||
Determine if the signal is valid.
|
||||
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
||||
"""
|
||||
# Agent should only try to exit if it is in position
|
||||
if action in (Actions.Short_exit.value, Actions.Long_exit.value):
|
||||
if self._position not in (Positions.Short, Positions.Long):
|
||||
return False
|
||||
|
||||
# Agent should only try to enter if it is not in position
|
||||
if action in (Actions.Short_enter.value, Actions.Long_enter.value):
|
||||
if self._position != Positions.Neutral:
|
||||
return False
|
||||
|
||||
return True
|
307
freqtrade/freqai/RL/BaseEnvironment.py
Normal file
307
freqtrade/freqai/RL/BaseEnvironment.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import logging
|
||||
import random
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gym import spaces
|
||||
from gym.utils import seeding
|
||||
from pandas import DataFrame
|
||||
|
||||
from freqtrade.data.dataprovider import DataProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Positions(Enum):
|
||||
Short = 0
|
||||
Long = 1
|
||||
Neutral = 0.5
|
||||
|
||||
def opposite(self):
|
||||
return Positions.Short if self == Positions.Long else Positions.Long
|
||||
|
||||
|
||||
class BaseEnvironment(gym.Env):
|
||||
"""
|
||||
Base class for environments. This class is agnostic to action count.
|
||||
Inherited classes customize this to include varying action counts/types,
|
||||
See RL/Base5ActionRLEnv.py and RL/Base4ActionRLEnv.py
|
||||
"""
|
||||
|
||||
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
||||
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
||||
id: str = 'baseenv-1', seed: int = 1, config: dict = {},
|
||||
dp: Optional[DataProvider] = None):
|
||||
"""
|
||||
Initializes the training/eval environment.
|
||||
:param df: dataframe of features
|
||||
:param prices: dataframe of prices to be used in the training environment
|
||||
:param window_size: size of window (temporal) to pass to the agent
|
||||
:param reward_kwargs: extra config settings assigned by user in `rl_config`
|
||||
:param starting_point: start at edge of window or not
|
||||
:param id: string id of the environment (used in backend for multiprocessed env)
|
||||
:param seed: Sets the seed of the environment higher in the gym.Env object
|
||||
:param config: Typical user configuration file
|
||||
:param dp: dataprovider from freqtrade
|
||||
"""
|
||||
self.config = config
|
||||
self.rl_config = config['freqai']['rl_config']
|
||||
self.add_state_info = self.rl_config.get('add_state_info', False)
|
||||
self.id = id
|
||||
self.seed(seed)
|
||||
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
||||
self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
|
||||
self.compound_trades = config['stake_amount'] == 'unlimited'
|
||||
if self.config.get('fee', None) is not None:
|
||||
self.fee = self.config['fee']
|
||||
elif dp is not None:
|
||||
self.fee = dp._exchange.get_fee(symbol=dp.current_whitelist()[0]) # type: ignore
|
||||
else:
|
||||
self.fee = 0.0015
|
||||
|
||||
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
|
||||
reward_kwargs: dict, starting_point=True):
|
||||
"""
|
||||
Resets the environment when the agent fails (in our case, if the drawdown
|
||||
exceeds the user set max_training_drawdown_pct)
|
||||
:param df: dataframe of features
|
||||
:param prices: dataframe of prices to be used in the training environment
|
||||
:param window_size: size of window (temporal) to pass to the agent
|
||||
:param reward_kwargs: extra config settings assigned by user in `rl_config`
|
||||
:param starting_point: start at edge of window or not
|
||||
"""
|
||||
self.df = df
|
||||
self.signal_features = self.df
|
||||
self.prices = prices
|
||||
self.window_size = window_size
|
||||
self.starting_point = starting_point
|
||||
self.rr = reward_kwargs["rr"]
|
||||
self.profit_aim = reward_kwargs["profit_aim"]
|
||||
|
||||
# # spaces
|
||||
if self.add_state_info:
|
||||
self.total_features = self.signal_features.shape[1] + 3
|
||||
else:
|
||||
self.total_features = self.signal_features.shape[1]
|
||||
self.shape = (window_size, self.total_features)
|
||||
self.set_action_space()
|
||||
self.observation_space = spaces.Box(
|
||||
low=-1, high=1, shape=self.shape, dtype=np.float32)
|
||||
|
||||
# episode
|
||||
self._start_tick: int = self.window_size
|
||||
self._end_tick: int = len(self.prices) - 1
|
||||
self._done: bool = False
|
||||
self._current_tick: int = self._start_tick
|
||||
self._last_trade_tick: Optional[int] = None
|
||||
self._position = Positions.Neutral
|
||||
self._position_history: list = [None]
|
||||
self.total_reward: float = 0
|
||||
self._total_profit: float = 1
|
||||
self._total_unrealized_profit: float = 1
|
||||
self.history: dict = {}
|
||||
self.trade_history: list = []
|
||||
|
||||
@abstractmethod
|
||||
def set_action_space(self):
|
||||
"""
|
||||
Unique to the environment action count. Must be inherited.
|
||||
"""
|
||||
|
||||
def seed(self, seed: int = 1):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def reset(self):
|
||||
|
||||
self._done = False
|
||||
|
||||
if self.starting_point is True:
|
||||
if self.rl_config.get('randomize_starting_position', False):
|
||||
length_of_data = int(self._end_tick / 4)
|
||||
start_tick = random.randint(self.window_size + 1, length_of_data)
|
||||
self._start_tick = start_tick
|
||||
self._position_history = (self._start_tick * [None]) + [self._position]
|
||||
else:
|
||||
self._position_history = (self.window_size * [None]) + [self._position]
|
||||
|
||||
self._current_tick = self._start_tick
|
||||
self._last_trade_tick = None
|
||||
self._position = Positions.Neutral
|
||||
|
||||
self.total_reward = 0.
|
||||
self._total_profit = 1. # unit
|
||||
self.history = {}
|
||||
self.trade_history = []
|
||||
self.portfolio_log_returns = np.zeros(len(self.prices))
|
||||
|
||||
self._profits = [(self._start_tick, 1)]
|
||||
self.close_trade_profit = []
|
||||
self._total_unrealized_profit = 1
|
||||
|
||||
return self._get_observation()
|
||||
|
||||
@abstractmethod
|
||||
def step(self, action: int):
|
||||
"""
|
||||
Step depeneds on action types, this must be inherited.
|
||||
"""
|
||||
return
|
||||
|
||||
def _get_observation(self):
|
||||
"""
|
||||
This may or may not be independent of action types, user can inherit
|
||||
this in their custom "MyRLEnv"
|
||||
"""
|
||||
features_window = self.signal_features[(
|
||||
self._current_tick - self.window_size):self._current_tick]
|
||||
if self.add_state_info:
|
||||
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
|
||||
columns=['current_profit_pct',
|
||||
'position',
|
||||
'trade_duration'],
|
||||
index=features_window.index)
|
||||
|
||||
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
|
||||
features_and_state['position'] = self._position.value
|
||||
features_and_state['trade_duration'] = self.get_trade_duration()
|
||||
features_and_state = pd.concat([features_window, features_and_state], axis=1)
|
||||
return features_and_state
|
||||
else:
|
||||
return features_window
|
||||
|
||||
def get_trade_duration(self):
|
||||
"""
|
||||
Get the trade duration if the agent is in a trade
|
||||
"""
|
||||
if self._last_trade_tick is None:
|
||||
return 0
|
||||
else:
|
||||
return self._current_tick - self._last_trade_tick
|
||||
|
||||
def get_unrealized_profit(self):
|
||||
"""
|
||||
Get the unrealized profit if the agent is in a trade
|
||||
"""
|
||||
if self._last_trade_tick is None:
|
||||
return 0.
|
||||
|
||||
if self._position == Positions.Neutral:
|
||||
return 0.
|
||||
elif self._position == Positions.Short:
|
||||
current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open)
|
||||
last_trade_price = self.add_entry_fee(self.prices.iloc[self._last_trade_tick].open)
|
||||
return (last_trade_price - current_price) / last_trade_price
|
||||
elif self._position == Positions.Long:
|
||||
current_price = self.add_entry_fee(self.prices.iloc[self._current_tick].open)
|
||||
last_trade_price = self.add_exit_fee(self.prices.iloc[self._last_trade_tick].open)
|
||||
return (current_price - last_trade_price) / last_trade_price
|
||||
else:
|
||||
return 0.
|
||||
|
||||
@abstractmethod
|
||||
def is_tradesignal(self, action: int) -> bool:
|
||||
"""
|
||||
Determine if the signal is a trade signal. This is
|
||||
unique to the actions in the environment, and therefore must be
|
||||
inherited.
|
||||
"""
|
||||
return True
|
||||
|
||||
def _is_valid(self, action: int) -> bool:
|
||||
"""
|
||||
Determine if the signal is valid.This is
|
||||
unique to the actions in the environment, and therefore must be
|
||||
inherited.
|
||||
"""
|
||||
return True
|
||||
|
||||
def add_entry_fee(self, price):
|
||||
return price * (1 + self.fee)
|
||||
|
||||
def add_exit_fee(self, price):
|
||||
return price / (1 + self.fee)
|
||||
|
||||
def _update_history(self, info):
|
||||
if not self.history:
|
||||
self.history = {key: [] for key in info.keys()}
|
||||
|
||||
for key, value in info.items():
|
||||
self.history[key].append(value)
|
||||
|
||||
@abstractmethod
|
||||
def calculate_reward(self, action: int) -> float:
|
||||
"""
|
||||
An example reward function. This is the one function that users will likely
|
||||
wish to inject their own creativity into.
|
||||
:param action: int = The action made by the agent for the current candle.
|
||||
:return:
|
||||
float = the reward to give to the agent for current step (used for optimization
|
||||
of weights in NN)
|
||||
"""
|
||||
|
||||
def _update_unrealized_total_profit(self):
|
||||
"""
|
||||
Update the unrealized total profit incase of episode end.
|
||||
"""
|
||||
if self._position in (Positions.Long, Positions.Short):
|
||||
pnl = self.get_unrealized_profit()
|
||||
if self.compound_trades:
|
||||
# assumes unit stake and compounding
|
||||
unrl_profit = self._total_profit * (1 + pnl)
|
||||
else:
|
||||
# assumes unit stake and no compounding
|
||||
unrl_profit = self._total_profit + pnl
|
||||
self._total_unrealized_profit = unrl_profit
|
||||
|
||||
def _update_total_profit(self):
|
||||
pnl = self.get_unrealized_profit()
|
||||
if self.compound_trades:
|
||||
# assumes unit stake and compounding
|
||||
self._total_profit = self._total_profit * (1 + pnl)
|
||||
else:
|
||||
# assumes unit stake and no compounding
|
||||
self._total_profit += pnl
|
||||
|
||||
def current_price(self) -> float:
|
||||
return self.prices.iloc[self._current_tick].open
|
||||
|
||||
# Keeping around incase we want to start building more complex environment
|
||||
# templates in the future.
|
||||
# def most_recent_return(self):
|
||||
# """
|
||||
# Calculate the tick to tick return if in a trade.
|
||||
# Return is generated from rising prices in Long
|
||||
# and falling prices in Short positions.
|
||||
# The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
|
||||
# """
|
||||
# # Long positions
|
||||
# if self._position == Positions.Long:
|
||||
# current_price = self.prices.iloc[self._current_tick].open
|
||||
# previous_price = self.prices.iloc[self._current_tick - 1].open
|
||||
|
||||
# if (self._position_history[self._current_tick - 1] == Positions.Short
|
||||
# or self._position_history[self._current_tick - 1] == Positions.Neutral):
|
||||
# previous_price = self.add_entry_fee(previous_price)
|
||||
|
||||
# return np.log(current_price) - np.log(previous_price)
|
||||
|
||||
# # Short positions
|
||||
# if self._position == Positions.Short:
|
||||
# current_price = self.prices.iloc[self._current_tick].open
|
||||
# previous_price = self.prices.iloc[self._current_tick - 1].open
|
||||
# if (self._position_history[self._current_tick - 1] == Positions.Long
|
||||
# or self._position_history[self._current_tick - 1] == Positions.Neutral):
|
||||
# previous_price = self.add_exit_fee(previous_price)
|
||||
|
||||
# return np.log(previous_price) - np.log(current_price)
|
||||
|
||||
# return 0
|
||||
|
||||
# def update_portfolio_log_returns(self, action):
|
||||
# self.portfolio_log_returns[self._current_tick] = self.most_recent_return(action)
|
396
freqtrade/freqai/RL/BaseReinforcementLearningModel.py
Normal file
396
freqtrade/freqai/RL/BaseReinforcementLearningModel.py
Normal file
@@ -0,0 +1,396 @@
|
||||
import importlib
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pandas as pd
|
||||
import torch as th
|
||||
import torch.multiprocessing
|
||||
from pandas import DataFrame
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
|
||||
from freqtrade.exceptions import OperationalException
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
||||
from freqtrade.freqai.RL.BaseEnvironment import Positions
|
||||
from freqtrade.persistence import Trade
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
|
||||
SB3_MODELS = ['PPO', 'A2C', 'DQN']
|
||||
SB3_CONTRIB_MODELS = ['TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO']
|
||||
|
||||
|
||||
class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
"""
|
||||
User created Reinforcement Learning Model prediction class
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(config=kwargs['config'])
|
||||
self.max_threads = min(self.freqai_info['rl_config'].get(
|
||||
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
|
||||
th.set_num_threads(self.max_threads)
|
||||
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
||||
self.train_env: Union[SubprocVecEnv, gym.Env] = None
|
||||
self.eval_env: Union[SubprocVecEnv, gym.Env] = None
|
||||
self.eval_callback: Optional[EvalCallback] = None
|
||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||
self.rl_config = self.freqai_info['rl_config']
|
||||
self.continual_learning = self.freqai_info.get('continual_learning', False)
|
||||
if self.model_type in SB3_MODELS:
|
||||
import_str = 'stable_baselines3'
|
||||
elif self.model_type in SB3_CONTRIB_MODELS:
|
||||
import_str = 'sb3_contrib'
|
||||
else:
|
||||
raise OperationalException(f'{self.model_type} not available in stable_baselines3 or '
|
||||
f'sb3_contrib. please choose one of {SB3_MODELS} or '
|
||||
f'{SB3_CONTRIB_MODELS}')
|
||||
|
||||
mod = importlib.import_module(import_str, self.model_type)
|
||||
self.MODELCLASS = getattr(mod, self.model_type)
|
||||
self.policy_type = self.freqai_info['rl_config']['policy_type']
|
||||
self.unset_outlier_removal()
|
||||
self.net_arch = self.rl_config.get('net_arch', [128, 128])
|
||||
self.dd.model_type = "stable_baselines"
|
||||
|
||||
def unset_outlier_removal(self):
|
||||
"""
|
||||
If user has activated any function that may remove training points, this
|
||||
function will set them to false and warn them
|
||||
"""
|
||||
if self.ft_params.get('use_SVM_to_remove_outliers', False):
|
||||
self.ft_params.update({'use_SVM_to_remove_outliers': False})
|
||||
logger.warning('User tried to use SVM with RL. Deactivating SVM.')
|
||||
if self.ft_params.get('use_DBSCAN_to_remove_outliers', False):
|
||||
self.ft_params.update({'use_DBSCAN_to_remove_outliers': False})
|
||||
logger.warning('User tried to use DBSCAN with RL. Deactivating DBSCAN.')
|
||||
if self.freqai_info['data_split_parameters'].get('shuffle', False):
|
||||
self.freqai_info['data_split_parameters'].update({'shuffle': False})
|
||||
logger.warning('User tried to shuffle training data. Setting shuffle to False')
|
||||
|
||||
def train(
|
||||
self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Filter the training data and train a model to it. Train makes heavy use of the datakitchen
|
||||
for storing, saving, loading, and analyzing the data.
|
||||
:param unfiltered_df: Full dataframe for the current training period
|
||||
:param metadata: pair metadata from strategy.
|
||||
:returns:
|
||||
:model: Trained model which can be used to inference (self.predict)
|
||||
"""
|
||||
|
||||
logger.info("--------------------Starting training " f"{pair} --------------------")
|
||||
|
||||
features_filtered, labels_filtered = dk.filter_features(
|
||||
unfiltered_df,
|
||||
dk.training_features_list,
|
||||
dk.label_list,
|
||||
training_filter=True,
|
||||
)
|
||||
|
||||
data_dictionary: Dict[str, Any] = dk.make_train_test_datasets(
|
||||
features_filtered, labels_filtered)
|
||||
dk.fit_labels() # FIXME useless for now, but just satiating append methods
|
||||
|
||||
# normalize all data based on train_dataset only
|
||||
prices_train, prices_test = self.build_ohlc_price_dataframes(dk.data_dictionary, pair, dk)
|
||||
data_dictionary = dk.normalize_data(data_dictionary)
|
||||
|
||||
# data cleaning/analysis
|
||||
self.data_cleaning_train(dk)
|
||||
|
||||
logger.info(
|
||||
f'Training model on {len(dk.data_dictionary["train_features"].columns)}'
|
||||
f' features and {len(data_dictionary["train_features"])} data points'
|
||||
)
|
||||
|
||||
self.set_train_and_eval_environments(data_dictionary, prices_train, prices_test, dk)
|
||||
|
||||
model = self.fit(data_dictionary, dk)
|
||||
|
||||
logger.info(f"--------------------done training {pair}--------------------")
|
||||
|
||||
return model
|
||||
|
||||
def set_train_and_eval_environments(self, data_dictionary: Dict[str, DataFrame],
|
||||
prices_train: DataFrame, prices_test: DataFrame,
|
||||
dk: FreqaiDataKitchen):
|
||||
"""
|
||||
User can override this if they are using a custom MyRLEnv
|
||||
:param data_dictionary: dict = common data dictionary containing train and test
|
||||
features/labels/weights.
|
||||
:param prices_train/test: DataFrame = dataframe comprised of the prices to be used in the
|
||||
environment during training or testing
|
||||
:param dk: FreqaiDataKitchen = the datakitchen for the current pair
|
||||
"""
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
self.train_env = self.MyRLEnv(df=train_df,
|
||||
prices=prices_train,
|
||||
window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params,
|
||||
config=self.config,
|
||||
dp=self.data_provider)
|
||||
self.eval_env = Monitor(self.MyRLEnv(df=test_df,
|
||||
prices=prices_test,
|
||||
window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params,
|
||||
config=self.config,
|
||||
dp=self.data_provider))
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=len(train_df),
|
||||
best_model_save_path=str(dk.data_path))
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||
"""
|
||||
Agent customizations and abstract Reinforcement Learning customizations
|
||||
go in here. Abstract method, so this function must be overridden by
|
||||
user class.
|
||||
"""
|
||||
return
|
||||
|
||||
def get_state_info(self, pair: str) -> Tuple[float, float, int]:
|
||||
"""
|
||||
State info during dry/live (not backtesting) which is fed back
|
||||
into the model.
|
||||
:param pair: str = COIN/STAKE to get the environment information for
|
||||
:return:
|
||||
:market_side: float = representing short, long, or neutral for
|
||||
pair
|
||||
:current_profit: float = unrealized profit of the current trade
|
||||
:trade_duration: int = the number of candles that the trade has
|
||||
been open for
|
||||
"""
|
||||
open_trades = Trade.get_trades_proxy(is_open=True)
|
||||
market_side = 0.5
|
||||
current_profit: float = 0
|
||||
trade_duration = 0
|
||||
for trade in open_trades:
|
||||
if trade.pair == pair:
|
||||
if self.data_provider._exchange is None: # type: ignore
|
||||
logger.error('No exchange available.')
|
||||
return 0, 0, 0
|
||||
else:
|
||||
current_rate = self.data_provider._exchange.get_rate( # type: ignore
|
||||
pair, refresh=False, side="exit", is_short=trade.is_short)
|
||||
|
||||
now = datetime.now(timezone.utc).timestamp()
|
||||
trade_duration = int((now - trade.open_date_utc.timestamp()) / self.base_tf_seconds)
|
||||
current_profit = trade.calc_profit_ratio(current_rate)
|
||||
|
||||
return market_side, current_profit, int(trade_duration)
|
||||
|
||||
def predict(
|
||||
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||
"""
|
||||
Filter the prediction features data and predict with it.
|
||||
:param unfiltered_dataframe: Full dataframe for the current backtest period.
|
||||
:return:
|
||||
:pred_df: dataframe containing the predictions
|
||||
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
|
||||
data (NaNs) or felt uncertain about data (PCA and DI index)
|
||||
"""
|
||||
|
||||
dk.find_features(unfiltered_df)
|
||||
filtered_dataframe, _ = dk.filter_features(
|
||||
unfiltered_df, dk.training_features_list, training_filter=False
|
||||
)
|
||||
filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe)
|
||||
dk.data_dictionary["prediction_features"] = filtered_dataframe
|
||||
|
||||
# optional additional data cleaning/analysis
|
||||
self.data_cleaning_predict(dk)
|
||||
|
||||
pred_df = self.rl_model_predict(
|
||||
dk.data_dictionary["prediction_features"], dk, self.model)
|
||||
pred_df.fillna(0, inplace=True)
|
||||
|
||||
return (pred_df, dk.do_predict)
|
||||
|
||||
def rl_model_predict(self, dataframe: DataFrame,
|
||||
dk: FreqaiDataKitchen, model: Any) -> DataFrame:
|
||||
"""
|
||||
A helper function to make predictions in the Reinforcement learning module.
|
||||
:param dataframe: DataFrame = the dataframe of features to make the predictions on
|
||||
:param dk: FreqaiDatakitchen = data kitchen for the current pair
|
||||
:param model: Any = the trained model used to inference the features.
|
||||
"""
|
||||
output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
|
||||
|
||||
def _predict(window):
|
||||
observations = dataframe.iloc[window.index]
|
||||
if self.live and self.rl_config.get('add_state_info', False):
|
||||
market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
|
||||
observations['current_profit_pct'] = current_profit
|
||||
observations['position'] = market_side
|
||||
observations['trade_duration'] = trade_duration
|
||||
res, _ = model.predict(observations, deterministic=True)
|
||||
return res
|
||||
|
||||
output = output.rolling(window=self.CONV_WIDTH).apply(_predict)
|
||||
|
||||
return output
|
||||
|
||||
def build_ohlc_price_dataframes(self, data_dictionary: dict,
|
||||
pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame,
|
||||
DataFrame]:
|
||||
"""
|
||||
Builds the train prices and test prices for the environment.
|
||||
"""
|
||||
|
||||
pair = pair.replace(':', '')
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
# price data for model training and evaluation
|
||||
tf = self.config['timeframe']
|
||||
ohlc_list = [f'%-{pair}raw_open_{tf}', f'%-{pair}raw_low_{tf}',
|
||||
f'%-{pair}raw_high_{tf}', f'%-{pair}raw_close_{tf}']
|
||||
rename_dict = {f'%-{pair}raw_open_{tf}': 'open', f'%-{pair}raw_low_{tf}': 'low',
|
||||
f'%-{pair}raw_high_{tf}': ' high', f'%-{pair}raw_close_{tf}': 'close'}
|
||||
|
||||
prices_train = train_df.filter(ohlc_list, axis=1)
|
||||
if prices_train.empty:
|
||||
raise OperationalException('Reinforcement learning module didnt find the raw prices '
|
||||
'assigned in populate_any_indicators. Please assign them '
|
||||
'with:\n'
|
||||
'informative[f"%-{pair}raw_close"] = informative["close"]\n'
|
||||
'informative[f"%-{pair}raw_open"] = informative["open"]\n'
|
||||
'informative[f"%-{pair}raw_high"] = informative["high"]\n'
|
||||
'informative[f"%-{pair}raw_low"] = informative["low"]\n')
|
||||
prices_train.rename(columns=rename_dict, inplace=True)
|
||||
prices_train.reset_index(drop=True)
|
||||
|
||||
prices_test = test_df.filter(ohlc_list, axis=1)
|
||||
prices_test.rename(columns=rename_dict, inplace=True)
|
||||
prices_test.reset_index(drop=True)
|
||||
|
||||
return prices_train, prices_test
|
||||
|
||||
def load_model_from_disk(self, dk: FreqaiDataKitchen) -> Any:
|
||||
"""
|
||||
Can be used by user if they are trying to limit_ram_usage *and*
|
||||
perform continual learning.
|
||||
For now, this is unused.
|
||||
"""
|
||||
exists = Path(dk.data_path / f"{dk.model_filename}_model").is_file()
|
||||
if exists:
|
||||
model = self.MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")
|
||||
else:
|
||||
logger.info('No model file on disk to continue learning from.')
|
||||
|
||||
return model
|
||||
|
||||
def _on_stop(self):
|
||||
"""
|
||||
Hook called on bot shutdown. Close SubprocVecEnv subprocesses for clean shutdown.
|
||||
"""
|
||||
|
||||
if self.train_env:
|
||||
self.train_env.close()
|
||||
|
||||
if self.eval_env:
|
||||
self.eval_env.close()
|
||||
|
||||
# Nested class which can be overridden by user to customize further
|
||||
class MyRLEnv(Base5ActionRLEnv):
|
||||
"""
|
||||
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||
sets a custom reward based on profit and trade duration.
|
||||
"""
|
||||
|
||||
def calculate_reward(self, action: int) -> float:
|
||||
"""
|
||||
An example reward function. This is the one function that users will likely
|
||||
wish to inject their own creativity into.
|
||||
:param action: int = The action made by the agent for the current candle.
|
||||
:return:
|
||||
float = the reward to give to the agent for current step (used for optimization
|
||||
of weights in NN)
|
||||
"""
|
||||
# first, penalize if the action is not valid
|
||||
if not self._is_valid(action):
|
||||
return -2
|
||||
|
||||
pnl = self.get_unrealized_profit()
|
||||
factor = 100.
|
||||
|
||||
# reward agent for entering trades
|
||||
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
||||
and self._position == Positions.Neutral):
|
||||
return 25
|
||||
# discourage agent from not entering trades
|
||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||
return -1
|
||||
|
||||
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
||||
if self._last_trade_tick:
|
||||
trade_duration = self._current_tick - self._last_trade_tick
|
||||
else:
|
||||
trade_duration = 0
|
||||
|
||||
if trade_duration <= max_trade_duration:
|
||||
factor *= 1.5
|
||||
elif trade_duration > max_trade_duration:
|
||||
factor *= 0.5
|
||||
|
||||
# discourage sitting in position
|
||||
if (self._position in (Positions.Short, Positions.Long) and
|
||||
action == Actions.Neutral.value):
|
||||
return -1 * trade_duration / max_trade_duration
|
||||
|
||||
# close long
|
||||
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
||||
if pnl > self.profit_aim * self.rr:
|
||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||
return float(pnl * factor)
|
||||
|
||||
# close short
|
||||
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
||||
if pnl > self.profit_aim * self.rr:
|
||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||
return float(pnl * factor)
|
||||
|
||||
return 0.
|
||||
|
||||
|
||||
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
||||
seed: int, train_df: DataFrame, price: DataFrame,
|
||||
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
||||
config: Dict[str, Any] = {}) -> Callable:
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
:param env_id: (str) the environment ID
|
||||
:param num_env: (int) the number of environment you wish to have in subprocesses
|
||||
:param seed: (int) the inital seed for RNG
|
||||
:param rank: (int) index of the subprocess
|
||||
:return: (Callable)
|
||||
"""
|
||||
|
||||
def _init() -> gym.Env:
|
||||
|
||||
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
||||
reward_kwargs=reward_params, id=env_id, seed=seed + rank, config=config)
|
||||
if monitor:
|
||||
env = Monitor(env)
|
||||
return env
|
||||
set_random_seed(seed)
|
||||
return _init
|
0
freqtrade/freqai/RL/__init__.py
Normal file
0
freqtrade/freqai/RL/__init__.py
Normal file
@@ -1,4 +1,5 @@
|
||||
import collections
|
||||
import importlib
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
@@ -99,6 +100,7 @@ class FreqaiDataDrawer:
|
||||
self.empty_pair_dict: pair_info = {
|
||||
"model_filename": "", "trained_timestamp": 0,
|
||||
"data_path": "", "extras": {}}
|
||||
self.model_type = self.freqai_info.get('model_save_type', 'joblib')
|
||||
|
||||
def update_metric_tracker(self, metric: str, value: float, pair: str) -> None:
|
||||
"""
|
||||
@@ -497,10 +499,12 @@ class FreqaiDataDrawer:
|
||||
save_path = Path(dk.data_path)
|
||||
|
||||
# Save the trained model
|
||||
if not dk.keras:
|
||||
if self.model_type == 'joblib':
|
||||
dump(model, save_path / f"{dk.model_filename}_model.joblib")
|
||||
else:
|
||||
elif self.model_type == 'keras':
|
||||
model.save(save_path / f"{dk.model_filename}_model.h5")
|
||||
elif 'stable_baselines' in self.model_type:
|
||||
model.save(save_path / f"{dk.model_filename}_model.zip")
|
||||
|
||||
if dk.svm_model is not None:
|
||||
dump(dk.svm_model, save_path / f"{dk.model_filename}_svm_model.joblib")
|
||||
@@ -527,11 +531,10 @@ class FreqaiDataDrawer:
|
||||
dk.pca, open(dk.data_path / f"{dk.model_filename}_pca_object.pkl", "wb")
|
||||
)
|
||||
|
||||
# if self.live:
|
||||
# store as much in ram as possible to increase performance
|
||||
self.model_dictionary[coin] = model
|
||||
self.pair_dict[coin]["model_filename"] = dk.model_filename
|
||||
self.pair_dict[coin]["data_path"] = str(dk.data_path)
|
||||
|
||||
if coin not in self.meta_data_dictionary:
|
||||
self.meta_data_dictionary[coin] = {}
|
||||
self.meta_data_dictionary[coin]["train_df"] = dk.data_dictionary["train_features"]
|
||||
@@ -563,14 +566,6 @@ class FreqaiDataDrawer:
|
||||
if dk.live:
|
||||
dk.model_filename = self.pair_dict[coin]["model_filename"]
|
||||
dk.data_path = Path(self.pair_dict[coin]["data_path"])
|
||||
if self.freqai_info.get("follow_mode", False):
|
||||
# follower can be on a different system which is rsynced from the leader:
|
||||
dk.data_path = Path(
|
||||
self.config["user_data_dir"]
|
||||
/ "models"
|
||||
/ dk.data_path.parts[-2]
|
||||
/ dk.data_path.parts[-1]
|
||||
)
|
||||
|
||||
if coin in self.meta_data_dictionary:
|
||||
dk.data = self.meta_data_dictionary[coin]["meta_data"]
|
||||
@@ -589,12 +584,16 @@ class FreqaiDataDrawer:
|
||||
# try to access model in memory instead of loading object from disk to save time
|
||||
if dk.live and coin in self.model_dictionary:
|
||||
model = self.model_dictionary[coin]
|
||||
elif not dk.keras:
|
||||
elif self.model_type == 'joblib':
|
||||
model = load(dk.data_path / f"{dk.model_filename}_model.joblib")
|
||||
else:
|
||||
elif self.model_type == 'keras':
|
||||
from tensorflow import keras
|
||||
|
||||
model = keras.models.load_model(dk.data_path / f"{dk.model_filename}_model.h5")
|
||||
elif self.model_type == 'stable_baselines':
|
||||
mod = importlib.import_module(
|
||||
'stable_baselines3', self.freqai_info['rl_config']['model_type'])
|
||||
MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type'])
|
||||
model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")
|
||||
|
||||
if Path(dk.data_path / f"{dk.model_filename}_svm_model.joblib").is_file():
|
||||
dk.svm_model = load(dk.data_path / f"{dk.model_filename}_svm_model.joblib")
|
||||
@@ -604,6 +603,10 @@ class FreqaiDataDrawer:
|
||||
f"Unable to load model, ensure model exists at " f"{dk.data_path} "
|
||||
)
|
||||
|
||||
# load it into ram if it was loaded from disk
|
||||
if coin not in self.model_dictionary:
|
||||
self.model_dictionary[coin] = model
|
||||
|
||||
if self.config["freqai"]["feature_parameters"]["principal_component_analysis"]:
|
||||
dk.pca = cloudpickle.load(
|
||||
open(dk.data_path / f"{dk.model_filename}_pca_object.pkl", "rb")
|
||||
|
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Tuple
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pandas as pd
|
||||
import psutil
|
||||
from pandas import DataFrame, HDFStore
|
||||
from scipy import stats
|
||||
from sklearn import linear_model
|
||||
@@ -98,7 +99,10 @@ class FreqaiDataKitchen:
|
||||
)
|
||||
|
||||
self.data['extra_returns_per_train'] = self.freqai_config.get('extra_returns_per_train', {})
|
||||
self.thread_count = self.freqai_config.get("data_kitchen_thread_count", -1)
|
||||
if not self.freqai_config.get("data_kitchen_thread_count", 0):
|
||||
self.thread_count = max(int(psutil.cpu_count() * 2 - 2), 1)
|
||||
else:
|
||||
self.thread_count = self.freqai_config["data_kitchen_thread_count"]
|
||||
self.train_dates: DataFrame = pd.DataFrame()
|
||||
self.unique_classes: Dict[str, list] = {}
|
||||
self.unique_class_list: list = []
|
||||
|
@@ -5,15 +5,17 @@ from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Tuple
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import psutil
|
||||
from numpy.typing import NDArray
|
||||
from pandas import DataFrame
|
||||
|
||||
from freqtrade.configuration import TimeRange
|
||||
from freqtrade.constants import Config
|
||||
from freqtrade.data.dataprovider import DataProvider
|
||||
from freqtrade.enums import RunMode
|
||||
from freqtrade.exceptions import OperationalException
|
||||
from freqtrade.exchange import timeframe_to_seconds
|
||||
@@ -101,6 +103,8 @@ class IFreqaiModel(ABC):
|
||||
self._threads: List[threading.Thread] = []
|
||||
self._stop_event = threading.Event()
|
||||
self.metadata = self.dd.load_global_metadata_from_disk()
|
||||
self.data_provider: Optional[DataProvider] = None
|
||||
self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1)
|
||||
|
||||
record_params(config, self.full_path)
|
||||
|
||||
@@ -129,6 +133,7 @@ class IFreqaiModel(ABC):
|
||||
|
||||
self.live = strategy.dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
|
||||
self.dd.set_pair_dict_info(metadata)
|
||||
self.data_provider = strategy.dp
|
||||
|
||||
if self.live:
|
||||
self.inference_timer('start')
|
||||
@@ -175,6 +180,13 @@ class IFreqaiModel(ABC):
|
||||
self.model = None
|
||||
self.dk = None
|
||||
|
||||
def _on_stop(self):
|
||||
"""
|
||||
Callback for Subclasses to override to include logic for shutting down resources
|
||||
when SIGINT is sent.
|
||||
"""
|
||||
return
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Cleans up threads on Shutdown, set stop event. Join threads to wait
|
||||
@@ -183,6 +195,9 @@ class IFreqaiModel(ABC):
|
||||
logger.info("Stopping FreqAI")
|
||||
self._stop_event.set()
|
||||
|
||||
self.data_provider = None
|
||||
self._on_stop()
|
||||
|
||||
logger.info("Waiting on Training iteration")
|
||||
for _thread in self._threads:
|
||||
_thread.join()
|
||||
@@ -663,7 +678,7 @@ class IFreqaiModel(ABC):
|
||||
hist_preds_df['DI_values'] = 0
|
||||
|
||||
for return_str in dk.data['extra_returns_per_train']:
|
||||
hist_preds_df[return_str] = 0
|
||||
hist_preds_df[return_str] = dk.data['extra_returns_per_train'][return_str]
|
||||
|
||||
hist_preds_df['close_price'] = strat_df['close']
|
||||
hist_preds_df['date_pred'] = strat_df['date']
|
||||
|
141
freqtrade/freqai/prediction_models/ReinforcementLearner.py
Normal file
141
freqtrade/freqai/prediction_models/ReinforcementLearner.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch as th
|
||||
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
|
||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
"""
|
||||
Reinforcement Learning Model prediction model.
|
||||
|
||||
Users can inherit from this class to make their own RL model with custom
|
||||
environment/training controls. Define the file as follows:
|
||||
|
||||
```
|
||||
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||
|
||||
class MyCoolRLModel(ReinforcementLearner):
|
||||
```
|
||||
|
||||
Save the file to `user_data/freqaimodels`, then run it with:
|
||||
|
||||
freqtrade trade --freqaimodel MyCoolRLModel --config config.json --strategy SomeCoolStrat
|
||||
|
||||
Here the users can override any of the functions
|
||||
available in the `IFreqaiModel` inheritance tree. Most importantly for RL, this
|
||||
is where the user overrides `MyRLEnv` (see below), to define custom
|
||||
`calculate_reward()` function, or to override any other parts of the environment.
|
||||
|
||||
This class also allows users to override any other part of the IFreqaiModel tree.
|
||||
For example, the user can override `def fit()` or `def train()` or `def predict()`
|
||||
to take fine-tuned control over these processes.
|
||||
|
||||
Another common override may be `def data_cleaning_predict()` where the user can
|
||||
take fine-tuned control over the data handling pipeline.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||
"""
|
||||
User customizable fit method
|
||||
:param data_dictionary: dict = common data dictionary containing all train/test
|
||||
features/labels/weights.
|
||||
:param dk: FreqaiDatakitchen = data kitchen for current pair.
|
||||
:return:
|
||||
model Any = trained model to be used for inference in dry/live/backtesting
|
||||
"""
|
||||
train_df = data_dictionary["train_features"]
|
||||
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
||||
|
||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||
net_arch=self.net_arch)
|
||||
|
||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=Path(
|
||||
dk.full_path / "tensorboard" / dk.pair.split('/')[0]),
|
||||
**self.freqai_info['model_training_parameters']
|
||||
)
|
||||
else:
|
||||
logger.info('Continual training activated - starting training from previously '
|
||||
'trained agent.')
|
||||
model = self.dd.model_dictionary[dk.pair]
|
||||
model.set_env(self.train_env)
|
||||
|
||||
model.learn(
|
||||
total_timesteps=int(total_timesteps),
|
||||
callback=self.eval_callback
|
||||
)
|
||||
|
||||
if Path(dk.data_path / "best_model.zip").is_file():
|
||||
logger.info('Callback found a best model.')
|
||||
best_model = self.MODELCLASS.load(dk.data_path / "best_model")
|
||||
return best_model
|
||||
|
||||
logger.info('Couldnt find best model, using final model instead.')
|
||||
|
||||
return model
|
||||
|
||||
class MyRLEnv(Base5ActionRLEnv):
|
||||
"""
|
||||
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||
sets a custom reward based on profit and trade duration.
|
||||
"""
|
||||
|
||||
def calculate_reward(self, action: int) -> float:
|
||||
"""
|
||||
An example reward function. This is the one function that users will likely
|
||||
wish to inject their own creativity into.
|
||||
:param action: int = The action made by the agent for the current candle.
|
||||
:return:
|
||||
float = the reward to give to the agent for current step (used for optimization
|
||||
of weights in NN)
|
||||
"""
|
||||
# first, penalize if the action is not valid
|
||||
if not self._is_valid(action):
|
||||
return -2
|
||||
|
||||
pnl = self.get_unrealized_profit()
|
||||
factor = 100.
|
||||
|
||||
# reward agent for entering trades
|
||||
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
||||
and self._position == Positions.Neutral):
|
||||
return 25
|
||||
# discourage agent from not entering trades
|
||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||
return -1
|
||||
|
||||
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
||||
trade_duration = self._current_tick - self._last_trade_tick # type: ignore
|
||||
|
||||
if trade_duration <= max_trade_duration:
|
||||
factor *= 1.5
|
||||
elif trade_duration > max_trade_duration:
|
||||
factor *= 0.5
|
||||
|
||||
# discourage sitting in position
|
||||
if (self._position in (Positions.Short, Positions.Long) and
|
||||
action == Actions.Neutral.value):
|
||||
return -1 * trade_duration / max_trade_duration
|
||||
|
||||
# close long
|
||||
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
||||
if pnl > self.profit_aim * self.rr:
|
||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||
return float(pnl * factor)
|
||||
|
||||
# close short
|
||||
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
||||
if pnl > self.profit_aim * self.rr:
|
||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||
return float(pnl * factor)
|
||||
|
||||
return 0.
|
@@ -0,0 +1,51 @@
|
||||
import logging
|
||||
from typing import Any, Dict # , Tuple
|
||||
|
||||
# import numpy.typing as npt
|
||||
from pandas import DataFrame
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReinforcementLearner_multiproc(ReinforcementLearner):
|
||||
"""
|
||||
Demonstration of how to build vectorized environments
|
||||
"""
|
||||
|
||||
def set_train_and_eval_environments(self, data_dictionary: Dict[str, Any],
|
||||
prices_train: DataFrame, prices_test: DataFrame,
|
||||
dk: FreqaiDataKitchen):
|
||||
"""
|
||||
User can override this if they are using a custom MyRLEnv
|
||||
:param data_dictionary: dict = common data dictionary containing train and test
|
||||
features/labels/weights.
|
||||
:param prices_train/test: DataFrame = dataframe comprised of the prices to be used in
|
||||
the environment during training
|
||||
or testing
|
||||
:param dk: FreqaiDataKitchen = the datakitchen for the current pair
|
||||
"""
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
env_id = "train_env"
|
||||
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
|
||||
self.reward_params, self.CONV_WIDTH, monitor=True,
|
||||
config=self.config) for i
|
||||
in range(self.max_threads)])
|
||||
|
||||
eval_env_id = 'eval_env'
|
||||
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
|
||||
test_df, prices_test,
|
||||
self.reward_params, self.CONV_WIDTH, monitor=True,
|
||||
config=self.config) for i
|
||||
in range(self.max_threads)])
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=len(train_df),
|
||||
best_model_save_path=str(dk.data_path))
|
@@ -191,10 +191,10 @@ class FreqtradeBot(LoggingMixin):
|
||||
# Check whether markets have to be reloaded and reload them when it's needed
|
||||
self.exchange.reload_markets()
|
||||
|
||||
self.update_closed_trades_without_assigned_fees()
|
||||
self.update_trades_without_assigned_fees()
|
||||
|
||||
# Query trades from persistence layer
|
||||
trades = Trade.get_open_trades()
|
||||
trades: List[Trade] = Trade.get_open_trades()
|
||||
|
||||
self.active_pair_whitelist = self._refresh_active_whitelist(trades)
|
||||
|
||||
@@ -354,7 +354,7 @@ class FreqtradeBot(LoggingMixin):
|
||||
if self.trading_mode == TradingMode.FUTURES:
|
||||
self._schedule.run_pending()
|
||||
|
||||
def update_closed_trades_without_assigned_fees(self) -> None:
|
||||
def update_trades_without_assigned_fees(self) -> None:
|
||||
"""
|
||||
Update closed trades without close fees assigned.
|
||||
Only acts when Orders are in the database, otherwise the last order-id is unknown.
|
||||
@@ -381,15 +381,16 @@ class FreqtradeBot(LoggingMixin):
|
||||
|
||||
trades = Trade.get_open_trades_without_assigned_fees()
|
||||
for trade in trades:
|
||||
if trade.is_open and not trade.fee_updated(trade.entry_side):
|
||||
order = trade.select_order(trade.entry_side, False)
|
||||
open_order = trade.select_order(trade.entry_side, True)
|
||||
if order and open_order is None:
|
||||
logger.info(
|
||||
f"Updating {trade.entry_side}-fee on trade {trade}"
|
||||
f"for order {order.order_id}."
|
||||
)
|
||||
self.update_trade_state(trade, order.order_id, send_msg=False)
|
||||
with self._exit_lock:
|
||||
if trade.is_open and not trade.fee_updated(trade.entry_side):
|
||||
order = trade.select_order(trade.entry_side, False)
|
||||
open_order = trade.select_order(trade.entry_side, True)
|
||||
if order and open_order is None:
|
||||
logger.info(
|
||||
f"Updating {trade.entry_side}-fee on trade {trade}"
|
||||
f"for order {order.order_id}."
|
||||
)
|
||||
self.update_trade_state(trade, order.order_id, send_msg=False)
|
||||
|
||||
def handle_insufficient_funds(self, trade: Trade):
|
||||
"""
|
||||
@@ -826,6 +827,8 @@ class FreqtradeBot(LoggingMixin):
|
||||
co = self.exchange.cancel_stoploss_order_with_result(
|
||||
trade.stoploss_order_id, trade.pair, trade.amount)
|
||||
trade.update_order(co)
|
||||
# Reset stoploss order id.
|
||||
trade.stoploss_order_id = None
|
||||
except InvalidOrderException:
|
||||
logger.exception(f"Could not cancel stoploss order {trade.stoploss_order_id}")
|
||||
return trade
|
||||
@@ -982,7 +985,7 @@ class FreqtradeBot(LoggingMixin):
|
||||
# SELL / exit positions / close trades logic and methods
|
||||
#
|
||||
|
||||
def exit_positions(self, trades: List[Any]) -> int:
|
||||
def exit_positions(self, trades: List[Trade]) -> int:
|
||||
"""
|
||||
Tries to execute exit orders for open trades (positions)
|
||||
"""
|
||||
@@ -1010,7 +1013,7 @@ class FreqtradeBot(LoggingMixin):
|
||||
|
||||
def handle_trade(self, trade: Trade) -> bool:
|
||||
"""
|
||||
Sells/exits_short the current pair if the threshold is reached and updates the trade record.
|
||||
Exits the current pair if the threshold is reached and updates the trade record.
|
||||
:return: True if trade has been sold/exited_short, False otherwise
|
||||
"""
|
||||
if not trade.is_open:
|
||||
@@ -1148,7 +1151,7 @@ class FreqtradeBot(LoggingMixin):
|
||||
stoploss = (
|
||||
self.edge.stoploss(pair=trade.pair)
|
||||
if self.edge else
|
||||
self.strategy.stoploss / trade.leverage
|
||||
trade.stop_loss_pct / trade.leverage
|
||||
)
|
||||
if trade.is_short:
|
||||
stop_price = trade.open_rate * (1 - stoploss)
|
||||
@@ -1167,7 +1170,6 @@ class FreqtradeBot(LoggingMixin):
|
||||
if self.create_stoploss_order(trade=trade, stop_price=trade.stoploss_or_liquidation):
|
||||
return False
|
||||
else:
|
||||
trade.stoploss_order_id = None
|
||||
logger.warning('Stoploss order was cancelled, but unable to recreate one.')
|
||||
|
||||
# Finally we check if stoploss on exchange should be moved up because of trailing.
|
||||
|
@@ -692,10 +692,11 @@ class Backtesting:
|
||||
trade.orders.append(order)
|
||||
return trade
|
||||
|
||||
def _get_exit_trade_entry(self, trade: LocalTrade, row: Tuple) -> Optional[LocalTrade]:
|
||||
def _get_exit_trade_entry(
|
||||
self, trade: LocalTrade, row: Tuple, is_first: bool) -> Optional[LocalTrade]:
|
||||
exit_candle_time: datetime = row[DATE_IDX].to_pydatetime()
|
||||
|
||||
if self.trading_mode == TradingMode.FUTURES:
|
||||
if is_first and self.trading_mode == TradingMode.FUTURES:
|
||||
trade.funding_fees = self.exchange.calculate_funding_fees(
|
||||
self.futures_data[trade.pair],
|
||||
amount=trade.amount,
|
||||
@@ -704,32 +705,7 @@ class Backtesting:
|
||||
close_date=exit_candle_time,
|
||||
)
|
||||
|
||||
if self.timeframe_detail and trade.pair in self.detail_data:
|
||||
exit_candle_end = exit_candle_time + timedelta(minutes=self.timeframe_min)
|
||||
|
||||
detail_data = self.detail_data[trade.pair]
|
||||
detail_data = detail_data.loc[
|
||||
(detail_data['date'] >= exit_candle_time) &
|
||||
(detail_data['date'] < exit_candle_end)
|
||||
].copy()
|
||||
if len(detail_data) == 0:
|
||||
# Fall back to "regular" data if no detail data was found for this candle
|
||||
return self._get_exit_trade_entry_for_candle(trade, row)
|
||||
detail_data.loc[:, 'enter_long'] = row[LONG_IDX]
|
||||
detail_data.loc[:, 'exit_long'] = row[ELONG_IDX]
|
||||
detail_data.loc[:, 'enter_short'] = row[SHORT_IDX]
|
||||
detail_data.loc[:, 'exit_short'] = row[ESHORT_IDX]
|
||||
detail_data.loc[:, 'enter_tag'] = row[ENTER_TAG_IDX]
|
||||
detail_data.loc[:, 'exit_tag'] = row[EXIT_TAG_IDX]
|
||||
for det_row in detail_data[HEADERS].values.tolist():
|
||||
res = self._get_exit_trade_entry_for_candle(trade, det_row)
|
||||
if res:
|
||||
return res
|
||||
|
||||
return None
|
||||
|
||||
else:
|
||||
return self._get_exit_trade_entry_for_candle(trade, row)
|
||||
return self._get_exit_trade_entry_for_candle(trade, row)
|
||||
|
||||
def get_valid_price_and_stake(
|
||||
self, pair: str, row: Tuple, propose_rate: float, stake_amount: float,
|
||||
@@ -1074,7 +1050,7 @@ class Backtesting:
|
||||
|
||||
def backtest_loop(
|
||||
self, row: Tuple, pair: str, current_time: datetime, end_date: datetime,
|
||||
max_open_trades: int, open_trade_count_start: int) -> int:
|
||||
max_open_trades: int, open_trade_count_start: int, is_first: bool = True) -> int:
|
||||
"""
|
||||
NOTE: This method is used by Hyperopt at each iteration. Please keep it optimized.
|
||||
|
||||
@@ -1092,9 +1068,11 @@ class Backtesting:
|
||||
# without positionstacking, we can only have one open trade per pair.
|
||||
# max_open_trades must be respected
|
||||
# don't open on the last row
|
||||
# We only open trades on the main candle, not on detail candles
|
||||
trade_dir = self.check_for_trade_entry(row)
|
||||
if (
|
||||
(self._position_stacking or len(LocalTrade.bt_trades_open_pp[pair]) == 0)
|
||||
and is_first
|
||||
and self.trade_slot_available(max_open_trades, open_trade_count_start)
|
||||
and current_time != end_date
|
||||
and trade_dir is not None
|
||||
@@ -1120,7 +1098,7 @@ class Backtesting:
|
||||
|
||||
# 4. Create exit orders (if any)
|
||||
if not trade.open_order_id:
|
||||
self._get_exit_trade_entry(trade, row) # Place exit order if necessary
|
||||
self._get_exit_trade_entry(trade, row, is_first) # Place exit order if necessary
|
||||
|
||||
# 5. Process exit orders.
|
||||
order = trade.select_order(trade.exit_side, is_open=True)
|
||||
@@ -1171,7 +1149,6 @@ class Backtesting:
|
||||
|
||||
self.progress.init_step(BacktestState.BACKTEST, int(
|
||||
(end_date - start_date) / timedelta(minutes=self.timeframe_min)))
|
||||
|
||||
# Loop timerange and get candle for each pair at that point in time
|
||||
while current_time <= end_date:
|
||||
open_trade_count_start = LocalTrade.bt_open_open_trade_count
|
||||
@@ -1185,9 +1162,37 @@ class Backtesting:
|
||||
row_index += 1
|
||||
indexes[pair] = row_index
|
||||
self.dataprovider._set_dataframe_max_index(row_index)
|
||||
current_detail_time: datetime = row[DATE_IDX].to_pydatetime()
|
||||
if self.timeframe_detail and pair in self.detail_data:
|
||||
exit_candle_end = current_detail_time + timedelta(minutes=self.timeframe_min)
|
||||
|
||||
open_trade_count_start = self.backtest_loop(
|
||||
row, pair, current_time, end_date, max_open_trades, open_trade_count_start)
|
||||
detail_data = self.detail_data[pair]
|
||||
detail_data = detail_data.loc[
|
||||
(detail_data['date'] >= current_detail_time) &
|
||||
(detail_data['date'] < exit_candle_end)
|
||||
].copy()
|
||||
if len(detail_data) == 0:
|
||||
# Fall back to "regular" data if no detail data was found for this candle
|
||||
open_trade_count_start = self.backtest_loop(
|
||||
row, pair, current_time, end_date, max_open_trades,
|
||||
open_trade_count_start)
|
||||
detail_data.loc[:, 'enter_long'] = row[LONG_IDX]
|
||||
detail_data.loc[:, 'exit_long'] = row[ELONG_IDX]
|
||||
detail_data.loc[:, 'enter_short'] = row[SHORT_IDX]
|
||||
detail_data.loc[:, 'exit_short'] = row[ESHORT_IDX]
|
||||
detail_data.loc[:, 'enter_tag'] = row[ENTER_TAG_IDX]
|
||||
detail_data.loc[:, 'exit_tag'] = row[EXIT_TAG_IDX]
|
||||
is_first = True
|
||||
current_time_det = current_time
|
||||
for det_row in detail_data[HEADERS].values.tolist():
|
||||
open_trade_count_start = self.backtest_loop(
|
||||
det_row, pair, current_time_det, end_date, max_open_trades,
|
||||
open_trade_count_start, is_first)
|
||||
current_time_det += timedelta(minutes=self.timeframe_detail_min)
|
||||
is_first = False
|
||||
else:
|
||||
open_trade_count_start = self.backtest_loop(
|
||||
row, pair, current_time, end_date, max_open_trades, open_trade_count_start)
|
||||
|
||||
# Move time one configured time_interval ahead.
|
||||
self.progress.increment()
|
||||
|
@@ -17,6 +17,7 @@ from freqtrade.enums import HyperoptState
|
||||
from freqtrade.exceptions import OperationalException
|
||||
from freqtrade.misc import deep_merge_dicts, round_coin_value, round_dict, safe_value_fallback2
|
||||
from freqtrade.optimize.hyperopt_epoch_filters import hyperopt_filter_epochs
|
||||
from freqtrade.optimize.optimize_reports import generate_wins_draws_losses
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -325,8 +326,10 @@ class HyperoptTools():
|
||||
|
||||
# New mode, using backtest result for metrics
|
||||
trials['results_metrics.winsdrawslosses'] = trials.apply(
|
||||
lambda x: f"{x['results_metrics.wins']} {x['results_metrics.draws']:>4} "
|
||||
f"{x['results_metrics.losses']:>4}", axis=1)
|
||||
lambda x: generate_wins_draws_losses(
|
||||
x['results_metrics.wins'], x['results_metrics.draws'],
|
||||
x['results_metrics.losses']
|
||||
), axis=1)
|
||||
|
||||
trials = trials[['Best', 'current_epoch', 'results_metrics.total_trades',
|
||||
'results_metrics.winsdrawslosses',
|
||||
@@ -337,7 +340,7 @@ class HyperoptTools():
|
||||
'loss', 'is_initial_point', 'is_random', 'is_best']]
|
||||
|
||||
trials.columns = [
|
||||
'Best', 'Epoch', 'Trades', ' Win Draw Loss', 'Avg profit',
|
||||
'Best', 'Epoch', 'Trades', ' Win Draw Loss Win%', 'Avg profit',
|
||||
'Total profit', 'Profit', 'Avg duration', 'max_drawdown', 'max_drawdown_account',
|
||||
'max_drawdown_abs', 'Objective', 'is_initial_point', 'is_random', 'is_best'
|
||||
]
|
||||
@@ -467,9 +470,9 @@ class HyperoptTools():
|
||||
|
||||
base_metrics = ['Best', 'current_epoch', 'results_metrics.total_trades',
|
||||
'results_metrics.profit_mean', 'results_metrics.profit_median',
|
||||
'results_metrics.profit_total',
|
||||
'Stake currency',
|
||||
'results_metrics.profit_total', 'Stake currency',
|
||||
'results_metrics.profit_total_abs', 'results_metrics.holding_avg',
|
||||
'results_metrics.trade_count_long', 'results_metrics.trade_count_short',
|
||||
'loss', 'is_initial_point', 'is_best']
|
||||
perc_multi = 100
|
||||
|
||||
@@ -477,7 +480,9 @@ class HyperoptTools():
|
||||
trials = trials[base_metrics + param_metrics]
|
||||
|
||||
base_columns = ['Best', 'Epoch', 'Trades', 'Avg profit', 'Median profit', 'Total profit',
|
||||
'Stake currency', 'Profit', 'Avg duration', 'Objective',
|
||||
'Stake currency', 'Profit', 'Avg duration',
|
||||
'Trade count long', 'Trade count short',
|
||||
'Objective',
|
||||
'is_initial_point', 'is_best']
|
||||
param_columns = list(results[0]['params_dict'].keys())
|
||||
trials.columns = base_columns + param_columns
|
||||
|
@@ -86,7 +86,7 @@ def _get_line_header(first_column: str, stake_currency: str,
|
||||
'Win Draw Loss Win%']
|
||||
|
||||
|
||||
def _generate_wins_draws_losses(wins, draws, losses):
|
||||
def generate_wins_draws_losses(wins, draws, losses):
|
||||
if wins > 0 and losses == 0:
|
||||
wl_ratio = '100'
|
||||
elif wins == 0:
|
||||
@@ -600,7 +600,7 @@ def text_table_bt_results(pair_results: List[Dict[str, Any]], stake_currency: st
|
||||
output = [[
|
||||
t['key'], t['trades'], t['profit_mean_pct'], t['profit_sum_pct'], t['profit_total_abs'],
|
||||
t['profit_total_pct'], t['duration_avg'],
|
||||
_generate_wins_draws_losses(t['wins'], t['draws'], t['losses'])
|
||||
generate_wins_draws_losses(t['wins'], t['draws'], t['losses'])
|
||||
] for t in pair_results]
|
||||
# Ignore type as floatfmt does allow tuples but mypy does not know that
|
||||
return tabulate(output, headers=headers,
|
||||
@@ -626,7 +626,7 @@ def text_table_exit_reason(exit_reason_stats: List[Dict[str, Any]], stake_curren
|
||||
|
||||
output = [[
|
||||
t.get('exit_reason', t.get('sell_reason')), t['trades'],
|
||||
_generate_wins_draws_losses(t['wins'], t['draws'], t['losses']),
|
||||
generate_wins_draws_losses(t['wins'], t['draws'], t['losses']),
|
||||
t['profit_mean_pct'], t['profit_sum_pct'],
|
||||
round_coin_value(t['profit_total_abs'], stake_currency, False),
|
||||
t['profit_total_pct'],
|
||||
@@ -656,7 +656,7 @@ def text_table_tags(tag_type: str, tag_results: List[Dict[str, Any]], stake_curr
|
||||
t['profit_total_abs'],
|
||||
t['profit_total_pct'],
|
||||
t['duration_avg'],
|
||||
_generate_wins_draws_losses(
|
||||
generate_wins_draws_losses(
|
||||
t['wins'],
|
||||
t['draws'],
|
||||
t['losses'])] for t in tag_results]
|
||||
@@ -715,7 +715,7 @@ def text_table_strategy(strategy_results, stake_currency: str) -> str:
|
||||
output = [[
|
||||
t['key'], t['trades'], t['profit_mean_pct'], t['profit_sum_pct'], t['profit_total_abs'],
|
||||
t['profit_total_pct'], t['duration_avg'],
|
||||
_generate_wins_draws_losses(t['wins'], t['draws'], t['losses']), drawdown]
|
||||
generate_wins_draws_losses(t['wins'], t['draws'], t['losses']), drawdown]
|
||||
for t, drawdown in zip(strategy_results, drawdown)]
|
||||
# Ignore type as floatfmt does allow tuples but mypy does not know that
|
||||
return tabulate(output, headers=headers,
|
||||
|
@@ -87,7 +87,7 @@ class PairLocks():
|
||||
Get the lock that expires the latest for the pair given.
|
||||
"""
|
||||
locks = PairLocks.get_pair_locks(pair, now, side=side)
|
||||
locks = sorted(locks, key=lambda l: l.lock_end_time, reverse=True)
|
||||
locks = sorted(locks, key=lambda lock: lock.lock_end_time, reverse=True)
|
||||
return locks[0] if locks else None
|
||||
|
||||
@staticmethod
|
||||
|
@@ -81,8 +81,6 @@ async def validate_ws_token(
|
||||
except HTTPException:
|
||||
pass
|
||||
|
||||
# No checks passed, deny the connection
|
||||
logger.debug("Denying websocket request.")
|
||||
# If it doesn't match, close the websocket connection
|
||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
|
||||
|
@@ -1,16 +1,16 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, WebSocketDisconnect
|
||||
from fastapi.websockets import WebSocket, WebSocketState
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.websockets import WebSocket
|
||||
from pydantic import ValidationError
|
||||
from websockets.exceptions import WebSocketException
|
||||
|
||||
from freqtrade.enums import RPCMessageType, RPCRequestType
|
||||
from freqtrade.rpc.api_server.api_auth import validate_ws_token
|
||||
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
|
||||
from freqtrade.rpc.api_server.ws import WebSocketChannel
|
||||
from freqtrade.rpc.api_server.ws.channel import ChannelManager
|
||||
from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc
|
||||
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel
|
||||
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
|
||||
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
|
||||
WSRequestSchema, WSWhitelistMessage)
|
||||
from freqtrade.rpc.rpc import RPC
|
||||
@@ -22,23 +22,35 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def is_websocket_alive(ws: WebSocket) -> bool:
|
||||
async def channel_reader(channel: WebSocketChannel, rpc: RPC):
|
||||
"""
|
||||
Check if a FastAPI Websocket is still open
|
||||
Iterate over the messages from the channel and process the request
|
||||
"""
|
||||
if (
|
||||
ws.application_state == WebSocketState.CONNECTED and
|
||||
ws.client_state == WebSocketState.CONNECTED
|
||||
):
|
||||
return True
|
||||
return False
|
||||
async for message in channel:
|
||||
await _process_consumer_request(message, channel, rpc)
|
||||
|
||||
|
||||
async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream):
|
||||
"""
|
||||
Iterate over messages in the message stream and send them
|
||||
"""
|
||||
async for message, ts in message_stream:
|
||||
if channel.subscribed_to(message.get('type')):
|
||||
# Log a warning if this channel is behind
|
||||
# on the message stream by a lot
|
||||
if (time.time() - ts) > 60:
|
||||
logger.warning(f"Channel {channel} is behind MessageStream by 1 minute,"
|
||||
" this can cause a memory leak if you see this message"
|
||||
" often, consider reducing pair list size or amount of"
|
||||
" consumers.")
|
||||
|
||||
await channel.send(message, timeout=True)
|
||||
|
||||
|
||||
async def _process_consumer_request(
|
||||
request: Dict[str, Any],
|
||||
channel: WebSocketChannel,
|
||||
rpc: RPC,
|
||||
channel_manager: ChannelManager
|
||||
rpc: RPC
|
||||
):
|
||||
"""
|
||||
Validate and handle a request from a websocket consumer
|
||||
@@ -74,65 +86,29 @@ async def _process_consumer_request(
|
||||
|
||||
# Format response
|
||||
response = WSWhitelistMessage(data=whitelist)
|
||||
# Send it back
|
||||
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
|
||||
await channel.send(response.dict(exclude_none=True))
|
||||
|
||||
elif type == RPCRequestType.ANALYZED_DF:
|
||||
limit = None
|
||||
|
||||
if data:
|
||||
# Limit the amount of candles per dataframe to 'limit' or 1500
|
||||
limit = max(data.get('limit', 1500), 1500)
|
||||
# Limit the amount of candles per dataframe to 'limit' or 1500
|
||||
limit = min(data.get('limit', 1500), 1500) if data else None
|
||||
|
||||
# For every pair in the generator, send a separate message
|
||||
for message in rpc._ws_request_analyzed_df(limit):
|
||||
# Format response
|
||||
response = WSAnalyzedDFMessage(data=message)
|
||||
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
|
||||
await channel.send(response.dict(exclude_none=True))
|
||||
|
||||
|
||||
@router.websocket("/message/ws")
|
||||
async def message_endpoint(
|
||||
ws: WebSocket,
|
||||
websocket: WebSocket,
|
||||
token: str = Depends(validate_ws_token),
|
||||
rpc: RPC = Depends(get_rpc),
|
||||
channel_manager=Depends(get_channel_manager),
|
||||
token: str = Depends(validate_ws_token)
|
||||
message_stream: MessageStream = Depends(get_message_stream)
|
||||
):
|
||||
"""
|
||||
Message WebSocket endpoint, facilitates sending RPC messages
|
||||
"""
|
||||
try:
|
||||
channel = await channel_manager.on_connect(ws)
|
||||
if await is_websocket_alive(ws):
|
||||
|
||||
logger.info(f"Consumer connected - {channel}")
|
||||
|
||||
# Keep connection open until explicitly closed, and process requests
|
||||
try:
|
||||
while not channel.is_closed():
|
||||
request = await channel.recv()
|
||||
|
||||
# Process the request here
|
||||
await _process_consumer_request(request, channel, rpc, channel_manager)
|
||||
|
||||
except (WebSocketDisconnect, WebSocketException):
|
||||
# Handle client disconnects
|
||||
logger.info(f"Consumer disconnected - {channel}")
|
||||
except RuntimeError:
|
||||
# Handle cases like -
|
||||
# RuntimeError('Cannot call "send" once a closed message has been sent')
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.info(f"Consumer connection failed - {channel}: {e}")
|
||||
logger.debug(e, exc_info=e)
|
||||
|
||||
except RuntimeError:
|
||||
# WebSocket was closed
|
||||
# Do nothing
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to serve - {ws.client}")
|
||||
# Log tracebacks to keep track of what errors are happening
|
||||
logger.exception(e)
|
||||
finally:
|
||||
if channel:
|
||||
await channel_manager.on_disconnect(ws)
|
||||
if token:
|
||||
async with create_channel(websocket) as channel:
|
||||
await channel.run_channel_tasks(
|
||||
channel_reader(channel, rpc),
|
||||
channel_broadcaster(channel, message_stream)
|
||||
)
|
||||
|
@@ -41,8 +41,8 @@ def get_exchange(config=Depends(get_config)):
|
||||
return ApiServer._exchange
|
||||
|
||||
|
||||
def get_channel_manager():
|
||||
return ApiServer._ws_channel_manager
|
||||
def get_message_stream():
|
||||
return ApiServer._message_stream
|
||||
|
||||
|
||||
def is_webserver_mode(config=Depends(get_config)):
|
||||
|
@@ -1,22 +1,17 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from ipaddress import IPv4Address
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import orjson
|
||||
import uvicorn
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
# Look into alternatives
|
||||
from janus import Queue as ThreadedQueue
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from freqtrade.constants import Config
|
||||
from freqtrade.exceptions import OperationalException
|
||||
from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer
|
||||
from freqtrade.rpc.api_server.ws import ChannelManager
|
||||
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
|
||||
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
|
||||
from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler
|
||||
|
||||
|
||||
@@ -50,10 +45,8 @@ class ApiServer(RPCHandler):
|
||||
_config: Config = {}
|
||||
# Exchange - only available in webserver mode.
|
||||
_exchange = None
|
||||
# websocket message queue stuff
|
||||
_ws_channel_manager: ChannelManager
|
||||
_ws_thread = None
|
||||
_ws_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
# websocket message stuff
|
||||
_message_stream: Optional[MessageStream] = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""
|
||||
@@ -71,15 +64,11 @@ class ApiServer(RPCHandler):
|
||||
return
|
||||
self._standalone: bool = standalone
|
||||
self._server = None
|
||||
self._ws_queue: Optional[ThreadedQueue] = None
|
||||
self._ws_background_task = None
|
||||
|
||||
ApiServer.__initialized = True
|
||||
|
||||
api_config = self._config['api_server']
|
||||
|
||||
ApiServer._ws_channel_manager = ChannelManager()
|
||||
|
||||
self.app = FastAPI(title="Freqtrade API",
|
||||
docs_url='/docs' if api_config.get('enable_openapi', False) else None,
|
||||
redoc_url=None,
|
||||
@@ -105,21 +94,9 @@ class ApiServer(RPCHandler):
|
||||
del ApiServer._rpc
|
||||
if self._server and not self._standalone:
|
||||
logger.info("Stopping API Server")
|
||||
# self._server.force_exit, self._server.should_exit = True, True
|
||||
self._server.cleanup()
|
||||
|
||||
if self._ws_thread and self._ws_loop:
|
||||
logger.info("Stopping API Server background tasks")
|
||||
|
||||
if self._ws_background_task:
|
||||
# Cancel the queue task
|
||||
self._ws_background_task.cancel()
|
||||
|
||||
self._ws_thread.join()
|
||||
|
||||
self._ws_thread = None
|
||||
self._ws_loop = None
|
||||
self._ws_background_task = None
|
||||
|
||||
@classmethod
|
||||
def shutdown(cls):
|
||||
cls.__initialized = False
|
||||
@@ -129,9 +106,11 @@ class ApiServer(RPCHandler):
|
||||
cls._rpc = None
|
||||
|
||||
def send_msg(self, msg: Dict[str, Any]) -> None:
|
||||
if self._ws_queue:
|
||||
sync_q = self._ws_queue.sync_q
|
||||
sync_q.put(msg)
|
||||
"""
|
||||
Publish the message to the message stream
|
||||
"""
|
||||
if ApiServer._message_stream:
|
||||
ApiServer._message_stream.publish(msg)
|
||||
|
||||
def handle_rpc_exception(self, request, exc):
|
||||
logger.exception(f"API Error calling: {exc}")
|
||||
@@ -170,54 +149,30 @@ class ApiServer(RPCHandler):
|
||||
)
|
||||
|
||||
app.add_exception_handler(RPCException, self.handle_rpc_exception)
|
||||
app.add_event_handler(
|
||||
event_type="startup",
|
||||
func=self._api_startup_event
|
||||
)
|
||||
app.add_event_handler(
|
||||
event_type="shutdown",
|
||||
func=self._api_shutdown_event
|
||||
)
|
||||
|
||||
def start_message_queue(self):
|
||||
if self._ws_thread:
|
||||
return
|
||||
async def _api_startup_event(self):
|
||||
"""
|
||||
Creates the MessageStream class on startup
|
||||
so it has access to the same event loop
|
||||
as uvicorn
|
||||
"""
|
||||
if not ApiServer._message_stream:
|
||||
ApiServer._message_stream = MessageStream()
|
||||
|
||||
# Create a new loop, as it'll be just for the background thread
|
||||
self._ws_loop = asyncio.new_event_loop()
|
||||
|
||||
# Start the thread
|
||||
self._ws_thread = Thread(target=self._ws_loop.run_forever)
|
||||
self._ws_thread.start()
|
||||
|
||||
# Finally, submit the coro to the thread
|
||||
self._ws_background_task = asyncio.run_coroutine_threadsafe(
|
||||
self._broadcast_queue_data(), loop=self._ws_loop)
|
||||
|
||||
async def _broadcast_queue_data(self) -> None:
|
||||
# Instantiate the queue in this coroutine so it's attached to our loop
|
||||
self._ws_queue = ThreadedQueue()
|
||||
async_queue = self._ws_queue.async_q
|
||||
|
||||
try:
|
||||
while True:
|
||||
logger.debug("Getting queue messages...")
|
||||
if (qsize := async_queue.qsize()) > 20:
|
||||
# If the queue becomes too big for too long, this may indicate a problem.
|
||||
logger.warning(f"Queue size now {qsize}")
|
||||
# Get data from queue
|
||||
message: WSMessageSchemaType = await async_queue.get()
|
||||
logger.debug(f"Found message of type: {message.get('type')}")
|
||||
async_queue.task_done()
|
||||
# Broadcast it
|
||||
await self._ws_channel_manager.broadcast(message)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# For testing, shouldn't happen when stable
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception happened in background task: {e}")
|
||||
|
||||
finally:
|
||||
# Disconnect channels and stop the loop on cancel
|
||||
await self._ws_channel_manager.disconnect_all()
|
||||
if self._ws_loop:
|
||||
self._ws_loop.stop()
|
||||
# Avoid adding more items to the queue if they aren't
|
||||
# going to get broadcasted.
|
||||
self._ws_queue = None
|
||||
async def _api_shutdown_event(self):
|
||||
"""
|
||||
Removes the MessageStream class on shutdown
|
||||
"""
|
||||
if ApiServer._message_stream:
|
||||
ApiServer._message_stream = None
|
||||
|
||||
def start_api(self):
|
||||
"""
|
||||
@@ -257,7 +212,6 @@ class ApiServer(RPCHandler):
|
||||
if self._standalone:
|
||||
self._server.run()
|
||||
else:
|
||||
self.start_message_queue()
|
||||
self._server.run_in_thread()
|
||||
except Exception:
|
||||
logger.exception("Api server failed to start.")
|
||||
|
@@ -3,4 +3,5 @@
|
||||
from freqtrade.rpc.api_server.ws.types import WebSocketType
|
||||
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
|
||||
from freqtrade.rpc.api_server.ws.serializer import HybridJSONWebSocketSerializer
|
||||
from freqtrade.rpc.api_server.ws.channel import ChannelManager, WebSocketChannel
|
||||
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
|
||||
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
|
||||
|
@@ -1,11 +1,13 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from threading import RLock
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from collections import deque
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncIterator, Deque, Dict, List, Optional, Type, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import WebSocket as FastAPIWebSocket
|
||||
from fastapi import WebSocketDisconnect
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
|
||||
from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer,
|
||||
@@ -21,31 +23,27 @@ class WebSocketChannel:
|
||||
"""
|
||||
Object to help facilitate managing a websocket connection
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
websocket: WebSocketType,
|
||||
channel_id: Optional[str] = None,
|
||||
drain_timeout: int = 3,
|
||||
throttle: float = 0.01,
|
||||
serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer
|
||||
):
|
||||
|
||||
self.channel_id = channel_id if channel_id else uuid4().hex[:8]
|
||||
|
||||
# The WebSocket object
|
||||
self._websocket = WebSocketProxy(websocket)
|
||||
|
||||
self.drain_timeout = drain_timeout
|
||||
self.throttle = throttle
|
||||
|
||||
self._subscriptions: List[str] = []
|
||||
# 32 is the size of the receiving queue in websockets package
|
||||
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
|
||||
self._relay_task = asyncio.create_task(self.relay())
|
||||
|
||||
# Internal event to signify a closed websocket
|
||||
self._closed = asyncio.Event()
|
||||
# The async tasks created for the channel
|
||||
self._channel_tasks: List[asyncio.Task] = []
|
||||
|
||||
# Deque for average send times
|
||||
self._send_times: Deque[float] = deque([], maxlen=10)
|
||||
# High limit defaults to 3 to start
|
||||
self._send_high_limit = 3
|
||||
|
||||
# The subscribed message types
|
||||
self._subscriptions: List[str] = []
|
||||
|
||||
# Wrap the WebSocket in the Serializing class
|
||||
self._wrapped_ws = serializer_cls(self._websocket)
|
||||
@@ -61,43 +59,58 @@ class WebSocketChannel:
|
||||
def remote_addr(self):
|
||||
return self._websocket.remote_addr
|
||||
|
||||
async def _send(self, data):
|
||||
"""
|
||||
Send data on the wrapped websocket
|
||||
"""
|
||||
await self._wrapped_ws.send(data)
|
||||
@property
|
||||
def avg_send_time(self):
|
||||
return sum(self._send_times) / len(self._send_times)
|
||||
|
||||
async def send(self, data) -> bool:
|
||||
def _calc_send_limit(self):
|
||||
"""
|
||||
Add the data to the queue to be sent.
|
||||
:returns: True if data added to queue, False otherwise
|
||||
Calculate the send high limit for this channel
|
||||
"""
|
||||
|
||||
# This block only runs if the queue is full, it will wait
|
||||
# until self.drain_timeout for the relay to drain the outgoing queue
|
||||
# We can't use asyncio.wait_for here because the queue may have been created with a
|
||||
# different eventloop
|
||||
if not self.is_closed():
|
||||
start = time.time()
|
||||
while self.queue.full():
|
||||
await asyncio.sleep(1)
|
||||
if (time.time() - start) > self.drain_timeout:
|
||||
return False
|
||||
# Only update if we have enough data
|
||||
if len(self._send_times) == self._send_times.maxlen:
|
||||
# At least 1s or twice the average of send times, with a
|
||||
# maximum of 3 seconds per message
|
||||
self._send_high_limit = min(max(self.avg_send_time * 2, 1), 3)
|
||||
|
||||
# If for some reason the queue is still full, just return False
|
||||
try:
|
||||
self.queue.put_nowait(data)
|
||||
except asyncio.QueueFull:
|
||||
return False
|
||||
async def send(
|
||||
self,
|
||||
message: Union[WSMessageSchemaType, Dict[str, Any]],
|
||||
timeout: bool = False
|
||||
):
|
||||
"""
|
||||
Send a message on the wrapped websocket. If the sending
|
||||
takes too long, it will raise a TimeoutError and
|
||||
disconnect the connection.
|
||||
|
||||
# If we got here everything is ok
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
:param message: The message to send
|
||||
:param timeout: Enforce send high limit, defaults to False
|
||||
"""
|
||||
try:
|
||||
_ = time.time()
|
||||
# If the send times out, it will raise
|
||||
# a TimeoutError and bubble up to the
|
||||
# message_endpoint to close the connection
|
||||
await asyncio.wait_for(
|
||||
self._wrapped_ws.send(message),
|
||||
timeout=self._send_high_limit if timeout else None
|
||||
)
|
||||
total_time = time.time() - _
|
||||
self._send_times.append(total_time)
|
||||
|
||||
self._calc_send_limit()
|
||||
except asyncio.TimeoutError:
|
||||
logger.info(f"Connection for {self} timed out, disconnecting")
|
||||
raise
|
||||
|
||||
# Explicitly give control back to event loop as
|
||||
# websockets.send does not
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def recv(self):
|
||||
"""
|
||||
Receive data on the wrapped websocket
|
||||
Receive a message on the wrapped websocket
|
||||
"""
|
||||
return await self._wrapped_ws.recv()
|
||||
|
||||
@@ -107,17 +120,27 @@ class WebSocketChannel:
|
||||
"""
|
||||
return await self._websocket.ping()
|
||||
|
||||
async def accept(self):
|
||||
"""
|
||||
Accept the underlying websocket connection,
|
||||
if the connection has been closed before we can
|
||||
accept, just close the channel.
|
||||
"""
|
||||
try:
|
||||
return await self._websocket.accept()
|
||||
except RuntimeError:
|
||||
await self.close()
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Close the WebSocketChannel
|
||||
"""
|
||||
|
||||
self._closed.set()
|
||||
self._relay_task.cancel()
|
||||
|
||||
try:
|
||||
await self.raw_websocket.close()
|
||||
except Exception:
|
||||
await self._websocket.close()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
@@ -142,99 +165,76 @@ class WebSocketChannel:
|
||||
"""
|
||||
return message_type in self._subscriptions
|
||||
|
||||
async def relay(self):
|
||||
async def run_channel_tasks(self, *tasks, **kwargs):
|
||||
"""
|
||||
Relay messages from the channel's queue and send them out. This is started
|
||||
as a task.
|
||||
Create and await on the channel tasks unless an exception
|
||||
was raised, then cancel them all.
|
||||
|
||||
:params *tasks: All coros or tasks to be run concurrently
|
||||
:param **kwargs: Any extra kwargs to pass to gather
|
||||
"""
|
||||
while not self._closed.is_set():
|
||||
message = await self.queue.get()
|
||||
|
||||
if not self.is_closed():
|
||||
# Wrap the coros into tasks if they aren't already
|
||||
self._channel_tasks = [
|
||||
task if isinstance(task, asyncio.Task) else asyncio.create_task(task)
|
||||
for task in tasks
|
||||
]
|
||||
|
||||
try:
|
||||
await self._send(message)
|
||||
self.queue.task_done()
|
||||
return await asyncio.gather(*self._channel_tasks, **kwargs)
|
||||
except Exception:
|
||||
# If an exception occurred, cancel the rest of the tasks
|
||||
await self.cancel_channel_tasks()
|
||||
|
||||
# Limit messages per sec.
|
||||
# Could cause problems with queue size if too low, and
|
||||
# problems with network traffik if too high.
|
||||
# 0.01 = 100/s
|
||||
await asyncio.sleep(self.throttle)
|
||||
except RuntimeError:
|
||||
# The connection was closed, just exit the task
|
||||
return
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
def __init__(self):
|
||||
self.channels = dict()
|
||||
self._lock = RLock() # Re-entrant Lock
|
||||
|
||||
async def on_connect(self, websocket: WebSocketType):
|
||||
async def cancel_channel_tasks(self):
|
||||
"""
|
||||
Wrap websocket connection into Channel and add to list
|
||||
|
||||
:param websocket: The WebSocket object to attach to the Channel
|
||||
Cancel and wait on all channel tasks
|
||||
"""
|
||||
if isinstance(websocket, FastAPIWebSocket):
|
||||
for task in self._channel_tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for tasks to finish cancelling
|
||||
try:
|
||||
await websocket.accept()
|
||||
except RuntimeError:
|
||||
# The connection was closed before we could accept it
|
||||
return
|
||||
await task
|
||||
except (
|
||||
asyncio.CancelledError,
|
||||
asyncio.TimeoutError,
|
||||
WebSocketDisconnect,
|
||||
ConnectionClosed,
|
||||
RuntimeError
|
||||
):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.info(f"Encountered unknown exception: {e}", exc_info=e)
|
||||
|
||||
ws_channel = WebSocketChannel(websocket)
|
||||
self._channel_tasks = []
|
||||
|
||||
with self._lock:
|
||||
self.channels[websocket] = ws_channel
|
||||
|
||||
return ws_channel
|
||||
|
||||
async def on_disconnect(self, websocket: WebSocketType):
|
||||
async def __aiter__(self):
|
||||
"""
|
||||
Call close on the channel if it's not, and remove from channel list
|
||||
Generator for received messages
|
||||
"""
|
||||
# We can not catch any errors here as websocket.recv is
|
||||
# the first to catch any disconnects and bubble it up
|
||||
# so the connection is garbage collected right away
|
||||
while not self.is_closed():
|
||||
yield await self.recv()
|
||||
|
||||
:param websocket: The WebSocket objet attached to the Channel
|
||||
"""
|
||||
with self._lock:
|
||||
channel = self.channels.get(websocket)
|
||||
if channel:
|
||||
logger.info(f"Disconnecting channel {channel}")
|
||||
if not channel.is_closed():
|
||||
await channel.close()
|
||||
|
||||
del self.channels[websocket]
|
||||
@asynccontextmanager
|
||||
async def create_channel(
|
||||
websocket: WebSocketType,
|
||||
**kwargs
|
||||
) -> AsyncIterator[WebSocketChannel]:
|
||||
"""
|
||||
Context manager for safely opening and closing a WebSocketChannel
|
||||
"""
|
||||
channel = WebSocketChannel(websocket, **kwargs)
|
||||
try:
|
||||
await channel.accept()
|
||||
logger.info(f"Connected to channel - {channel}")
|
||||
|
||||
async def disconnect_all(self):
|
||||
"""
|
||||
Disconnect all Channels
|
||||
"""
|
||||
with self._lock:
|
||||
for websocket in self.channels.copy().keys():
|
||||
await self.on_disconnect(websocket)
|
||||
|
||||
async def broadcast(self, message: WSMessageSchemaType):
|
||||
"""
|
||||
Broadcast a message on all Channels
|
||||
|
||||
:param message: The message to send
|
||||
"""
|
||||
with self._lock:
|
||||
for channel in self.channels.copy().values():
|
||||
if channel.subscribed_to(message.get('type')):
|
||||
await self.send_direct(channel, message)
|
||||
|
||||
async def send_direct(
|
||||
self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]):
|
||||
"""
|
||||
Send a message directly through direct_channel only
|
||||
|
||||
:param direct_channel: The WebSocketChannel object to send the message through
|
||||
:param message: The message to send
|
||||
"""
|
||||
if not await channel.send(message):
|
||||
await self.on_disconnect(channel.raw_websocket)
|
||||
|
||||
def has_channels(self):
|
||||
"""
|
||||
Flag for more than 0 channels
|
||||
"""
|
||||
return len(self.channels) > 0
|
||||
yield channel
|
||||
finally:
|
||||
await channel.close()
|
||||
logger.info(f"Disconnected from channel - {channel}")
|
||||
|
31
freqtrade/rpc/api_server/ws/message_stream.py
Normal file
31
freqtrade/rpc/api_server/ws/message_stream.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
|
||||
class MessageStream:
|
||||
"""
|
||||
A message stream for consumers to subscribe to,
|
||||
and for producers to publish to.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._waiter = self._loop.create_future()
|
||||
|
||||
def publish(self, message):
|
||||
"""
|
||||
Publish a message to this MessageStream
|
||||
|
||||
:param message: The message to publish
|
||||
"""
|
||||
waiter, self._waiter = self._waiter, self._loop.create_future()
|
||||
waiter.set_result((message, time.time(), self._waiter))
|
||||
|
||||
async def __aiter__(self):
|
||||
"""
|
||||
Iterate over the messages in the message stream
|
||||
"""
|
||||
waiter = self._waiter
|
||||
while True:
|
||||
# Shield the future from being cancelled by a task waiting on it
|
||||
message, ts, waiter = await asyncio.shield(waiter)
|
||||
yield message, ts
|
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import orjson
|
||||
import rapidjson
|
||||
@@ -7,6 +8,7 @@ from pandas import DataFrame
|
||||
|
||||
from freqtrade.misc import dataframe_to_json, json_to_dataframe
|
||||
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
|
||||
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -24,17 +26,13 @@ class WebSocketSerializer(ABC):
|
||||
def _deserialize(self, data):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def send(self, data: bytes):
|
||||
async def send(self, data: Union[WSMessageSchemaType, Dict[str, Any]]):
|
||||
await self._websocket.send(self._serialize(data))
|
||||
|
||||
async def recv(self) -> bytes:
|
||||
data = await self._websocket.recv()
|
||||
|
||||
return self._deserialize(data)
|
||||
|
||||
async def close(self, code: int = 1000):
|
||||
await self._websocket.close(code)
|
||||
|
||||
|
||||
class HybridJSONWebSocketSerializer(WebSocketSerializer):
|
||||
def _serialize(self, data) -> str:
|
||||
|
@@ -31,6 +31,7 @@ class Producer(TypedDict):
|
||||
name: str
|
||||
host: str
|
||||
port: int
|
||||
secure: bool
|
||||
ws_token: str
|
||||
|
||||
|
||||
@@ -180,7 +181,8 @@ class ExternalMessageConsumer:
|
||||
host, port = producer['host'], producer['port']
|
||||
token = producer['ws_token']
|
||||
name = producer['name']
|
||||
ws_url = f"ws://{host}:{port}/api/v1/message/ws?token={token}"
|
||||
scheme = 'wss' if producer.get('secure', False) else 'ws'
|
||||
ws_url = f"{scheme}://{host}:{port}/api/v1/message/ws?token={token}"
|
||||
|
||||
# This will raise InvalidURI if the url is bad
|
||||
async with websockets.connect(
|
||||
|
@@ -789,17 +789,18 @@ class RPC:
|
||||
if not order_type:
|
||||
order_type = self._freqtrade.strategy.order_types.get(
|
||||
'force_entry', self._freqtrade.strategy.order_types['entry'])
|
||||
if self._freqtrade.execute_entry(pair, stake_amount, price,
|
||||
ordertype=order_type, trade=trade,
|
||||
is_short=is_short,
|
||||
enter_tag=enter_tag,
|
||||
leverage_=leverage,
|
||||
):
|
||||
Trade.commit()
|
||||
trade = Trade.get_trades([Trade.is_open.is_(True), Trade.pair == pair]).first()
|
||||
return trade
|
||||
else:
|
||||
raise RPCException(f'Failed to enter position for {pair}.')
|
||||
with self._freqtrade._exit_lock:
|
||||
if self._freqtrade.execute_entry(pair, stake_amount, price,
|
||||
ordertype=order_type, trade=trade,
|
||||
is_short=is_short,
|
||||
enter_tag=enter_tag,
|
||||
leverage_=leverage,
|
||||
):
|
||||
Trade.commit()
|
||||
trade = Trade.get_trades([Trade.is_open.is_(True), Trade.pair == pair]).first()
|
||||
return trade
|
||||
else:
|
||||
raise RPCException(f'Failed to enter position for {pair}.')
|
||||
|
||||
def _rpc_delete(self, trade_id: int) -> Dict[str, Union[str, int]]:
|
||||
"""
|
||||
|
@@ -19,7 +19,7 @@ class FreqaiExampleHybridStrategy(IStrategy):
|
||||
|
||||
Launching this strategy would be:
|
||||
|
||||
freqtrade trade --strategy FreqaiExampleHyridStrategy --strategy-path freqtrade/templates
|
||||
freqtrade trade --strategy FreqaiExampleHybridStrategy --strategy-path freqtrade/templates
|
||||
--freqaimodel CatboostClassifier --config config_examples/config_freqai.example.json
|
||||
|
||||
or the user simply adds this to their config:
|
||||
@@ -86,7 +86,7 @@ class FreqaiExampleHybridStrategy(IStrategy):
|
||||
process_only_new_candles = True
|
||||
stoploss = -0.05
|
||||
use_exit_signal = True
|
||||
startup_candle_count: int = 300
|
||||
startup_candle_count: int = 30
|
||||
can_short = True
|
||||
|
||||
# Hyperoptable parameters
|
||||
|
@@ -328,7 +328,7 @@
|
||||
"# Show graph inline\n",
|
||||
"# graph.show()\n",
|
||||
"\n",
|
||||
"# Render graph in a seperate window\n",
|
||||
"# Render graph in a separate window\n",
|
||||
"graph.show(renderer=\"browser\")\n"
|
||||
]
|
||||
},
|
||||
|
Reference in New Issue
Block a user