simplify plot_feature_importance call

This commit is contained in:
initrv 2022-09-17 18:53:43 +03:00
parent 86aa875bc9
commit 1c92734f39
2 changed files with 20 additions and 29 deletions

View File

@ -556,14 +556,6 @@ class IFreqaiModel(ABC):
model = self.train(unfiltered_dataframe, pair, dk) model = self.train(unfiltered_dataframe, pair, dk)
if self.freqai_info["feature_parameters"].get("plot_feature_importance", False):
plot_feature_importance(
model=model,
feature_names=dk.training_features_list,
pair=pair,
train_dir=dk.data_path
)
self.dd.pair_dict[pair]["trained_timestamp"] = new_trained_timerange.stopts self.dd.pair_dict[pair]["trained_timestamp"] = new_trained_timerange.stopts
dk.set_new_model_names(pair, new_trained_timerange) dk.set_new_model_names(pair, new_trained_timerange)
self.dd.pair_dict[pair]["first"] = False self.dd.pair_dict[pair]["first"] = False
@ -571,6 +563,9 @@ class IFreqaiModel(ABC):
self.dd.pair_to_end_of_training_queue(pair) self.dd.pair_to_end_of_training_queue(pair)
self.dd.save_data(model, pair, dk) self.dd.save_data(model, pair, dk)
if self.freqai_info["feature_parameters"].get("plot_feature_importance", False):
plot_feature_importance(model, pair, dk)
if self.freqai_info.get("purge_old_models", False): if self.freqai_info.get("purge_old_models", False):
self.dd.purge_old_models() self.dd.purge_old_models()

View File

@ -1,6 +1,7 @@
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -11,6 +12,7 @@ from freqtrade.data.history.history_utils import refresh_backtest_ohlcv_data
from freqtrade.exceptions import OperationalException from freqtrade.exceptions import OperationalException
from freqtrade.exchange import timeframe_to_seconds from freqtrade.exchange import timeframe_to_seconds
from freqtrade.exchange.exchange import market_is_active from freqtrade.exchange.exchange import market_is_active
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.plugins.pairlist.pairlist_helpers import dynamic_expand_pairlist from freqtrade.plugins.pairlist.pairlist_helpers import dynamic_expand_pairlist
@ -138,36 +140,30 @@ def get_required_data_timerange(
# ) # )
def plot_feature_importance(model, feature_names, pair, train_dir, count_max=50) -> None: def plot_feature_importance(model: Any, pair: str, dk: FreqaiDataKitchen,
count_max: int = 25) -> None:
""" """
Plot Best and Worst Features by importance for CatBoost model. Plot Best and worst features by importance for a single sub-train.
Called once per sub-train. :param model: Any = A model which was `fit` using a common library
Usage: plot_feature_importance( such as catboost or lightgbm
model=model, :param pair: str = pair e.g. BTC/USD
feature_names=dk.training_features_list, :param dk: FreqaiDataKitchen = non-persistent data container for current coin/loop
pair=pair, :param count_max: int = the amount of features to be loaded per column
train_dir=dk.data_path)
""" """
try: from freqtrade.plot.plotting import go, make_subplots, store_plot_file
import plotly.graph_objects as go
from plotly.subplots import make_subplots
except ImportError:
logger.exception("Module plotly not found \n Please install using `pip3 install plotly`")
exit(1)
from freqtrade.plot.plotting import store_plot_file # Extract feature importance from model
# Gather feature importance from model
if "catboost.core" in str(model.__class__): if "catboost.core" in str(model.__class__):
feature_importance = model.get_feature_importance() feature_importance = model.get_feature_importance()
elif "lightgbm.sklearn" in str(model.__class__): elif "lightgbm.sklearn" in str(model.__class__):
feature_importance = model.feature_importances_ feature_importance = model.feature_importances_
else: else:
raise NotImplementedError(f"Cannot extract feature importance for {model.__class__}") # TODO: Add support for more libraries
raise NotImplementedError(f"Cannot extract feature importance from {model.__class__}")
# Data preparation # Data preparation
fi_df = pd.DataFrame({ fi_df = pd.DataFrame({
"feature_names": np.array(feature_names), "feature_names": np.array(dk.training_features_list),
"feature_importance": np.array(feature_importance) "feature_importance": np.array(feature_importance)
}) })
fi_df_top = fi_df.nlargest(count_max, "feature_importance")[::-1] fi_df_top = fi_df.nlargest(count_max, "feature_importance")[::-1]
@ -185,9 +181,9 @@ def plot_feature_importance(model, feature_names, pair, train_dir, count_max=50)
fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.5) fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.5)
fig = add_feature_trace(fig, fi_df_top, 1) fig = add_feature_trace(fig, fi_df_top, 1)
fig = add_feature_trace(fig, fi_df_worst, 2) fig = add_feature_trace(fig, fi_df_worst, 2)
fig.update_layout(title_text=f"Best and Worst Features {pair}") fig.update_layout(title_text=f"Best and worst features by importance {pair}")
# Store plot file # Store plot file
model_dir, train_name = str(train_dir).rsplit("/", 1) model_dir, train_name = str(dk.data_path).rsplit("/", 1)
fi_dir = Path(f"{model_dir}/feature_importance/{pair.split('/')[0]}") fi_dir = Path(f"{model_dir}/feature_importance/{pair.split('/')[0]}")
store_plot_file(fig, f"{train_name}.html", fi_dir) store_plot_file(fig, f"{train_name}.html", fi_dir)