From 86aa875bc9d5edeba04f908fe45b011e52045c83 Mon Sep 17 00:00:00 2001 From: initrv Date: Fri, 16 Sep 2022 21:47:12 +0300 Subject: [PATCH] plot features as html instead of png --- freqtrade/freqai/utils.py | 62 ++++++++++++++++----------------------- requirements-plot.txt | 1 - 2 files changed, 25 insertions(+), 38 deletions(-) diff --git a/freqtrade/freqai/utils.py b/freqtrade/freqai/utils.py index 86d89d4b0..3f6b8b053 100644 --- a/freqtrade/freqai/utils.py +++ b/freqtrade/freqai/utils.py @@ -1,13 +1,9 @@ import logging from datetime import datetime, timezone -# for plot_feature_importance from pathlib import Path import numpy as np import pandas as pd -import plotly.graph_objects as go -import plotly.io as pio -from plotly.subplots import make_subplots from freqtrade.configuration import TimeRange from freqtrade.data.dataprovider import DataProvider @@ -142,64 +138,56 @@ def get_required_data_timerange( # ) -def plot_feature_importance(model, feature_names, pair, train_dir, count_max=25) -> None: +def plot_feature_importance(model, feature_names, pair, train_dir, count_max=50) -> None: """ Plot Best and Worst Features by importance for CatBoost model. Called once per sub-train. - - Required: pip install kaleido - Usage: plot_feature_importance( model=model, feature_names=dk.training_features_list, pair=pair, train_dir=dk.data_path) """ + try: + import plotly.graph_objects as go + from plotly.subplots import make_subplots + except ImportError: + logger.exception("Module plotly not found \n Please install using `pip3 install plotly`") + exit(1) + + from freqtrade.plot.plotting import store_plot_file # Gather feature importance from model if "catboost.core" in str(model.__class__): - fi = model.get_feature_importance() - + feature_importance = model.get_feature_importance() elif "lightgbm.sklearn" in str(model.__class__): - fi = model.feature_importances_ - + feature_importance = model.feature_importances_ else: raise NotImplementedError(f"Cannot extract feature importance for {model.__class__}") # Data preparation fi_df = pd.DataFrame({ "feature_names": np.array(feature_names), - "feature_importance": np.array(fi) + "feature_importance": np.array(feature_importance) }) fi_df_top = fi_df.nlargest(count_max, "feature_importance")[::-1] fi_df_worst = fi_df.nsmallest(count_max, "feature_importance")[::-1] # Plotting + def add_feature_trace(fig, fi_df, col): + return fig.add_trace( + go.Bar( + x=fi_df["feature_importance"], + y=fi_df["feature_names"], + orientation='h', showlegend=False + ), row=1, col=col + ) fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.5) - fig.add_trace( - go.Bar( - x=fi_df_top["feature_importance"], - y=fi_df_top["feature_names"], - orientation='h', showlegend=False - ), row=1, col=1 - ) - fig.add_trace( - go.Bar( - x=fi_df_worst["feature_importance"], - y=fi_df_worst["feature_names"], - orientation='h', showlegend=False - ), row=1, col=2 - ) - fig.update_layout( - title_text=f"Best and Worst Features {pair}", - width=1000, height=600 - ) + fig = add_feature_trace(fig, fi_df_top, 1) + fig = add_feature_trace(fig, fi_df_worst, 2) + fig.update_layout(title_text=f"Best and Worst Features {pair}") - # Create directory and save image + # Store plot file model_dir, train_name = str(train_dir).rsplit("/", 1) fi_dir = Path(f"{model_dir}/feature_importance/{pair.split('/')[0]}") - fi_dir.mkdir(parents=True, exist_ok=True) - - pio.write_image(fig, f"{fi_dir}/{train_name}.png", format="png") - - logger.info(f"Freqai saving feature importance plot {fi_dir}/{train_name}.png") + store_plot_file(fig, f"{train_name}.html", fi_dir) diff --git a/requirements-plot.txt b/requirements-plot.txt index ef3cf9f24..80cd3f4f2 100644 --- a/requirements-plot.txt +++ b/requirements-plot.txt @@ -2,4 +2,3 @@ -r requirements.txt plotly==5.10.0 -kaleido==0.2.1