allow plot to plot multitargets, add test

This commit is contained in:
robcaulk 2022-09-17 19:17:44 +02:00
parent 1c92734f39
commit 2c23effbf2
4 changed files with 75 additions and 35 deletions

View File

@ -563,7 +563,7 @@ class IFreqaiModel(ABC):
self.dd.pair_to_end_of_training_queue(pair) self.dd.pair_to_end_of_training_queue(pair)
self.dd.save_data(model, pair, dk) self.dd.save_data(model, pair, dk)
if self.freqai_info["feature_parameters"].get("plot_feature_importance", False): if self.freqai_info["feature_parameters"].get("plot_feature_importance", True):
plot_feature_importance(model, pair, dk) plot_feature_importance(model, pair, dk)
if self.freqai_info.get("purge_old_models", False): if self.freqai_info.get("purge_old_models", False):

View File

@ -1,6 +1,5 @@
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path
from typing import Any from typing import Any
import numpy as np import numpy as np
@ -153,37 +152,42 @@ def plot_feature_importance(model: Any, pair: str, dk: FreqaiDataKitchen,
from freqtrade.plot.plotting import go, make_subplots, store_plot_file from freqtrade.plot.plotting import go, make_subplots, store_plot_file
# Extract feature importance from model # Extract feature importance from model
if "catboost.core" in str(model.__class__): models = {}
feature_importance = model.get_feature_importance() if 'FreqaiMultiOutputRegressor' in str(model.__class__):
elif "lightgbm.sklearn" in str(model.__class__): for estimator, label in zip(model.estimators_, dk.label_list):
feature_importance = model.feature_importances_ models[label] = estimator
else:
# TODO: Add support for more libraries
raise NotImplementedError(f"Cannot extract feature importance from {model.__class__}")
# Data preparation for label in models:
fi_df = pd.DataFrame({ mdl = models[label]
"feature_names": np.array(dk.training_features_list), if "catboost.core" in str(mdl.__class__):
"feature_importance": np.array(feature_importance) feature_importance = mdl.get_feature_importance()
}) elif "lightgbm.sklearn" or "xgb" in str(mdl.__class__):
fi_df_top = fi_df.nlargest(count_max, "feature_importance")[::-1] feature_importance = mdl.feature_importances_
fi_df_worst = fi_df.nsmallest(count_max, "feature_importance")[::-1] else:
# TODO: Add support for more libraries
raise NotImplementedError(f"Cannot extract feature importance from {mdl.__class__}")
# Plotting # Data preparation
def add_feature_trace(fig, fi_df, col): fi_df = pd.DataFrame({
return fig.add_trace( "feature_names": np.array(dk.training_features_list),
go.Bar( "feature_importance": np.array(feature_importance)
x=fi_df["feature_importance"], })
y=fi_df["feature_names"], fi_df_top = fi_df.nlargest(count_max, "feature_importance")[::-1]
orientation='h', showlegend=False fi_df_worst = fi_df.nsmallest(count_max, "feature_importance")[::-1]
), row=1, col=col
)
fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.5)
fig = add_feature_trace(fig, fi_df_top, 1)
fig = add_feature_trace(fig, fi_df_worst, 2)
fig.update_layout(title_text=f"Best and worst features by importance {pair}")
# Store plot file # Plotting
model_dir, train_name = str(dk.data_path).rsplit("/", 1) def add_feature_trace(fig, fi_df, col):
fi_dir = Path(f"{model_dir}/feature_importance/{pair.split('/')[0]}") return fig.add_trace(
store_plot_file(fig, f"{train_name}.html", fi_dir) go.Bar(
x=fi_df["feature_importance"],
y=fi_df["feature_names"],
orientation='h', showlegend=False
), row=1, col=col
)
fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.5)
fig = add_feature_trace(fig, fi_df_top, 1)
fig = add_feature_trace(fig, fi_df_worst, 2)
fig.update_layout(title_text=f"Best and worst features by importance {pair}")
store_plot_file(fig, f"{dk.model_filename}-{label}.html", dk.data_path,
include_plotlyjs="cdn")

View File

@ -601,7 +601,8 @@ def generate_plot_filename(pair: str, timeframe: str) -> str:
return file_name return file_name
def store_plot_file(fig, filename: str, directory: Path, auto_open: bool = False) -> None: def store_plot_file(fig, filename: str, directory: Path,
auto_open: bool = False, include_plotlyjs=True) -> None:
""" """
Generate a plot html file from pre populated fig plotly object Generate a plot html file from pre populated fig plotly object
:param fig: Plotly Figure to plot :param fig: Plotly Figure to plot
@ -614,7 +615,7 @@ def store_plot_file(fig, filename: str, directory: Path, auto_open: bool = False
_filename = directory.joinpath(filename) _filename = directory.joinpath(filename)
plot(fig, filename=str(_filename), plot(fig, filename=str(_filename),
auto_open=auto_open) auto_open=auto_open, include_plotlyjs=include_plotlyjs)
logger.info(f"Stored plot as {_filename}") logger.info(f"Stored plot as {_filename}")

View File

@ -315,3 +315,38 @@ def test_principal_component_analysis(mocker, freqai_conf):
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_pca_object.pkl") assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_pca_object.pkl")
shutil.rmtree(Path(freqai.dk.full_path)) shutil.rmtree(Path(freqai.dk.full_path))
def test_plot_feature_importance(mocker, freqai_conf):
from freqtrade.freqai.utils import plot_feature_importance
freqai_conf.update({"timerange": "20180110-20180130"})
freqai_conf.get("freqai", {}).get("feature_parameters", {}).update(
{"princpial_component_analysis": "true"})
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
exchange = get_patched_exchange(mocker, freqai_conf)
strategy.dp = DataProvider(freqai_conf, exchange)
strategy.freqai_info = freqai_conf.get("freqai", {})
freqai = strategy.freqai
freqai.live = True
freqai.dk = FreqaiDataKitchen(freqai_conf)
timerange = TimeRange.parse_timerange("20180110-20180130")
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
freqai.dd.pair_dict = MagicMock()
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
new_timerange = TimeRange.parse_timerange("20180120-20180130")
freqai.extract_data_and_train_model(
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
model = freqai.dd.load_data("ADA/BTC", freqai.dk)
plot_feature_importance(model, "ADA/BTC", freqai.dk)
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}.html")
shutil.rmtree(Path(freqai.dk.full_path))