Merge branch 'develop' into list-models
This commit is contained in:
commit
dc50186d5b
@ -30,6 +30,14 @@ class CatboostClassifier(BaseClassifierModel):
|
|||||||
label=data_dictionary["train_labels"],
|
label=data_dictionary["train_labels"],
|
||||||
weight=data_dictionary["train_weights"],
|
weight=data_dictionary["train_weights"],
|
||||||
)
|
)
|
||||||
|
if self.freqai_info.get("data_split_parameters", {}).get("test_size", 0.1) == 0:
|
||||||
|
test_data = None
|
||||||
|
else:
|
||||||
|
test_data = Pool(
|
||||||
|
data=data_dictionary["test_features"],
|
||||||
|
label=data_dictionary["test_labels"],
|
||||||
|
weight=data_dictionary["test_weights"],
|
||||||
|
)
|
||||||
|
|
||||||
cbr = CatBoostClassifier(
|
cbr = CatBoostClassifier(
|
||||||
allow_writing_files=True,
|
allow_writing_files=True,
|
||||||
@ -40,6 +48,6 @@ class CatboostClassifier(BaseClassifierModel):
|
|||||||
|
|
||||||
init_model = self.get_init_model(dk.pair)
|
init_model = self.get_init_model(dk.pair)
|
||||||
|
|
||||||
cbr.fit(train_data, init_model=init_model)
|
cbr.fit(X=train_data, eval_set=test_data, init_model=init_model)
|
||||||
|
|
||||||
return cbr
|
return cbr
|
||||||
|
85
freqtrade/freqai/prediction_models/XGBoostRFClassifier.py
Normal file
85
freqtrade/freqai/prediction_models/XGBoostRFClassifier.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
import pandas as pd
|
||||||
|
from pandas import DataFrame
|
||||||
|
from pandas.api.types import is_integer_dtype
|
||||||
|
from sklearn.preprocessing import LabelEncoder
|
||||||
|
from xgboost import XGBRFClassifier
|
||||||
|
|
||||||
|
from freqtrade.freqai.base_models.BaseClassifierModel import BaseClassifierModel
|
||||||
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class XGBoostRFClassifier(BaseClassifierModel):
|
||||||
|
"""
|
||||||
|
User created prediction model. The class needs to override three necessary
|
||||||
|
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||||
|
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
User sets up the training and test data to fit their desired model here
|
||||||
|
:params:
|
||||||
|
:data_dictionary: the dictionary constructed by DataHandler to hold
|
||||||
|
all the training and test data/labels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
X = data_dictionary["train_features"].to_numpy()
|
||||||
|
y = data_dictionary["train_labels"].to_numpy()[:, 0]
|
||||||
|
|
||||||
|
le = LabelEncoder()
|
||||||
|
if not is_integer_dtype(y):
|
||||||
|
y = pd.Series(le.fit_transform(y), dtype="int64")
|
||||||
|
|
||||||
|
if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) == 0:
|
||||||
|
eval_set = None
|
||||||
|
else:
|
||||||
|
test_features = data_dictionary["test_features"].to_numpy()
|
||||||
|
test_labels = data_dictionary["test_labels"].to_numpy()[:, 0]
|
||||||
|
|
||||||
|
if not is_integer_dtype(test_labels):
|
||||||
|
test_labels = pd.Series(le.transform(test_labels), dtype="int64")
|
||||||
|
|
||||||
|
eval_set = [(test_features, test_labels)]
|
||||||
|
|
||||||
|
train_weights = data_dictionary["train_weights"]
|
||||||
|
|
||||||
|
init_model = self.get_init_model(dk.pair)
|
||||||
|
|
||||||
|
model = XGBRFClassifier(**self.model_training_parameters)
|
||||||
|
|
||||||
|
model.fit(X=X, y=y, eval_set=eval_set, sample_weight=train_weights,
|
||||||
|
xgb_model=init_model)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||||
|
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||||
|
"""
|
||||||
|
Filter the prediction features data and predict with it.
|
||||||
|
:param: unfiltered_df: Full dataframe for the current backtest period.
|
||||||
|
:return:
|
||||||
|
:pred_df: dataframe containing the predictions
|
||||||
|
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
|
||||||
|
data (NaNs) or felt uncertain about data (PCA and DI index)
|
||||||
|
"""
|
||||||
|
|
||||||
|
(pred_df, dk.do_predict) = super().predict(unfiltered_df, dk, **kwargs)
|
||||||
|
|
||||||
|
le = LabelEncoder()
|
||||||
|
label = dk.label_list[0]
|
||||||
|
labels_before = list(dk.data['labels_std'].keys())
|
||||||
|
labels_after = le.fit_transform(labels_before).tolist()
|
||||||
|
pred_df[label] = le.inverse_transform(pred_df[label])
|
||||||
|
pred_df = pred_df.rename(
|
||||||
|
columns={labels_after[i]: labels_before[i] for i in range(len(labels_before))})
|
||||||
|
|
||||||
|
return (pred_df, dk.do_predict)
|
45
freqtrade/freqai/prediction_models/XGBoostRFRegressor.py
Normal file
45
freqtrade/freqai/prediction_models/XGBoostRFRegressor.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from xgboost import XGBRFRegressor
|
||||||
|
|
||||||
|
from freqtrade.freqai.base_models.BaseRegressionModel import BaseRegressionModel
|
||||||
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class XGBoostRFRegressor(BaseRegressionModel):
|
||||||
|
"""
|
||||||
|
User created prediction model. The class needs to override three necessary
|
||||||
|
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||||
|
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
User sets up the training and test data to fit their desired model here
|
||||||
|
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||||
|
all the training and test data/labels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
X = data_dictionary["train_features"]
|
||||||
|
y = data_dictionary["train_labels"]
|
||||||
|
|
||||||
|
if self.freqai_info.get("data_split_parameters", {}).get("test_size", 0.1) == 0:
|
||||||
|
eval_set = None
|
||||||
|
else:
|
||||||
|
eval_set = [(data_dictionary["test_features"], data_dictionary["test_labels"])]
|
||||||
|
eval_weights = [data_dictionary['test_weights']]
|
||||||
|
|
||||||
|
sample_weight = data_dictionary["train_weights"]
|
||||||
|
|
||||||
|
xgb_model = self.get_init_model(dk.pair)
|
||||||
|
|
||||||
|
model = XGBRFRegressor(**self.model_training_parameters)
|
||||||
|
|
||||||
|
model.fit(X=X, y=y, sample_weight=sample_weight, eval_set=eval_set,
|
||||||
|
sample_weight_eval_set=eval_weights, xgb_model=xgb_model)
|
||||||
|
|
||||||
|
return model
|
@ -919,30 +919,23 @@ class Backtesting:
|
|||||||
return trade
|
return trade
|
||||||
|
|
||||||
def handle_left_open(self, open_trades: Dict[str, List[LocalTrade]],
|
def handle_left_open(self, open_trades: Dict[str, List[LocalTrade]],
|
||||||
data: Dict[str, List[Tuple]]) -> List[LocalTrade]:
|
data: Dict[str, List[Tuple]]) -> None:
|
||||||
"""
|
"""
|
||||||
Handling of left open trades at the end of backtesting
|
Handling of left open trades at the end of backtesting
|
||||||
"""
|
"""
|
||||||
trades = []
|
|
||||||
for pair in open_trades.keys():
|
for pair in open_trades.keys():
|
||||||
if len(open_trades[pair]) > 0:
|
for trade in open_trades[pair]:
|
||||||
for trade in open_trades[pair]:
|
if trade.open_order_id and trade.nr_of_successful_entries == 0:
|
||||||
if trade.open_order_id and trade.nr_of_successful_entries == 0:
|
# Ignore trade if entry-order did not fill yet
|
||||||
# Ignore trade if entry-order did not fill yet
|
continue
|
||||||
continue
|
exit_row = data[pair][-1]
|
||||||
exit_row = data[pair][-1]
|
self._exit_trade(trade, exit_row, exit_row[OPEN_IDX], trade.amount)
|
||||||
self._exit_trade(trade, exit_row, exit_row[OPEN_IDX], trade.amount)
|
trade.orders[-1].close_bt_order(exit_row[DATE_IDX].to_pydatetime(), trade)
|
||||||
trade.orders[-1].close_bt_order(exit_row[DATE_IDX].to_pydatetime(), trade)
|
|
||||||
|
|
||||||
trade.close_date = exit_row[DATE_IDX].to_pydatetime()
|
trade.close_date = exit_row[DATE_IDX].to_pydatetime()
|
||||||
trade.exit_reason = ExitType.FORCE_EXIT.value
|
trade.exit_reason = ExitType.FORCE_EXIT.value
|
||||||
trade.close(exit_row[OPEN_IDX], show_msg=False)
|
trade.close(exit_row[OPEN_IDX], show_msg=False)
|
||||||
LocalTrade.close_bt_trade(trade)
|
LocalTrade.close_bt_trade(trade)
|
||||||
# Deepcopy object to have wallets update correctly
|
|
||||||
trade1 = deepcopy(trade)
|
|
||||||
trade1.is_open = True
|
|
||||||
trades.append(trade1)
|
|
||||||
return trades
|
|
||||||
|
|
||||||
def trade_slot_available(self, max_open_trades: int, open_trade_count: int) -> bool:
|
def trade_slot_available(self, max_open_trades: int, open_trade_count: int) -> bool:
|
||||||
# Always allow trades when max_open_trades is enabled.
|
# Always allow trades when max_open_trades is enabled.
|
||||||
@ -1094,7 +1087,6 @@ class Backtesting:
|
|||||||
:param enable_protections: Should protections be enabled?
|
:param enable_protections: Should protections be enabled?
|
||||||
:return: DataFrame with trades (results of backtesting)
|
:return: DataFrame with trades (results of backtesting)
|
||||||
"""
|
"""
|
||||||
trades: List[LocalTrade] = []
|
|
||||||
self.prepare_backtest(enable_protections)
|
self.prepare_backtest(enable_protections)
|
||||||
# Ensure wallets are uptodate (important for --strategy-list)
|
# Ensure wallets are uptodate (important for --strategy-list)
|
||||||
self.wallets.update()
|
self.wallets.update()
|
||||||
@ -1188,7 +1180,6 @@ class Backtesting:
|
|||||||
open_trade_count -= 1
|
open_trade_count -= 1
|
||||||
open_trades[pair].remove(trade)
|
open_trades[pair].remove(trade)
|
||||||
LocalTrade.close_bt_trade(trade)
|
LocalTrade.close_bt_trade(trade)
|
||||||
trades.append(trade)
|
|
||||||
self.wallets.update()
|
self.wallets.update()
|
||||||
self.run_protections(
|
self.run_protections(
|
||||||
enable_protections, pair, current_time, trade.trade_direction)
|
enable_protections, pair, current_time, trade.trade_direction)
|
||||||
@ -1197,10 +1188,10 @@ class Backtesting:
|
|||||||
self.progress.increment()
|
self.progress.increment()
|
||||||
current_time += timedelta(minutes=self.timeframe_min)
|
current_time += timedelta(minutes=self.timeframe_min)
|
||||||
|
|
||||||
trades += self.handle_left_open(open_trades, data=data)
|
self.handle_left_open(open_trades, data=data)
|
||||||
self.wallets.update()
|
self.wallets.update()
|
||||||
|
|
||||||
results = trade_list_to_dataframe(trades)
|
results = trade_list_to_dataframe(LocalTrade.trades)
|
||||||
return {
|
return {
|
||||||
'results': results,
|
'results': results,
|
||||||
'config': self.strategy.config,
|
'config': self.strategy.config,
|
||||||
|
@ -408,10 +408,10 @@ def generate_strategy_stats(pairlist: List[str],
|
|||||||
|
|
||||||
exit_reason_stats = generate_exit_reason_stats(max_open_trades=max_open_trades,
|
exit_reason_stats = generate_exit_reason_stats(max_open_trades=max_open_trades,
|
||||||
results=results)
|
results=results)
|
||||||
left_open_results = generate_pair_metrics(pairlist, stake_currency=stake_currency,
|
left_open_results = generate_pair_metrics(
|
||||||
starting_balance=start_balance,
|
pairlist, stake_currency=stake_currency, starting_balance=start_balance,
|
||||||
results=results.loc[results['is_open']],
|
results=results.loc[results['exit_reason'] == 'force_exit'], skip_nan=True)
|
||||||
skip_nan=True)
|
|
||||||
daily_stats = generate_daily_stats(results)
|
daily_stats = generate_daily_stats(results)
|
||||||
trade_stats = generate_trading_stats(results)
|
trade_stats = generate_trading_stats(results)
|
||||||
best_pair = max([pair for pair in pair_results if pair['key'] != 'TOTAL'],
|
best_pair = max([pair for pair in pair_results if pair['key'] != 'TOTAL'],
|
||||||
|
@ -30,6 +30,7 @@ def is_mac() -> bool:
|
|||||||
@pytest.mark.parametrize('model', [
|
@pytest.mark.parametrize('model', [
|
||||||
'LightGBMRegressor',
|
'LightGBMRegressor',
|
||||||
'XGBoostRegressor',
|
'XGBoostRegressor',
|
||||||
|
'XGBoostRFRegressor',
|
||||||
'CatboostRegressor',
|
'CatboostRegressor',
|
||||||
])
|
])
|
||||||
def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model):
|
def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model):
|
||||||
@ -55,6 +56,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model):
|
|||||||
|
|
||||||
data_load_timerange = TimeRange.parse_timerange("20180125-20180130")
|
data_load_timerange = TimeRange.parse_timerange("20180125-20180130")
|
||||||
new_timerange = TimeRange.parse_timerange("20180127-20180130")
|
new_timerange = TimeRange.parse_timerange("20180127-20180130")
|
||||||
|
freqai.dk.set_paths('ADA/BTC', None)
|
||||||
|
|
||||||
freqai.extract_data_and_train_model(
|
freqai.extract_data_and_train_model(
|
||||||
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
|
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
|
||||||
@ -93,6 +95,7 @@ def test_extract_data_and_train_model_MultiTargets(mocker, freqai_conf, model):
|
|||||||
|
|
||||||
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||||
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
||||||
|
freqai.dk.set_paths('ADA/BTC', None)
|
||||||
|
|
||||||
freqai.extract_data_and_train_model(
|
freqai.extract_data_and_train_model(
|
||||||
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
|
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
|
||||||
@ -111,6 +114,7 @@ def test_extract_data_and_train_model_MultiTargets(mocker, freqai_conf, model):
|
|||||||
'LightGBMClassifier',
|
'LightGBMClassifier',
|
||||||
'CatboostClassifier',
|
'CatboostClassifier',
|
||||||
'XGBoostClassifier',
|
'XGBoostClassifier',
|
||||||
|
'XGBoostRFClassifier',
|
||||||
])
|
])
|
||||||
def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model):
|
def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model):
|
||||||
if is_arm() and model == 'CatboostClassifier':
|
if is_arm() and model == 'CatboostClassifier':
|
||||||
@ -134,6 +138,7 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model):
|
|||||||
|
|
||||||
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||||
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
||||||
|
freqai.dk.set_paths('ADA/BTC', None)
|
||||||
|
|
||||||
freqai.extract_data_and_train_model(new_timerange, "ADA/BTC",
|
freqai.extract_data_and_train_model(new_timerange, "ADA/BTC",
|
||||||
strategy, freqai.dk, data_load_timerange)
|
strategy, freqai.dk, data_load_timerange)
|
||||||
|
Loading…
Reference in New Issue
Block a user