Merge branch 'develop' into backtest_fitlivepredictions

This commit is contained in:
Wagner Costa
2022-11-29 09:39:15 -03:00
63 changed files with 3805 additions and 1804 deletions

View File

@@ -1,5 +1,5 @@
""" Freqtrade bot """
__version__ = '2022.11.dev'
__version__ = '2022.12.dev'
if 'dev' in __version__:
try:

View File

@@ -512,6 +512,7 @@ CONF_SCHEMA = {
'minimum': 0,
'maximum': 65535
},
'secure': {'type': 'boolean', 'default': False},
'ws_token': {'type': 'string'},
},
'required': ['name', 'host', 'ws_token']
@@ -577,9 +578,27 @@ CONF_SCHEMA = {
},
},
"model_training_parameters": {
"type": "object"
},
"rl_config": {
"type": "object",
"properties": {
"n_estimators": {"type": "integer", "default": 1000}
"train_cycles": {"type": "integer"},
"max_trade_duration_candles": {"type": "integer"},
"add_state_info": {"type": "boolean", "default": False},
"max_training_drawdown_pct": {"type": "number", "default": 0.02},
"cpu_count": {"type": "integer", "default": 1},
"model_type": {"type": "string", "default": "PPO"},
"policy_type": {"type": "string", "default": "MlpPolicy"},
"net_arch": {"type": "array", "default": [128, 128]},
"randomize_startinng_position": {"type": "boolean", "default": False},
"model_reward_parameters": {
"type": "object",
"properties": {
"rr": {"type": "number", "default": 1},
"profit_aim": {"type": "number", "default": 0.025}
}
}
},
},
},

File diff suppressed because it is too large Load Diff

View File

@@ -20,7 +20,7 @@ class Bybit(Exchange):
"""
_ft_has: Dict = {
"ohlcv_candle_limit": 200,
"ohlcv_candle_limit": 1000,
"ccxt_futures_name": "linear",
"ohlcv_has_history": False,
}

View File

@@ -218,3 +218,19 @@ class Kraken(Exchange):
fees = sum(df['open_fund'] * df['open_mark'] * amount * time_in_ratio)
return fees if is_short else -fees
def _trades_contracts_to_amount(self, trades: List) -> List:
"""
Fix "last" id issue for kraken data downloads
This whole override can probably be removed once the following
issue is closed in ccxt: https://github.com/ccxt/ccxt/issues/15827
"""
super()._trades_contracts_to_amount(trades)
if (
len(trades) > 0
and isinstance(trades[-1].get('info'), list)
and len(trades[-1].get('info', [])) > 7
):
trades[-1]['id'] = trades[-1].get('info', [])[-1]
return trades

View File

@@ -0,0 +1,135 @@
import logging
from enum import Enum
from gym import spaces
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
logger = logging.getLogger(__name__)
class Actions(Enum):
Neutral = 0
Exit = 1
Long_enter = 2
Short_enter = 3
class Base4ActionRLEnv(BaseEnvironment):
"""
Base class for a 4 action environment
"""
def set_action_space(self):
self.action_space = spaces.Discrete(len(Actions))
def step(self, action: int):
"""
Logic for a single step (incrementing one candle in time)
by the agent
:param: action: int = the action type that the agent plans
to take for the current step.
:returns:
observation = current state of environment
step_reward = the reward from `calculate_reward()`
_done = if the agent "died" or if the candles finished
info = dict passed back to openai gym lib
"""
self._done = False
self._current_tick += 1
if self._current_tick == self._end_tick:
self._done = True
self._update_unrealized_total_profit()
step_reward = self.calculate_reward(action)
self.total_reward += step_reward
trade_type = None
if self.is_tradesignal(action):
"""
Action: Neutral, position: Long -> Close Long
Action: Neutral, position: Short -> Close Short
Action: Long, position: Neutral -> Open Long
Action: Long, position: Short -> Close Short and Open Long
Action: Short, position: Neutral -> Open Short
Action: Short, position: Long -> Close Long and Open Short
"""
if action == Actions.Neutral.value:
self._position = Positions.Neutral
trade_type = "neutral"
self._last_trade_tick = None
elif action == Actions.Long_enter.value:
self._position = Positions.Long
trade_type = "long"
self._last_trade_tick = self._current_tick
elif action == Actions.Short_enter.value:
self._position = Positions.Short
trade_type = "short"
self._last_trade_tick = self._current_tick
elif action == Actions.Exit.value:
self._update_total_profit()
self._position = Positions.Neutral
trade_type = "neutral"
self._last_trade_tick = None
else:
print("case not defined")
if trade_type is not None:
self.trade_history.append(
{'price': self.current_price(), 'index': self._current_tick,
'type': trade_type})
if self._total_profit < 1 - self.rl_config.get('max_training_drawdown_pct', 0.8):
self._done = True
self._position_history.append(self._position)
info = dict(
tick=self._current_tick,
total_reward=self.total_reward,
total_profit=self._total_profit,
position=self._position.value
)
observation = self._get_observation()
self._update_history(info)
return observation, step_reward, self._done, info
def is_tradesignal(self, action: int) -> bool:
"""
Determine if the signal is a trade signal
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
"""
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
(action == Actions.Neutral.value and self._position == Positions.Short) or
(action == Actions.Neutral.value and self._position == Positions.Long) or
(action == Actions.Short_enter.value and self._position == Positions.Short) or
(action == Actions.Short_enter.value and self._position == Positions.Long) or
(action == Actions.Exit.value and self._position == Positions.Neutral) or
(action == Actions.Long_enter.value and self._position == Positions.Long) or
(action == Actions.Long_enter.value and self._position == Positions.Short))
def _is_valid(self, action: int) -> bool:
"""
Determine if the signal is valid.
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
"""
# Agent should only try to exit if it is in position
if action == Actions.Exit.value:
if self._position not in (Positions.Short, Positions.Long):
return False
# Agent should only try to enter if it is not in position
if action in (Actions.Short_enter.value, Actions.Long_enter.value):
if self._position != Positions.Neutral:
return False
return True

View File

@@ -0,0 +1,145 @@
import logging
from enum import Enum
from gym import spaces
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
logger = logging.getLogger(__name__)
class Actions(Enum):
Neutral = 0
Long_enter = 1
Long_exit = 2
Short_enter = 3
Short_exit = 4
class Base5ActionRLEnv(BaseEnvironment):
"""
Base class for a 5 action environment
"""
def set_action_space(self):
self.action_space = spaces.Discrete(len(Actions))
def step(self, action: int):
"""
Logic for a single step (incrementing one candle in time)
by the agent
:param: action: int = the action type that the agent plans
to take for the current step.
:returns:
observation = current state of environment
step_reward = the reward from `calculate_reward()`
_done = if the agent "died" or if the candles finished
info = dict passed back to openai gym lib
"""
self._done = False
self._current_tick += 1
if self._current_tick == self._end_tick:
self._done = True
self._update_unrealized_total_profit()
step_reward = self.calculate_reward(action)
self.total_reward += step_reward
trade_type = None
if self.is_tradesignal(action):
"""
Action: Neutral, position: Long -> Close Long
Action: Neutral, position: Short -> Close Short
Action: Long, position: Neutral -> Open Long
Action: Long, position: Short -> Close Short and Open Long
Action: Short, position: Neutral -> Open Short
Action: Short, position: Long -> Close Long and Open Short
"""
if action == Actions.Neutral.value:
self._position = Positions.Neutral
trade_type = "neutral"
self._last_trade_tick = None
elif action == Actions.Long_enter.value:
self._position = Positions.Long
trade_type = "long"
self._last_trade_tick = self._current_tick
elif action == Actions.Short_enter.value:
self._position = Positions.Short
trade_type = "short"
self._last_trade_tick = self._current_tick
elif action == Actions.Long_exit.value:
self._update_total_profit()
self._position = Positions.Neutral
trade_type = "neutral"
self._last_trade_tick = None
elif action == Actions.Short_exit.value:
self._update_total_profit()
self._position = Positions.Neutral
trade_type = "neutral"
self._last_trade_tick = None
else:
print("case not defined")
if trade_type is not None:
self.trade_history.append(
{'price': self.current_price(), 'index': self._current_tick,
'type': trade_type})
if (self._total_profit < self.max_drawdown or
self._total_unrealized_profit < self.max_drawdown):
self._done = True
self._position_history.append(self._position)
info = dict(
tick=self._current_tick,
total_reward=self.total_reward,
total_profit=self._total_profit,
position=self._position.value
)
observation = self._get_observation()
self._update_history(info)
return observation, step_reward, self._done, info
def is_tradesignal(self, action: int) -> bool:
"""
Determine if the signal is a trade signal
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
"""
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
(action == Actions.Neutral.value and self._position == Positions.Short) or
(action == Actions.Neutral.value and self._position == Positions.Long) or
(action == Actions.Short_enter.value and self._position == Positions.Short) or
(action == Actions.Short_enter.value and self._position == Positions.Long) or
(action == Actions.Short_exit.value and self._position == Positions.Long) or
(action == Actions.Short_exit.value and self._position == Positions.Neutral) or
(action == Actions.Long_enter.value and self._position == Positions.Long) or
(action == Actions.Long_enter.value and self._position == Positions.Short) or
(action == Actions.Long_exit.value and self._position == Positions.Short) or
(action == Actions.Long_exit.value and self._position == Positions.Neutral))
def _is_valid(self, action: int) -> bool:
# trade signal
"""
Determine if the signal is valid.
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
"""
# Agent should only try to exit if it is in position
if action in (Actions.Short_exit.value, Actions.Long_exit.value):
if self._position not in (Positions.Short, Positions.Long):
return False
# Agent should only try to enter if it is not in position
if action in (Actions.Short_enter.value, Actions.Long_enter.value):
if self._position != Positions.Neutral:
return False
return True

View File

@@ -0,0 +1,307 @@
import logging
import random
from abc import abstractmethod
from enum import Enum
from typing import Optional
import gym
import numpy as np
import pandas as pd
from gym import spaces
from gym.utils import seeding
from pandas import DataFrame
from freqtrade.data.dataprovider import DataProvider
logger = logging.getLogger(__name__)
class Positions(Enum):
Short = 0
Long = 1
Neutral = 0.5
def opposite(self):
return Positions.Short if self == Positions.Long else Positions.Long
class BaseEnvironment(gym.Env):
"""
Base class for environments. This class is agnostic to action count.
Inherited classes customize this to include varying action counts/types,
See RL/Base5ActionRLEnv.py and RL/Base4ActionRLEnv.py
"""
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
reward_kwargs: dict = {}, window_size=10, starting_point=True,
id: str = 'baseenv-1', seed: int = 1, config: dict = {},
dp: Optional[DataProvider] = None):
"""
Initializes the training/eval environment.
:param df: dataframe of features
:param prices: dataframe of prices to be used in the training environment
:param window_size: size of window (temporal) to pass to the agent
:param reward_kwargs: extra config settings assigned by user in `rl_config`
:param starting_point: start at edge of window or not
:param id: string id of the environment (used in backend for multiprocessed env)
:param seed: Sets the seed of the environment higher in the gym.Env object
:param config: Typical user configuration file
:param dp: dataprovider from freqtrade
"""
self.config = config
self.rl_config = config['freqai']['rl_config']
self.add_state_info = self.rl_config.get('add_state_info', False)
self.id = id
self.seed(seed)
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
self.compound_trades = config['stake_amount'] == 'unlimited'
if self.config.get('fee', None) is not None:
self.fee = self.config['fee']
elif dp is not None:
self.fee = dp._exchange.get_fee(symbol=dp.current_whitelist()[0]) # type: ignore
else:
self.fee = 0.0015
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
reward_kwargs: dict, starting_point=True):
"""
Resets the environment when the agent fails (in our case, if the drawdown
exceeds the user set max_training_drawdown_pct)
:param df: dataframe of features
:param prices: dataframe of prices to be used in the training environment
:param window_size: size of window (temporal) to pass to the agent
:param reward_kwargs: extra config settings assigned by user in `rl_config`
:param starting_point: start at edge of window or not
"""
self.df = df
self.signal_features = self.df
self.prices = prices
self.window_size = window_size
self.starting_point = starting_point
self.rr = reward_kwargs["rr"]
self.profit_aim = reward_kwargs["profit_aim"]
# # spaces
if self.add_state_info:
self.total_features = self.signal_features.shape[1] + 3
else:
self.total_features = self.signal_features.shape[1]
self.shape = (window_size, self.total_features)
self.set_action_space()
self.observation_space = spaces.Box(
low=-1, high=1, shape=self.shape, dtype=np.float32)
# episode
self._start_tick: int = self.window_size
self._end_tick: int = len(self.prices) - 1
self._done: bool = False
self._current_tick: int = self._start_tick
self._last_trade_tick: Optional[int] = None
self._position = Positions.Neutral
self._position_history: list = [None]
self.total_reward: float = 0
self._total_profit: float = 1
self._total_unrealized_profit: float = 1
self.history: dict = {}
self.trade_history: list = []
@abstractmethod
def set_action_space(self):
"""
Unique to the environment action count. Must be inherited.
"""
def seed(self, seed: int = 1):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def reset(self):
self._done = False
if self.starting_point is True:
if self.rl_config.get('randomize_starting_position', False):
length_of_data = int(self._end_tick / 4)
start_tick = random.randint(self.window_size + 1, length_of_data)
self._start_tick = start_tick
self._position_history = (self._start_tick * [None]) + [self._position]
else:
self._position_history = (self.window_size * [None]) + [self._position]
self._current_tick = self._start_tick
self._last_trade_tick = None
self._position = Positions.Neutral
self.total_reward = 0.
self._total_profit = 1. # unit
self.history = {}
self.trade_history = []
self.portfolio_log_returns = np.zeros(len(self.prices))
self._profits = [(self._start_tick, 1)]
self.close_trade_profit = []
self._total_unrealized_profit = 1
return self._get_observation()
@abstractmethod
def step(self, action: int):
"""
Step depeneds on action types, this must be inherited.
"""
return
def _get_observation(self):
"""
This may or may not be independent of action types, user can inherit
this in their custom "MyRLEnv"
"""
features_window = self.signal_features[(
self._current_tick - self.window_size):self._current_tick]
if self.add_state_info:
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
columns=['current_profit_pct',
'position',
'trade_duration'],
index=features_window.index)
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
features_and_state['position'] = self._position.value
features_and_state['trade_duration'] = self.get_trade_duration()
features_and_state = pd.concat([features_window, features_and_state], axis=1)
return features_and_state
else:
return features_window
def get_trade_duration(self):
"""
Get the trade duration if the agent is in a trade
"""
if self._last_trade_tick is None:
return 0
else:
return self._current_tick - self._last_trade_tick
def get_unrealized_profit(self):
"""
Get the unrealized profit if the agent is in a trade
"""
if self._last_trade_tick is None:
return 0.
if self._position == Positions.Neutral:
return 0.
elif self._position == Positions.Short:
current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open)
last_trade_price = self.add_entry_fee(self.prices.iloc[self._last_trade_tick].open)
return (last_trade_price - current_price) / last_trade_price
elif self._position == Positions.Long:
current_price = self.add_entry_fee(self.prices.iloc[self._current_tick].open)
last_trade_price = self.add_exit_fee(self.prices.iloc[self._last_trade_tick].open)
return (current_price - last_trade_price) / last_trade_price
else:
return 0.
@abstractmethod
def is_tradesignal(self, action: int) -> bool:
"""
Determine if the signal is a trade signal. This is
unique to the actions in the environment, and therefore must be
inherited.
"""
return True
def _is_valid(self, action: int) -> bool:
"""
Determine if the signal is valid.This is
unique to the actions in the environment, and therefore must be
inherited.
"""
return True
def add_entry_fee(self, price):
return price * (1 + self.fee)
def add_exit_fee(self, price):
return price / (1 + self.fee)
def _update_history(self, info):
if not self.history:
self.history = {key: [] for key in info.keys()}
for key, value in info.items():
self.history[key].append(value)
@abstractmethod
def calculate_reward(self, action: int) -> float:
"""
An example reward function. This is the one function that users will likely
wish to inject their own creativity into.
:param action: int = The action made by the agent for the current candle.
:return:
float = the reward to give to the agent for current step (used for optimization
of weights in NN)
"""
def _update_unrealized_total_profit(self):
"""
Update the unrealized total profit incase of episode end.
"""
if self._position in (Positions.Long, Positions.Short):
pnl = self.get_unrealized_profit()
if self.compound_trades:
# assumes unit stake and compounding
unrl_profit = self._total_profit * (1 + pnl)
else:
# assumes unit stake and no compounding
unrl_profit = self._total_profit + pnl
self._total_unrealized_profit = unrl_profit
def _update_total_profit(self):
pnl = self.get_unrealized_profit()
if self.compound_trades:
# assumes unit stake and compounding
self._total_profit = self._total_profit * (1 + pnl)
else:
# assumes unit stake and no compounding
self._total_profit += pnl
def current_price(self) -> float:
return self.prices.iloc[self._current_tick].open
# Keeping around incase we want to start building more complex environment
# templates in the future.
# def most_recent_return(self):
# """
# Calculate the tick to tick return if in a trade.
# Return is generated from rising prices in Long
# and falling prices in Short positions.
# The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
# """
# # Long positions
# if self._position == Positions.Long:
# current_price = self.prices.iloc[self._current_tick].open
# previous_price = self.prices.iloc[self._current_tick - 1].open
# if (self._position_history[self._current_tick - 1] == Positions.Short
# or self._position_history[self._current_tick - 1] == Positions.Neutral):
# previous_price = self.add_entry_fee(previous_price)
# return np.log(current_price) - np.log(previous_price)
# # Short positions
# if self._position == Positions.Short:
# current_price = self.prices.iloc[self._current_tick].open
# previous_price = self.prices.iloc[self._current_tick - 1].open
# if (self._position_history[self._current_tick - 1] == Positions.Long
# or self._position_history[self._current_tick - 1] == Positions.Neutral):
# previous_price = self.add_exit_fee(previous_price)
# return np.log(previous_price) - np.log(current_price)
# return 0
# def update_portfolio_log_returns(self, action):
# self.portfolio_log_returns[self._current_tick] = self.most_recent_return(action)

View File

@@ -0,0 +1,396 @@
import importlib
import logging
from abc import abstractmethod
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
import gym
import numpy as np
import numpy.typing as npt
import pandas as pd
import torch as th
import torch.multiprocessing
from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import SubprocVecEnv
from freqtrade.exceptions import OperationalException
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
from freqtrade.freqai.RL.BaseEnvironment import Positions
from freqtrade.persistence import Trade
logger = logging.getLogger(__name__)
torch.multiprocessing.set_sharing_strategy('file_system')
SB3_MODELS = ['PPO', 'A2C', 'DQN']
SB3_CONTRIB_MODELS = ['TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO']
class BaseReinforcementLearningModel(IFreqaiModel):
"""
User created Reinforcement Learning Model prediction class
"""
def __init__(self, **kwargs) -> None:
super().__init__(config=kwargs['config'])
self.max_threads = min(self.freqai_info['rl_config'].get(
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
th.set_num_threads(self.max_threads)
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: Union[SubprocVecEnv, gym.Env] = None
self.eval_env: Union[SubprocVecEnv, gym.Env] = None
self.eval_callback: Optional[EvalCallback] = None
self.model_type = self.freqai_info['rl_config']['model_type']
self.rl_config = self.freqai_info['rl_config']
self.continual_learning = self.freqai_info.get('continual_learning', False)
if self.model_type in SB3_MODELS:
import_str = 'stable_baselines3'
elif self.model_type in SB3_CONTRIB_MODELS:
import_str = 'sb3_contrib'
else:
raise OperationalException(f'{self.model_type} not available in stable_baselines3 or '
f'sb3_contrib. please choose one of {SB3_MODELS} or '
f'{SB3_CONTRIB_MODELS}')
mod = importlib.import_module(import_str, self.model_type)
self.MODELCLASS = getattr(mod, self.model_type)
self.policy_type = self.freqai_info['rl_config']['policy_type']
self.unset_outlier_removal()
self.net_arch = self.rl_config.get('net_arch', [128, 128])
self.dd.model_type = "stable_baselines"
def unset_outlier_removal(self):
"""
If user has activated any function that may remove training points, this
function will set them to false and warn them
"""
if self.ft_params.get('use_SVM_to_remove_outliers', False):
self.ft_params.update({'use_SVM_to_remove_outliers': False})
logger.warning('User tried to use SVM with RL. Deactivating SVM.')
if self.ft_params.get('use_DBSCAN_to_remove_outliers', False):
self.ft_params.update({'use_DBSCAN_to_remove_outliers': False})
logger.warning('User tried to use DBSCAN with RL. Deactivating DBSCAN.')
if self.freqai_info['data_split_parameters'].get('shuffle', False):
self.freqai_info['data_split_parameters'].update({'shuffle': False})
logger.warning('User tried to shuffle training data. Setting shuffle to False')
def train(
self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
) -> Any:
"""
Filter the training data and train a model to it. Train makes heavy use of the datakitchen
for storing, saving, loading, and analyzing the data.
:param unfiltered_df: Full dataframe for the current training period
:param metadata: pair metadata from strategy.
:returns:
:model: Trained model which can be used to inference (self.predict)
"""
logger.info("--------------------Starting training " f"{pair} --------------------")
features_filtered, labels_filtered = dk.filter_features(
unfiltered_df,
dk.training_features_list,
dk.label_list,
training_filter=True,
)
data_dictionary: Dict[str, Any] = dk.make_train_test_datasets(
features_filtered, labels_filtered)
dk.fit_labels() # FIXME useless for now, but just satiating append methods
# normalize all data based on train_dataset only
prices_train, prices_test = self.build_ohlc_price_dataframes(dk.data_dictionary, pair, dk)
data_dictionary = dk.normalize_data(data_dictionary)
# data cleaning/analysis
self.data_cleaning_train(dk)
logger.info(
f'Training model on {len(dk.data_dictionary["train_features"].columns)}'
f' features and {len(data_dictionary["train_features"])} data points'
)
self.set_train_and_eval_environments(data_dictionary, prices_train, prices_test, dk)
model = self.fit(data_dictionary, dk)
logger.info(f"--------------------done training {pair}--------------------")
return model
def set_train_and_eval_environments(self, data_dictionary: Dict[str, DataFrame],
prices_train: DataFrame, prices_test: DataFrame,
dk: FreqaiDataKitchen):
"""
User can override this if they are using a custom MyRLEnv
:param data_dictionary: dict = common data dictionary containing train and test
features/labels/weights.
:param prices_train/test: DataFrame = dataframe comprised of the prices to be used in the
environment during training or testing
:param dk: FreqaiDataKitchen = the datakitchen for the current pair
"""
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
self.train_env = self.MyRLEnv(df=train_df,
prices=prices_train,
window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params,
config=self.config,
dp=self.data_provider)
self.eval_env = Monitor(self.MyRLEnv(df=test_df,
prices=prices_test,
window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params,
config=self.config,
dp=self.data_provider))
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path))
@abstractmethod
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
"""
Agent customizations and abstract Reinforcement Learning customizations
go in here. Abstract method, so this function must be overridden by
user class.
"""
return
def get_state_info(self, pair: str) -> Tuple[float, float, int]:
"""
State info during dry/live (not backtesting) which is fed back
into the model.
:param pair: str = COIN/STAKE to get the environment information for
:return:
:market_side: float = representing short, long, or neutral for
pair
:current_profit: float = unrealized profit of the current trade
:trade_duration: int = the number of candles that the trade has
been open for
"""
open_trades = Trade.get_trades_proxy(is_open=True)
market_side = 0.5
current_profit: float = 0
trade_duration = 0
for trade in open_trades:
if trade.pair == pair:
if self.data_provider._exchange is None: # type: ignore
logger.error('No exchange available.')
return 0, 0, 0
else:
current_rate = self.data_provider._exchange.get_rate( # type: ignore
pair, refresh=False, side="exit", is_short=trade.is_short)
now = datetime.now(timezone.utc).timestamp()
trade_duration = int((now - trade.open_date_utc.timestamp()) / self.base_tf_seconds)
current_profit = trade.calc_profit_ratio(current_rate)
return market_side, current_profit, int(trade_duration)
def predict(
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
"""
Filter the prediction features data and predict with it.
:param unfiltered_dataframe: Full dataframe for the current backtest period.
:return:
:pred_df: dataframe containing the predictions
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
data (NaNs) or felt uncertain about data (PCA and DI index)
"""
dk.find_features(unfiltered_df)
filtered_dataframe, _ = dk.filter_features(
unfiltered_df, dk.training_features_list, training_filter=False
)
filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe)
dk.data_dictionary["prediction_features"] = filtered_dataframe
# optional additional data cleaning/analysis
self.data_cleaning_predict(dk)
pred_df = self.rl_model_predict(
dk.data_dictionary["prediction_features"], dk, self.model)
pred_df.fillna(0, inplace=True)
return (pred_df, dk.do_predict)
def rl_model_predict(self, dataframe: DataFrame,
dk: FreqaiDataKitchen, model: Any) -> DataFrame:
"""
A helper function to make predictions in the Reinforcement learning module.
:param dataframe: DataFrame = the dataframe of features to make the predictions on
:param dk: FreqaiDatakitchen = data kitchen for the current pair
:param model: Any = the trained model used to inference the features.
"""
output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
def _predict(window):
observations = dataframe.iloc[window.index]
if self.live and self.rl_config.get('add_state_info', False):
market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
observations['current_profit_pct'] = current_profit
observations['position'] = market_side
observations['trade_duration'] = trade_duration
res, _ = model.predict(observations, deterministic=True)
return res
output = output.rolling(window=self.CONV_WIDTH).apply(_predict)
return output
def build_ohlc_price_dataframes(self, data_dictionary: dict,
pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame,
DataFrame]:
"""
Builds the train prices and test prices for the environment.
"""
pair = pair.replace(':', '')
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
# price data for model training and evaluation
tf = self.config['timeframe']
ohlc_list = [f'%-{pair}raw_open_{tf}', f'%-{pair}raw_low_{tf}',
f'%-{pair}raw_high_{tf}', f'%-{pair}raw_close_{tf}']
rename_dict = {f'%-{pair}raw_open_{tf}': 'open', f'%-{pair}raw_low_{tf}': 'low',
f'%-{pair}raw_high_{tf}': ' high', f'%-{pair}raw_close_{tf}': 'close'}
prices_train = train_df.filter(ohlc_list, axis=1)
if prices_train.empty:
raise OperationalException('Reinforcement learning module didnt find the raw prices '
'assigned in populate_any_indicators. Please assign them '
'with:\n'
'informative[f"%-{pair}raw_close"] = informative["close"]\n'
'informative[f"%-{pair}raw_open"] = informative["open"]\n'
'informative[f"%-{pair}raw_high"] = informative["high"]\n'
'informative[f"%-{pair}raw_low"] = informative["low"]\n')
prices_train.rename(columns=rename_dict, inplace=True)
prices_train.reset_index(drop=True)
prices_test = test_df.filter(ohlc_list, axis=1)
prices_test.rename(columns=rename_dict, inplace=True)
prices_test.reset_index(drop=True)
return prices_train, prices_test
def load_model_from_disk(self, dk: FreqaiDataKitchen) -> Any:
"""
Can be used by user if they are trying to limit_ram_usage *and*
perform continual learning.
For now, this is unused.
"""
exists = Path(dk.data_path / f"{dk.model_filename}_model").is_file()
if exists:
model = self.MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")
else:
logger.info('No model file on disk to continue learning from.')
return model
def _on_stop(self):
"""
Hook called on bot shutdown. Close SubprocVecEnv subprocesses for clean shutdown.
"""
if self.train_env:
self.train_env.close()
if self.eval_env:
self.eval_env.close()
# Nested class which can be overridden by user to customize further
class MyRLEnv(Base5ActionRLEnv):
"""
User can override any function in BaseRLEnv and gym.Env. Here the user
sets a custom reward based on profit and trade duration.
"""
def calculate_reward(self, action: int) -> float:
"""
An example reward function. This is the one function that users will likely
wish to inject their own creativity into.
:param action: int = The action made by the agent for the current candle.
:return:
float = the reward to give to the agent for current step (used for optimization
of weights in NN)
"""
# first, penalize if the action is not valid
if not self._is_valid(action):
return -2
pnl = self.get_unrealized_profit()
factor = 100.
# reward agent for entering trades
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
and self._position == Positions.Neutral):
return 25
# discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral:
return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
if self._last_trade_tick:
trade_duration = self._current_tick - self._last_trade_tick
else:
trade_duration = 0
if trade_duration <= max_trade_duration:
factor *= 1.5
elif trade_duration > max_trade_duration:
factor *= 0.5
# discourage sitting in position
if (self._position in (Positions.Short, Positions.Long) and
action == Actions.Neutral.value):
return -1 * trade_duration / max_trade_duration
# close long
if action == Actions.Long_exit.value and self._position == Positions.Long:
if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
return float(pnl * factor)
# close short
if action == Actions.Short_exit.value and self._position == Positions.Short:
if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
return float(pnl * factor)
return 0.
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
seed: int, train_df: DataFrame, price: DataFrame,
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
config: Dict[str, Any] = {}) -> Callable:
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environment you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
:return: (Callable)
"""
def _init() -> gym.Env:
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
reward_kwargs=reward_params, id=env_id, seed=seed + rank, config=config)
if monitor:
env = Monitor(env)
return env
set_random_seed(seed)
return _init

View File

View File

@@ -1,4 +1,5 @@
import collections
import importlib
import logging
import re
import shutil
@@ -99,6 +100,7 @@ class FreqaiDataDrawer:
self.empty_pair_dict: pair_info = {
"model_filename": "", "trained_timestamp": 0,
"data_path": "", "extras": {}}
self.model_type = self.freqai_info.get('model_save_type', 'joblib')
def update_metric_tracker(self, metric: str, value: float, pair: str) -> None:
"""
@@ -497,10 +499,12 @@ class FreqaiDataDrawer:
save_path = Path(dk.data_path)
# Save the trained model
if not dk.keras:
if self.model_type == 'joblib':
dump(model, save_path / f"{dk.model_filename}_model.joblib")
else:
elif self.model_type == 'keras':
model.save(save_path / f"{dk.model_filename}_model.h5")
elif 'stable_baselines' in self.model_type:
model.save(save_path / f"{dk.model_filename}_model.zip")
if dk.svm_model is not None:
dump(dk.svm_model, save_path / f"{dk.model_filename}_svm_model.joblib")
@@ -527,11 +531,10 @@ class FreqaiDataDrawer:
dk.pca, open(dk.data_path / f"{dk.model_filename}_pca_object.pkl", "wb")
)
# if self.live:
# store as much in ram as possible to increase performance
self.model_dictionary[coin] = model
self.pair_dict[coin]["model_filename"] = dk.model_filename
self.pair_dict[coin]["data_path"] = str(dk.data_path)
if coin not in self.meta_data_dictionary:
self.meta_data_dictionary[coin] = {}
self.meta_data_dictionary[coin]["train_df"] = dk.data_dictionary["train_features"]
@@ -563,14 +566,6 @@ class FreqaiDataDrawer:
if dk.live:
dk.model_filename = self.pair_dict[coin]["model_filename"]
dk.data_path = Path(self.pair_dict[coin]["data_path"])
if self.freqai_info.get("follow_mode", False):
# follower can be on a different system which is rsynced from the leader:
dk.data_path = Path(
self.config["user_data_dir"]
/ "models"
/ dk.data_path.parts[-2]
/ dk.data_path.parts[-1]
)
if coin in self.meta_data_dictionary:
dk.data = self.meta_data_dictionary[coin]["meta_data"]
@@ -589,12 +584,16 @@ class FreqaiDataDrawer:
# try to access model in memory instead of loading object from disk to save time
if dk.live and coin in self.model_dictionary:
model = self.model_dictionary[coin]
elif not dk.keras:
elif self.model_type == 'joblib':
model = load(dk.data_path / f"{dk.model_filename}_model.joblib")
else:
elif self.model_type == 'keras':
from tensorflow import keras
model = keras.models.load_model(dk.data_path / f"{dk.model_filename}_model.h5")
elif self.model_type == 'stable_baselines':
mod = importlib.import_module(
'stable_baselines3', self.freqai_info['rl_config']['model_type'])
MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type'])
model = MODELCLASS.load(dk.data_path / f"{dk.model_filename}_model")
if Path(dk.data_path / f"{dk.model_filename}_svm_model.joblib").is_file():
dk.svm_model = load(dk.data_path / f"{dk.model_filename}_svm_model.joblib")
@@ -604,6 +603,10 @@ class FreqaiDataDrawer:
f"Unable to load model, ensure model exists at " f"{dk.data_path} "
)
# load it into ram if it was loaded from disk
if coin not in self.model_dictionary:
self.model_dictionary[coin] = model
if self.config["freqai"]["feature_parameters"]["principal_component_analysis"]:
dk.pca = cloudpickle.load(
open(dk.data_path / f"{dk.model_filename}_pca_object.pkl", "rb")

View File

@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Tuple
import numpy as np
import numpy.typing as npt
import pandas as pd
import psutil
from pandas import DataFrame, HDFStore
from scipy import stats
from sklearn import linear_model
@@ -98,7 +99,10 @@ class FreqaiDataKitchen:
)
self.data['extra_returns_per_train'] = self.freqai_config.get('extra_returns_per_train', {})
self.thread_count = self.freqai_config.get("data_kitchen_thread_count", -1)
if not self.freqai_config.get("data_kitchen_thread_count", 0):
self.thread_count = max(int(psutil.cpu_count() * 2 - 2), 1)
else:
self.thread_count = self.freqai_config["data_kitchen_thread_count"]
self.train_dates: DataFrame = pd.DataFrame()
self.unique_classes: Dict[str, list] = {}
self.unique_class_list: list = []

View File

@@ -5,15 +5,17 @@ from abc import ABC, abstractmethod
from collections import deque
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Literal, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple
import numpy as np
import pandas as pd
import psutil
from numpy.typing import NDArray
from pandas import DataFrame
from freqtrade.configuration import TimeRange
from freqtrade.constants import Config
from freqtrade.data.dataprovider import DataProvider
from freqtrade.enums import RunMode
from freqtrade.exceptions import OperationalException
from freqtrade.exchange import timeframe_to_seconds
@@ -101,6 +103,8 @@ class IFreqaiModel(ABC):
self._threads: List[threading.Thread] = []
self._stop_event = threading.Event()
self.metadata = self.dd.load_global_metadata_from_disk()
self.data_provider: Optional[DataProvider] = None
self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1)
record_params(config, self.full_path)
@@ -129,6 +133,7 @@ class IFreqaiModel(ABC):
self.live = strategy.dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
self.dd.set_pair_dict_info(metadata)
self.data_provider = strategy.dp
if self.live:
self.inference_timer('start')
@@ -175,6 +180,13 @@ class IFreqaiModel(ABC):
self.model = None
self.dk = None
def _on_stop(self):
"""
Callback for Subclasses to override to include logic for shutting down resources
when SIGINT is sent.
"""
return
def shutdown(self):
"""
Cleans up threads on Shutdown, set stop event. Join threads to wait
@@ -183,6 +195,9 @@ class IFreqaiModel(ABC):
logger.info("Stopping FreqAI")
self._stop_event.set()
self.data_provider = None
self._on_stop()
logger.info("Waiting on Training iteration")
for _thread in self._threads:
_thread.join()
@@ -663,7 +678,7 @@ class IFreqaiModel(ABC):
hist_preds_df['DI_values'] = 0
for return_str in dk.data['extra_returns_per_train']:
hist_preds_df[return_str] = 0
hist_preds_df[return_str] = dk.data['extra_returns_per_train'][return_str]
hist_preds_df['close_price'] = strat_df['close']
hist_preds_df['date_pred'] = strat_df['date']

View File

@@ -0,0 +1,141 @@
import logging
from pathlib import Path
from typing import Any, Dict
import torch as th
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
logger = logging.getLogger(__name__)
class ReinforcementLearner(BaseReinforcementLearningModel):
"""
Reinforcement Learning Model prediction model.
Users can inherit from this class to make their own RL model with custom
environment/training controls. Define the file as follows:
```
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
class MyCoolRLModel(ReinforcementLearner):
```
Save the file to `user_data/freqaimodels`, then run it with:
freqtrade trade --freqaimodel MyCoolRLModel --config config.json --strategy SomeCoolStrat
Here the users can override any of the functions
available in the `IFreqaiModel` inheritance tree. Most importantly for RL, this
is where the user overrides `MyRLEnv` (see below), to define custom
`calculate_reward()` function, or to override any other parts of the environment.
This class also allows users to override any other part of the IFreqaiModel tree.
For example, the user can override `def fit()` or `def train()` or `def predict()`
to take fine-tuned control over these processes.
Another common override may be `def data_cleaning_predict()` where the user can
take fine-tuned control over the data handling pipeline.
"""
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
"""
User customizable fit method
:param data_dictionary: dict = common data dictionary containing all train/test
features/labels/weights.
:param dk: FreqaiDatakitchen = data kitchen for current pair.
:return:
model Any = trained model to be used for inference in dry/live/backtesting
"""
train_df = data_dictionary["train_features"]
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=self.net_arch)
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
tensorboard_log=Path(
dk.full_path / "tensorboard" / dk.pair.split('/')[0]),
**self.freqai_info['model_training_parameters']
)
else:
logger.info('Continual training activated - starting training from previously '
'trained agent.')
model = self.dd.model_dictionary[dk.pair]
model.set_env(self.train_env)
model.learn(
total_timesteps=int(total_timesteps),
callback=self.eval_callback
)
if Path(dk.data_path / "best_model.zip").is_file():
logger.info('Callback found a best model.')
best_model = self.MODELCLASS.load(dk.data_path / "best_model")
return best_model
logger.info('Couldnt find best model, using final model instead.')
return model
class MyRLEnv(Base5ActionRLEnv):
"""
User can override any function in BaseRLEnv and gym.Env. Here the user
sets a custom reward based on profit and trade duration.
"""
def calculate_reward(self, action: int) -> float:
"""
An example reward function. This is the one function that users will likely
wish to inject their own creativity into.
:param action: int = The action made by the agent for the current candle.
:return:
float = the reward to give to the agent for current step (used for optimization
of weights in NN)
"""
# first, penalize if the action is not valid
if not self._is_valid(action):
return -2
pnl = self.get_unrealized_profit()
factor = 100.
# reward agent for entering trades
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
and self._position == Positions.Neutral):
return 25
# discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral:
return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
trade_duration = self._current_tick - self._last_trade_tick # type: ignore
if trade_duration <= max_trade_duration:
factor *= 1.5
elif trade_duration > max_trade_duration:
factor *= 0.5
# discourage sitting in position
if (self._position in (Positions.Short, Positions.Long) and
action == Actions.Neutral.value):
return -1 * trade_duration / max_trade_duration
# close long
if action == Actions.Long_exit.value and self._position == Positions.Long:
if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
return float(pnl * factor)
# close short
if action == Actions.Short_exit.value and self._position == Positions.Short:
if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
return float(pnl * factor)
return 0.

View File

@@ -0,0 +1,51 @@
import logging
from typing import Any, Dict # , Tuple
# import numpy.typing as npt
from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import SubprocVecEnv
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
logger = logging.getLogger(__name__)
class ReinforcementLearner_multiproc(ReinforcementLearner):
"""
Demonstration of how to build vectorized environments
"""
def set_train_and_eval_environments(self, data_dictionary: Dict[str, Any],
prices_train: DataFrame, prices_test: DataFrame,
dk: FreqaiDataKitchen):
"""
User can override this if they are using a custom MyRLEnv
:param data_dictionary: dict = common data dictionary containing train and test
features/labels/weights.
:param prices_train/test: DataFrame = dataframe comprised of the prices to be used in
the environment during training
or testing
:param dk: FreqaiDataKitchen = the datakitchen for the current pair
"""
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
env_id = "train_env"
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
self.reward_params, self.CONV_WIDTH, monitor=True,
config=self.config) for i
in range(self.max_threads)])
eval_env_id = 'eval_env'
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
test_df, prices_test,
self.reward_params, self.CONV_WIDTH, monitor=True,
config=self.config) for i
in range(self.max_threads)])
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path))

View File

@@ -191,10 +191,10 @@ class FreqtradeBot(LoggingMixin):
# Check whether markets have to be reloaded and reload them when it's needed
self.exchange.reload_markets()
self.update_closed_trades_without_assigned_fees()
self.update_trades_without_assigned_fees()
# Query trades from persistence layer
trades = Trade.get_open_trades()
trades: List[Trade] = Trade.get_open_trades()
self.active_pair_whitelist = self._refresh_active_whitelist(trades)
@@ -354,7 +354,7 @@ class FreqtradeBot(LoggingMixin):
if self.trading_mode == TradingMode.FUTURES:
self._schedule.run_pending()
def update_closed_trades_without_assigned_fees(self) -> None:
def update_trades_without_assigned_fees(self) -> None:
"""
Update closed trades without close fees assigned.
Only acts when Orders are in the database, otherwise the last order-id is unknown.
@@ -381,15 +381,16 @@ class FreqtradeBot(LoggingMixin):
trades = Trade.get_open_trades_without_assigned_fees()
for trade in trades:
if trade.is_open and not trade.fee_updated(trade.entry_side):
order = trade.select_order(trade.entry_side, False)
open_order = trade.select_order(trade.entry_side, True)
if order and open_order is None:
logger.info(
f"Updating {trade.entry_side}-fee on trade {trade}"
f"for order {order.order_id}."
)
self.update_trade_state(trade, order.order_id, send_msg=False)
with self._exit_lock:
if trade.is_open and not trade.fee_updated(trade.entry_side):
order = trade.select_order(trade.entry_side, False)
open_order = trade.select_order(trade.entry_side, True)
if order and open_order is None:
logger.info(
f"Updating {trade.entry_side}-fee on trade {trade}"
f"for order {order.order_id}."
)
self.update_trade_state(trade, order.order_id, send_msg=False)
def handle_insufficient_funds(self, trade: Trade):
"""
@@ -826,6 +827,8 @@ class FreqtradeBot(LoggingMixin):
co = self.exchange.cancel_stoploss_order_with_result(
trade.stoploss_order_id, trade.pair, trade.amount)
trade.update_order(co)
# Reset stoploss order id.
trade.stoploss_order_id = None
except InvalidOrderException:
logger.exception(f"Could not cancel stoploss order {trade.stoploss_order_id}")
return trade
@@ -982,7 +985,7 @@ class FreqtradeBot(LoggingMixin):
# SELL / exit positions / close trades logic and methods
#
def exit_positions(self, trades: List[Any]) -> int:
def exit_positions(self, trades: List[Trade]) -> int:
"""
Tries to execute exit orders for open trades (positions)
"""
@@ -1010,7 +1013,7 @@ class FreqtradeBot(LoggingMixin):
def handle_trade(self, trade: Trade) -> bool:
"""
Sells/exits_short the current pair if the threshold is reached and updates the trade record.
Exits the current pair if the threshold is reached and updates the trade record.
:return: True if trade has been sold/exited_short, False otherwise
"""
if not trade.is_open:
@@ -1148,7 +1151,7 @@ class FreqtradeBot(LoggingMixin):
stoploss = (
self.edge.stoploss(pair=trade.pair)
if self.edge else
self.strategy.stoploss / trade.leverage
trade.stop_loss_pct / trade.leverage
)
if trade.is_short:
stop_price = trade.open_rate * (1 - stoploss)
@@ -1167,7 +1170,6 @@ class FreqtradeBot(LoggingMixin):
if self.create_stoploss_order(trade=trade, stop_price=trade.stoploss_or_liquidation):
return False
else:
trade.stoploss_order_id = None
logger.warning('Stoploss order was cancelled, but unable to recreate one.')
# Finally we check if stoploss on exchange should be moved up because of trailing.

View File

@@ -692,10 +692,11 @@ class Backtesting:
trade.orders.append(order)
return trade
def _get_exit_trade_entry(self, trade: LocalTrade, row: Tuple) -> Optional[LocalTrade]:
def _get_exit_trade_entry(
self, trade: LocalTrade, row: Tuple, is_first: bool) -> Optional[LocalTrade]:
exit_candle_time: datetime = row[DATE_IDX].to_pydatetime()
if self.trading_mode == TradingMode.FUTURES:
if is_first and self.trading_mode == TradingMode.FUTURES:
trade.funding_fees = self.exchange.calculate_funding_fees(
self.futures_data[trade.pair],
amount=trade.amount,
@@ -704,32 +705,7 @@ class Backtesting:
close_date=exit_candle_time,
)
if self.timeframe_detail and trade.pair in self.detail_data:
exit_candle_end = exit_candle_time + timedelta(minutes=self.timeframe_min)
detail_data = self.detail_data[trade.pair]
detail_data = detail_data.loc[
(detail_data['date'] >= exit_candle_time) &
(detail_data['date'] < exit_candle_end)
].copy()
if len(detail_data) == 0:
# Fall back to "regular" data if no detail data was found for this candle
return self._get_exit_trade_entry_for_candle(trade, row)
detail_data.loc[:, 'enter_long'] = row[LONG_IDX]
detail_data.loc[:, 'exit_long'] = row[ELONG_IDX]
detail_data.loc[:, 'enter_short'] = row[SHORT_IDX]
detail_data.loc[:, 'exit_short'] = row[ESHORT_IDX]
detail_data.loc[:, 'enter_tag'] = row[ENTER_TAG_IDX]
detail_data.loc[:, 'exit_tag'] = row[EXIT_TAG_IDX]
for det_row in detail_data[HEADERS].values.tolist():
res = self._get_exit_trade_entry_for_candle(trade, det_row)
if res:
return res
return None
else:
return self._get_exit_trade_entry_for_candle(trade, row)
return self._get_exit_trade_entry_for_candle(trade, row)
def get_valid_price_and_stake(
self, pair: str, row: Tuple, propose_rate: float, stake_amount: float,
@@ -1074,7 +1050,7 @@ class Backtesting:
def backtest_loop(
self, row: Tuple, pair: str, current_time: datetime, end_date: datetime,
max_open_trades: int, open_trade_count_start: int) -> int:
max_open_trades: int, open_trade_count_start: int, is_first: bool = True) -> int:
"""
NOTE: This method is used by Hyperopt at each iteration. Please keep it optimized.
@@ -1092,9 +1068,11 @@ class Backtesting:
# without positionstacking, we can only have one open trade per pair.
# max_open_trades must be respected
# don't open on the last row
# We only open trades on the main candle, not on detail candles
trade_dir = self.check_for_trade_entry(row)
if (
(self._position_stacking or len(LocalTrade.bt_trades_open_pp[pair]) == 0)
and is_first
and self.trade_slot_available(max_open_trades, open_trade_count_start)
and current_time != end_date
and trade_dir is not None
@@ -1120,7 +1098,7 @@ class Backtesting:
# 4. Create exit orders (if any)
if not trade.open_order_id:
self._get_exit_trade_entry(trade, row) # Place exit order if necessary
self._get_exit_trade_entry(trade, row, is_first) # Place exit order if necessary
# 5. Process exit orders.
order = trade.select_order(trade.exit_side, is_open=True)
@@ -1171,7 +1149,6 @@ class Backtesting:
self.progress.init_step(BacktestState.BACKTEST, int(
(end_date - start_date) / timedelta(minutes=self.timeframe_min)))
# Loop timerange and get candle for each pair at that point in time
while current_time <= end_date:
open_trade_count_start = LocalTrade.bt_open_open_trade_count
@@ -1185,9 +1162,37 @@ class Backtesting:
row_index += 1
indexes[pair] = row_index
self.dataprovider._set_dataframe_max_index(row_index)
current_detail_time: datetime = row[DATE_IDX].to_pydatetime()
if self.timeframe_detail and pair in self.detail_data:
exit_candle_end = current_detail_time + timedelta(minutes=self.timeframe_min)
open_trade_count_start = self.backtest_loop(
row, pair, current_time, end_date, max_open_trades, open_trade_count_start)
detail_data = self.detail_data[pair]
detail_data = detail_data.loc[
(detail_data['date'] >= current_detail_time) &
(detail_data['date'] < exit_candle_end)
].copy()
if len(detail_data) == 0:
# Fall back to "regular" data if no detail data was found for this candle
open_trade_count_start = self.backtest_loop(
row, pair, current_time, end_date, max_open_trades,
open_trade_count_start)
detail_data.loc[:, 'enter_long'] = row[LONG_IDX]
detail_data.loc[:, 'exit_long'] = row[ELONG_IDX]
detail_data.loc[:, 'enter_short'] = row[SHORT_IDX]
detail_data.loc[:, 'exit_short'] = row[ESHORT_IDX]
detail_data.loc[:, 'enter_tag'] = row[ENTER_TAG_IDX]
detail_data.loc[:, 'exit_tag'] = row[EXIT_TAG_IDX]
is_first = True
current_time_det = current_time
for det_row in detail_data[HEADERS].values.tolist():
open_trade_count_start = self.backtest_loop(
det_row, pair, current_time_det, end_date, max_open_trades,
open_trade_count_start, is_first)
current_time_det += timedelta(minutes=self.timeframe_detail_min)
is_first = False
else:
open_trade_count_start = self.backtest_loop(
row, pair, current_time, end_date, max_open_trades, open_trade_count_start)
# Move time one configured time_interval ahead.
self.progress.increment()

View File

@@ -17,6 +17,7 @@ from freqtrade.enums import HyperoptState
from freqtrade.exceptions import OperationalException
from freqtrade.misc import deep_merge_dicts, round_coin_value, round_dict, safe_value_fallback2
from freqtrade.optimize.hyperopt_epoch_filters import hyperopt_filter_epochs
from freqtrade.optimize.optimize_reports import generate_wins_draws_losses
logger = logging.getLogger(__name__)
@@ -325,8 +326,10 @@ class HyperoptTools():
# New mode, using backtest result for metrics
trials['results_metrics.winsdrawslosses'] = trials.apply(
lambda x: f"{x['results_metrics.wins']} {x['results_metrics.draws']:>4} "
f"{x['results_metrics.losses']:>4}", axis=1)
lambda x: generate_wins_draws_losses(
x['results_metrics.wins'], x['results_metrics.draws'],
x['results_metrics.losses']
), axis=1)
trials = trials[['Best', 'current_epoch', 'results_metrics.total_trades',
'results_metrics.winsdrawslosses',
@@ -337,7 +340,7 @@ class HyperoptTools():
'loss', 'is_initial_point', 'is_random', 'is_best']]
trials.columns = [
'Best', 'Epoch', 'Trades', ' Win Draw Loss', 'Avg profit',
'Best', 'Epoch', 'Trades', ' Win Draw Loss Win%', 'Avg profit',
'Total profit', 'Profit', 'Avg duration', 'max_drawdown', 'max_drawdown_account',
'max_drawdown_abs', 'Objective', 'is_initial_point', 'is_random', 'is_best'
]
@@ -467,9 +470,9 @@ class HyperoptTools():
base_metrics = ['Best', 'current_epoch', 'results_metrics.total_trades',
'results_metrics.profit_mean', 'results_metrics.profit_median',
'results_metrics.profit_total',
'Stake currency',
'results_metrics.profit_total', 'Stake currency',
'results_metrics.profit_total_abs', 'results_metrics.holding_avg',
'results_metrics.trade_count_long', 'results_metrics.trade_count_short',
'loss', 'is_initial_point', 'is_best']
perc_multi = 100
@@ -477,7 +480,9 @@ class HyperoptTools():
trials = trials[base_metrics + param_metrics]
base_columns = ['Best', 'Epoch', 'Trades', 'Avg profit', 'Median profit', 'Total profit',
'Stake currency', 'Profit', 'Avg duration', 'Objective',
'Stake currency', 'Profit', 'Avg duration',
'Trade count long', 'Trade count short',
'Objective',
'is_initial_point', 'is_best']
param_columns = list(results[0]['params_dict'].keys())
trials.columns = base_columns + param_columns

View File

@@ -86,7 +86,7 @@ def _get_line_header(first_column: str, stake_currency: str,
'Win Draw Loss Win%']
def _generate_wins_draws_losses(wins, draws, losses):
def generate_wins_draws_losses(wins, draws, losses):
if wins > 0 and losses == 0:
wl_ratio = '100'
elif wins == 0:
@@ -600,7 +600,7 @@ def text_table_bt_results(pair_results: List[Dict[str, Any]], stake_currency: st
output = [[
t['key'], t['trades'], t['profit_mean_pct'], t['profit_sum_pct'], t['profit_total_abs'],
t['profit_total_pct'], t['duration_avg'],
_generate_wins_draws_losses(t['wins'], t['draws'], t['losses'])
generate_wins_draws_losses(t['wins'], t['draws'], t['losses'])
] for t in pair_results]
# Ignore type as floatfmt does allow tuples but mypy does not know that
return tabulate(output, headers=headers,
@@ -626,7 +626,7 @@ def text_table_exit_reason(exit_reason_stats: List[Dict[str, Any]], stake_curren
output = [[
t.get('exit_reason', t.get('sell_reason')), t['trades'],
_generate_wins_draws_losses(t['wins'], t['draws'], t['losses']),
generate_wins_draws_losses(t['wins'], t['draws'], t['losses']),
t['profit_mean_pct'], t['profit_sum_pct'],
round_coin_value(t['profit_total_abs'], stake_currency, False),
t['profit_total_pct'],
@@ -656,7 +656,7 @@ def text_table_tags(tag_type: str, tag_results: List[Dict[str, Any]], stake_curr
t['profit_total_abs'],
t['profit_total_pct'],
t['duration_avg'],
_generate_wins_draws_losses(
generate_wins_draws_losses(
t['wins'],
t['draws'],
t['losses'])] for t in tag_results]
@@ -715,7 +715,7 @@ def text_table_strategy(strategy_results, stake_currency: str) -> str:
output = [[
t['key'], t['trades'], t['profit_mean_pct'], t['profit_sum_pct'], t['profit_total_abs'],
t['profit_total_pct'], t['duration_avg'],
_generate_wins_draws_losses(t['wins'], t['draws'], t['losses']), drawdown]
generate_wins_draws_losses(t['wins'], t['draws'], t['losses']), drawdown]
for t, drawdown in zip(strategy_results, drawdown)]
# Ignore type as floatfmt does allow tuples but mypy does not know that
return tabulate(output, headers=headers,

View File

@@ -87,7 +87,7 @@ class PairLocks():
Get the lock that expires the latest for the pair given.
"""
locks = PairLocks.get_pair_locks(pair, now, side=side)
locks = sorted(locks, key=lambda l: l.lock_end_time, reverse=True)
locks = sorted(locks, key=lambda lock: lock.lock_end_time, reverse=True)
return locks[0] if locks else None
@staticmethod

View File

@@ -81,8 +81,6 @@ async def validate_ws_token(
except HTTPException:
pass
# No checks passed, deny the connection
logger.debug("Denying websocket request.")
# If it doesn't match, close the websocket connection
await ws.close(code=status.WS_1008_POLICY_VIOLATION)

View File

@@ -1,16 +1,16 @@
import logging
import time
from typing import Any, Dict
from fastapi import APIRouter, Depends, WebSocketDisconnect
from fastapi.websockets import WebSocket, WebSocketState
from fastapi import APIRouter, Depends
from fastapi.websockets import WebSocket
from pydantic import ValidationError
from websockets.exceptions import WebSocketException
from freqtrade.enums import RPCMessageType, RPCRequestType
from freqtrade.rpc.api_server.api_auth import validate_ws_token
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
from freqtrade.rpc.api_server.ws import WebSocketChannel
from freqtrade.rpc.api_server.ws.channel import ChannelManager
from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel, create_channel
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
WSRequestSchema, WSWhitelistMessage)
from freqtrade.rpc.rpc import RPC
@@ -22,23 +22,35 @@ logger = logging.getLogger(__name__)
router = APIRouter()
async def is_websocket_alive(ws: WebSocket) -> bool:
async def channel_reader(channel: WebSocketChannel, rpc: RPC):
"""
Check if a FastAPI Websocket is still open
Iterate over the messages from the channel and process the request
"""
if (
ws.application_state == WebSocketState.CONNECTED and
ws.client_state == WebSocketState.CONNECTED
):
return True
return False
async for message in channel:
await _process_consumer_request(message, channel, rpc)
async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream):
"""
Iterate over messages in the message stream and send them
"""
async for message, ts in message_stream:
if channel.subscribed_to(message.get('type')):
# Log a warning if this channel is behind
# on the message stream by a lot
if (time.time() - ts) > 60:
logger.warning(f"Channel {channel} is behind MessageStream by 1 minute,"
" this can cause a memory leak if you see this message"
" often, consider reducing pair list size or amount of"
" consumers.")
await channel.send(message, timeout=True)
async def _process_consumer_request(
request: Dict[str, Any],
channel: WebSocketChannel,
rpc: RPC,
channel_manager: ChannelManager
rpc: RPC
):
"""
Validate and handle a request from a websocket consumer
@@ -74,65 +86,29 @@ async def _process_consumer_request(
# Format response
response = WSWhitelistMessage(data=whitelist)
# Send it back
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
await channel.send(response.dict(exclude_none=True))
elif type == RPCRequestType.ANALYZED_DF:
limit = None
if data:
# Limit the amount of candles per dataframe to 'limit' or 1500
limit = max(data.get('limit', 1500), 1500)
# Limit the amount of candles per dataframe to 'limit' or 1500
limit = min(data.get('limit', 1500), 1500) if data else None
# For every pair in the generator, send a separate message
for message in rpc._ws_request_analyzed_df(limit):
# Format response
response = WSAnalyzedDFMessage(data=message)
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
await channel.send(response.dict(exclude_none=True))
@router.websocket("/message/ws")
async def message_endpoint(
ws: WebSocket,
websocket: WebSocket,
token: str = Depends(validate_ws_token),
rpc: RPC = Depends(get_rpc),
channel_manager=Depends(get_channel_manager),
token: str = Depends(validate_ws_token)
message_stream: MessageStream = Depends(get_message_stream)
):
"""
Message WebSocket endpoint, facilitates sending RPC messages
"""
try:
channel = await channel_manager.on_connect(ws)
if await is_websocket_alive(ws):
logger.info(f"Consumer connected - {channel}")
# Keep connection open until explicitly closed, and process requests
try:
while not channel.is_closed():
request = await channel.recv()
# Process the request here
await _process_consumer_request(request, channel, rpc, channel_manager)
except (WebSocketDisconnect, WebSocketException):
# Handle client disconnects
logger.info(f"Consumer disconnected - {channel}")
except RuntimeError:
# Handle cases like -
# RuntimeError('Cannot call "send" once a closed message has been sent')
pass
except Exception as e:
logger.info(f"Consumer connection failed - {channel}: {e}")
logger.debug(e, exc_info=e)
except RuntimeError:
# WebSocket was closed
# Do nothing
pass
except Exception as e:
logger.error(f"Failed to serve - {ws.client}")
# Log tracebacks to keep track of what errors are happening
logger.exception(e)
finally:
if channel:
await channel_manager.on_disconnect(ws)
if token:
async with create_channel(websocket) as channel:
await channel.run_channel_tasks(
channel_reader(channel, rpc),
channel_broadcaster(channel, message_stream)
)

View File

@@ -41,8 +41,8 @@ def get_exchange(config=Depends(get_config)):
return ApiServer._exchange
def get_channel_manager():
return ApiServer._ws_channel_manager
def get_message_stream():
return ApiServer._message_stream
def is_webserver_mode(config=Depends(get_config)):

View File

@@ -1,22 +1,17 @@
import asyncio
import logging
from ipaddress import IPv4Address
from threading import Thread
from typing import Any, Dict, Optional
import orjson
import uvicorn
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
# Look into alternatives
from janus import Queue as ThreadedQueue
from starlette.responses import JSONResponse
from freqtrade.constants import Config
from freqtrade.exceptions import OperationalException
from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer
from freqtrade.rpc.api_server.ws import ChannelManager
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler
@@ -50,10 +45,8 @@ class ApiServer(RPCHandler):
_config: Config = {}
# Exchange - only available in webserver mode.
_exchange = None
# websocket message queue stuff
_ws_channel_manager: ChannelManager
_ws_thread = None
_ws_loop: Optional[asyncio.AbstractEventLoop] = None
# websocket message stuff
_message_stream: Optional[MessageStream] = None
def __new__(cls, *args, **kwargs):
"""
@@ -71,15 +64,11 @@ class ApiServer(RPCHandler):
return
self._standalone: bool = standalone
self._server = None
self._ws_queue: Optional[ThreadedQueue] = None
self._ws_background_task = None
ApiServer.__initialized = True
api_config = self._config['api_server']
ApiServer._ws_channel_manager = ChannelManager()
self.app = FastAPI(title="Freqtrade API",
docs_url='/docs' if api_config.get('enable_openapi', False) else None,
redoc_url=None,
@@ -105,21 +94,9 @@ class ApiServer(RPCHandler):
del ApiServer._rpc
if self._server and not self._standalone:
logger.info("Stopping API Server")
# self._server.force_exit, self._server.should_exit = True, True
self._server.cleanup()
if self._ws_thread and self._ws_loop:
logger.info("Stopping API Server background tasks")
if self._ws_background_task:
# Cancel the queue task
self._ws_background_task.cancel()
self._ws_thread.join()
self._ws_thread = None
self._ws_loop = None
self._ws_background_task = None
@classmethod
def shutdown(cls):
cls.__initialized = False
@@ -129,9 +106,11 @@ class ApiServer(RPCHandler):
cls._rpc = None
def send_msg(self, msg: Dict[str, Any]) -> None:
if self._ws_queue:
sync_q = self._ws_queue.sync_q
sync_q.put(msg)
"""
Publish the message to the message stream
"""
if ApiServer._message_stream:
ApiServer._message_stream.publish(msg)
def handle_rpc_exception(self, request, exc):
logger.exception(f"API Error calling: {exc}")
@@ -170,54 +149,30 @@ class ApiServer(RPCHandler):
)
app.add_exception_handler(RPCException, self.handle_rpc_exception)
app.add_event_handler(
event_type="startup",
func=self._api_startup_event
)
app.add_event_handler(
event_type="shutdown",
func=self._api_shutdown_event
)
def start_message_queue(self):
if self._ws_thread:
return
async def _api_startup_event(self):
"""
Creates the MessageStream class on startup
so it has access to the same event loop
as uvicorn
"""
if not ApiServer._message_stream:
ApiServer._message_stream = MessageStream()
# Create a new loop, as it'll be just for the background thread
self._ws_loop = asyncio.new_event_loop()
# Start the thread
self._ws_thread = Thread(target=self._ws_loop.run_forever)
self._ws_thread.start()
# Finally, submit the coro to the thread
self._ws_background_task = asyncio.run_coroutine_threadsafe(
self._broadcast_queue_data(), loop=self._ws_loop)
async def _broadcast_queue_data(self) -> None:
# Instantiate the queue in this coroutine so it's attached to our loop
self._ws_queue = ThreadedQueue()
async_queue = self._ws_queue.async_q
try:
while True:
logger.debug("Getting queue messages...")
if (qsize := async_queue.qsize()) > 20:
# If the queue becomes too big for too long, this may indicate a problem.
logger.warning(f"Queue size now {qsize}")
# Get data from queue
message: WSMessageSchemaType = await async_queue.get()
logger.debug(f"Found message of type: {message.get('type')}")
async_queue.task_done()
# Broadcast it
await self._ws_channel_manager.broadcast(message)
except asyncio.CancelledError:
pass
# For testing, shouldn't happen when stable
except Exception as e:
logger.exception(f"Exception happened in background task: {e}")
finally:
# Disconnect channels and stop the loop on cancel
await self._ws_channel_manager.disconnect_all()
if self._ws_loop:
self._ws_loop.stop()
# Avoid adding more items to the queue if they aren't
# going to get broadcasted.
self._ws_queue = None
async def _api_shutdown_event(self):
"""
Removes the MessageStream class on shutdown
"""
if ApiServer._message_stream:
ApiServer._message_stream = None
def start_api(self):
"""
@@ -257,7 +212,6 @@ class ApiServer(RPCHandler):
if self._standalone:
self._server.run()
else:
self.start_message_queue()
self._server.run_in_thread()
except Exception:
logger.exception("Api server failed to start.")

View File

@@ -3,4 +3,5 @@
from freqtrade.rpc.api_server.ws.types import WebSocketType
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import HybridJSONWebSocketSerializer
from freqtrade.rpc.api_server.ws.channel import ChannelManager, WebSocketChannel
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
from freqtrade.rpc.api_server.ws.message_stream import MessageStream

View File

@@ -1,11 +1,13 @@
import asyncio
import logging
import time
from threading import RLock
from typing import Any, Dict, List, Optional, Type, Union
from collections import deque
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Deque, Dict, List, Optional, Type, Union
from uuid import uuid4
from fastapi import WebSocket as FastAPIWebSocket
from fastapi import WebSocketDisconnect
from websockets.exceptions import ConnectionClosed
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer,
@@ -21,31 +23,27 @@ class WebSocketChannel:
"""
Object to help facilitate managing a websocket connection
"""
def __init__(
self,
websocket: WebSocketType,
channel_id: Optional[str] = None,
drain_timeout: int = 3,
throttle: float = 0.01,
serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer
):
self.channel_id = channel_id if channel_id else uuid4().hex[:8]
# The WebSocket object
self._websocket = WebSocketProxy(websocket)
self.drain_timeout = drain_timeout
self.throttle = throttle
self._subscriptions: List[str] = []
# 32 is the size of the receiving queue in websockets package
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
self._relay_task = asyncio.create_task(self.relay())
# Internal event to signify a closed websocket
self._closed = asyncio.Event()
# The async tasks created for the channel
self._channel_tasks: List[asyncio.Task] = []
# Deque for average send times
self._send_times: Deque[float] = deque([], maxlen=10)
# High limit defaults to 3 to start
self._send_high_limit = 3
# The subscribed message types
self._subscriptions: List[str] = []
# Wrap the WebSocket in the Serializing class
self._wrapped_ws = serializer_cls(self._websocket)
@@ -61,43 +59,58 @@ class WebSocketChannel:
def remote_addr(self):
return self._websocket.remote_addr
async def _send(self, data):
"""
Send data on the wrapped websocket
"""
await self._wrapped_ws.send(data)
@property
def avg_send_time(self):
return sum(self._send_times) / len(self._send_times)
async def send(self, data) -> bool:
def _calc_send_limit(self):
"""
Add the data to the queue to be sent.
:returns: True if data added to queue, False otherwise
Calculate the send high limit for this channel
"""
# This block only runs if the queue is full, it will wait
# until self.drain_timeout for the relay to drain the outgoing queue
# We can't use asyncio.wait_for here because the queue may have been created with a
# different eventloop
if not self.is_closed():
start = time.time()
while self.queue.full():
await asyncio.sleep(1)
if (time.time() - start) > self.drain_timeout:
return False
# Only update if we have enough data
if len(self._send_times) == self._send_times.maxlen:
# At least 1s or twice the average of send times, with a
# maximum of 3 seconds per message
self._send_high_limit = min(max(self.avg_send_time * 2, 1), 3)
# If for some reason the queue is still full, just return False
try:
self.queue.put_nowait(data)
except asyncio.QueueFull:
return False
async def send(
self,
message: Union[WSMessageSchemaType, Dict[str, Any]],
timeout: bool = False
):
"""
Send a message on the wrapped websocket. If the sending
takes too long, it will raise a TimeoutError and
disconnect the connection.
# If we got here everything is ok
return True
else:
return False
:param message: The message to send
:param timeout: Enforce send high limit, defaults to False
"""
try:
_ = time.time()
# If the send times out, it will raise
# a TimeoutError and bubble up to the
# message_endpoint to close the connection
await asyncio.wait_for(
self._wrapped_ws.send(message),
timeout=self._send_high_limit if timeout else None
)
total_time = time.time() - _
self._send_times.append(total_time)
self._calc_send_limit()
except asyncio.TimeoutError:
logger.info(f"Connection for {self} timed out, disconnecting")
raise
# Explicitly give control back to event loop as
# websockets.send does not
await asyncio.sleep(0.01)
async def recv(self):
"""
Receive data on the wrapped websocket
Receive a message on the wrapped websocket
"""
return await self._wrapped_ws.recv()
@@ -107,17 +120,27 @@ class WebSocketChannel:
"""
return await self._websocket.ping()
async def accept(self):
"""
Accept the underlying websocket connection,
if the connection has been closed before we can
accept, just close the channel.
"""
try:
return await self._websocket.accept()
except RuntimeError:
await self.close()
async def close(self):
"""
Close the WebSocketChannel
"""
self._closed.set()
self._relay_task.cancel()
try:
await self.raw_websocket.close()
except Exception:
await self._websocket.close()
except RuntimeError:
pass
def is_closed(self) -> bool:
@@ -142,99 +165,76 @@ class WebSocketChannel:
"""
return message_type in self._subscriptions
async def relay(self):
async def run_channel_tasks(self, *tasks, **kwargs):
"""
Relay messages from the channel's queue and send them out. This is started
as a task.
Create and await on the channel tasks unless an exception
was raised, then cancel them all.
:params *tasks: All coros or tasks to be run concurrently
:param **kwargs: Any extra kwargs to pass to gather
"""
while not self._closed.is_set():
message = await self.queue.get()
if not self.is_closed():
# Wrap the coros into tasks if they aren't already
self._channel_tasks = [
task if isinstance(task, asyncio.Task) else asyncio.create_task(task)
for task in tasks
]
try:
await self._send(message)
self.queue.task_done()
return await asyncio.gather(*self._channel_tasks, **kwargs)
except Exception:
# If an exception occurred, cancel the rest of the tasks
await self.cancel_channel_tasks()
# Limit messages per sec.
# Could cause problems with queue size if too low, and
# problems with network traffik if too high.
# 0.01 = 100/s
await asyncio.sleep(self.throttle)
except RuntimeError:
# The connection was closed, just exit the task
return
class ChannelManager:
def __init__(self):
self.channels = dict()
self._lock = RLock() # Re-entrant Lock
async def on_connect(self, websocket: WebSocketType):
async def cancel_channel_tasks(self):
"""
Wrap websocket connection into Channel and add to list
:param websocket: The WebSocket object to attach to the Channel
Cancel and wait on all channel tasks
"""
if isinstance(websocket, FastAPIWebSocket):
for task in self._channel_tasks:
task.cancel()
# Wait for tasks to finish cancelling
try:
await websocket.accept()
except RuntimeError:
# The connection was closed before we could accept it
return
await task
except (
asyncio.CancelledError,
asyncio.TimeoutError,
WebSocketDisconnect,
ConnectionClosed,
RuntimeError
):
pass
except Exception as e:
logger.info(f"Encountered unknown exception: {e}", exc_info=e)
ws_channel = WebSocketChannel(websocket)
self._channel_tasks = []
with self._lock:
self.channels[websocket] = ws_channel
return ws_channel
async def on_disconnect(self, websocket: WebSocketType):
async def __aiter__(self):
"""
Call close on the channel if it's not, and remove from channel list
Generator for received messages
"""
# We can not catch any errors here as websocket.recv is
# the first to catch any disconnects and bubble it up
# so the connection is garbage collected right away
while not self.is_closed():
yield await self.recv()
:param websocket: The WebSocket objet attached to the Channel
"""
with self._lock:
channel = self.channels.get(websocket)
if channel:
logger.info(f"Disconnecting channel {channel}")
if not channel.is_closed():
await channel.close()
del self.channels[websocket]
@asynccontextmanager
async def create_channel(
websocket: WebSocketType,
**kwargs
) -> AsyncIterator[WebSocketChannel]:
"""
Context manager for safely opening and closing a WebSocketChannel
"""
channel = WebSocketChannel(websocket, **kwargs)
try:
await channel.accept()
logger.info(f"Connected to channel - {channel}")
async def disconnect_all(self):
"""
Disconnect all Channels
"""
with self._lock:
for websocket in self.channels.copy().keys():
await self.on_disconnect(websocket)
async def broadcast(self, message: WSMessageSchemaType):
"""
Broadcast a message on all Channels
:param message: The message to send
"""
with self._lock:
for channel in self.channels.copy().values():
if channel.subscribed_to(message.get('type')):
await self.send_direct(channel, message)
async def send_direct(
self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]):
"""
Send a message directly through direct_channel only
:param direct_channel: The WebSocketChannel object to send the message through
:param message: The message to send
"""
if not await channel.send(message):
await self.on_disconnect(channel.raw_websocket)
def has_channels(self):
"""
Flag for more than 0 channels
"""
return len(self.channels) > 0
yield channel
finally:
await channel.close()
logger.info(f"Disconnected from channel - {channel}")

View File

@@ -0,0 +1,31 @@
import asyncio
import time
class MessageStream:
"""
A message stream for consumers to subscribe to,
and for producers to publish to.
"""
def __init__(self):
self._loop = asyncio.get_running_loop()
self._waiter = self._loop.create_future()
def publish(self, message):
"""
Publish a message to this MessageStream
:param message: The message to publish
"""
waiter, self._waiter = self._waiter, self._loop.create_future()
waiter.set_result((message, time.time(), self._waiter))
async def __aiter__(self):
"""
Iterate over the messages in the message stream
"""
waiter = self._waiter
while True:
# Shield the future from being cancelled by a task waiting on it
message, ts, waiter = await asyncio.shield(waiter)
yield message, ts

View File

@@ -1,5 +1,6 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Union
import orjson
import rapidjson
@@ -7,6 +8,7 @@ from pandas import DataFrame
from freqtrade.misc import dataframe_to_json, json_to_dataframe
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
logger = logging.getLogger(__name__)
@@ -24,17 +26,13 @@ class WebSocketSerializer(ABC):
def _deserialize(self, data):
raise NotImplementedError()
async def send(self, data: bytes):
async def send(self, data: Union[WSMessageSchemaType, Dict[str, Any]]):
await self._websocket.send(self._serialize(data))
async def recv(self) -> bytes:
data = await self._websocket.recv()
return self._deserialize(data)
async def close(self, code: int = 1000):
await self._websocket.close(code)
class HybridJSONWebSocketSerializer(WebSocketSerializer):
def _serialize(self, data) -> str:

View File

@@ -31,6 +31,7 @@ class Producer(TypedDict):
name: str
host: str
port: int
secure: bool
ws_token: str
@@ -180,7 +181,8 @@ class ExternalMessageConsumer:
host, port = producer['host'], producer['port']
token = producer['ws_token']
name = producer['name']
ws_url = f"ws://{host}:{port}/api/v1/message/ws?token={token}"
scheme = 'wss' if producer.get('secure', False) else 'ws'
ws_url = f"{scheme}://{host}:{port}/api/v1/message/ws?token={token}"
# This will raise InvalidURI if the url is bad
async with websockets.connect(

View File

@@ -789,17 +789,18 @@ class RPC:
if not order_type:
order_type = self._freqtrade.strategy.order_types.get(
'force_entry', self._freqtrade.strategy.order_types['entry'])
if self._freqtrade.execute_entry(pair, stake_amount, price,
ordertype=order_type, trade=trade,
is_short=is_short,
enter_tag=enter_tag,
leverage_=leverage,
):
Trade.commit()
trade = Trade.get_trades([Trade.is_open.is_(True), Trade.pair == pair]).first()
return trade
else:
raise RPCException(f'Failed to enter position for {pair}.')
with self._freqtrade._exit_lock:
if self._freqtrade.execute_entry(pair, stake_amount, price,
ordertype=order_type, trade=trade,
is_short=is_short,
enter_tag=enter_tag,
leverage_=leverage,
):
Trade.commit()
trade = Trade.get_trades([Trade.is_open.is_(True), Trade.pair == pair]).first()
return trade
else:
raise RPCException(f'Failed to enter position for {pair}.')
def _rpc_delete(self, trade_id: int) -> Dict[str, Union[str, int]]:
"""

View File

@@ -19,7 +19,7 @@ class FreqaiExampleHybridStrategy(IStrategy):
Launching this strategy would be:
freqtrade trade --strategy FreqaiExampleHyridStrategy --strategy-path freqtrade/templates
freqtrade trade --strategy FreqaiExampleHybridStrategy --strategy-path freqtrade/templates
--freqaimodel CatboostClassifier --config config_examples/config_freqai.example.json
or the user simply adds this to their config:
@@ -86,7 +86,7 @@ class FreqaiExampleHybridStrategy(IStrategy):
process_only_new_candles = True
stoploss = -0.05
use_exit_signal = True
startup_candle_count: int = 300
startup_candle_count: int = 30
can_short = True
# Hyperoptable parameters

View File

@@ -328,7 +328,7 @@
"# Show graph inline\n",
"# graph.show()\n",
"\n",
"# Render graph in a seperate window\n",
"# Render graph in a separate window\n",
"graph.show(renderer=\"browser\")\n"
]
},