Skip to content

Commit 8e35964

Browse files
authored
Merge pull request #621 from guillaume-vignal/feature/plotly_interactionplot
Add Option to Display Interaction Plot
2 parents daedb63 + 96fa8f3 commit 8e35964

File tree

3 files changed

+47
-36
lines changed

3 files changed

+47
-36
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "shapash"
7-
version = "2.7.6"
7+
version = "2.7.7"
88
authors = [
99
{name = "Yann Golhen"},
1010
{name = "Sebastien Bidault"},
@@ -29,7 +29,7 @@ classifiers = [
2929
"Operating System :: OS Independent",
3030
]
3131
dependencies = [
32-
"plotly>=5.0.0",
32+
"plotly>=5.0.0,<6.0.0",
3333
"matplotlib>=3.2.0",
3434
"numpy>1.18.0,<2",
3535
"pandas>=2.1.0",

shapash/explainer/smart_explainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,7 @@ def generate_report(
12061206
notebook_path=None,
12071207
kernel_name=None,
12081208
max_points=200,
1209+
display_interaction_plot=False,
12091210
nb_top_interactions=5,
12101211
):
12111212
"""
@@ -1251,6 +1252,9 @@ def generate_report(
12511252
by default.
12521253
max_points : int, optional
12531254
number of maximum points in the contribution plot
1255+
display_interaction_plot: bool, optional
1256+
Whether to display the interaction plot. This can be computationally expensive,
1257+
so it is set to False by default to optimize performance.
12541258
nb_top_interactions : int
12551259
Number of top interactions to display.
12561260
Examples
@@ -1305,6 +1309,7 @@ def generate_report(
13051309
title_description=title_description,
13061310
metrics=metrics,
13071311
max_points=max_points,
1312+
display_interaction_plot=display_interaction_plot,
13081313
nb_top_interactions=nb_top_interactions,
13091314
),
13101315
notebook_path=notebook_path,

shapash/report/project_report.py

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ def __init__(
103103
else:
104104
self.max_points = 200
105105

106+
if "display_interaction_plot" in self.config.keys():
107+
self.display_interaction_plot = config["display_interaction_plot"]
108+
else:
109+
self.display_interaction_plot = False
110+
106111
if "nb_top_interactions" in self.config.keys():
107112
self.nb_top_interactions = config["nb_top_interactions"]
108113
else:
@@ -427,46 +432,47 @@ def display_model_explainability(self):
427432
)
428433

429434
# Interaction Plot
430-
explain_contrib_data_interaction = list()
431-
list_ind, _ = self.explainer.plot._select_indices_interactions_plot(
432-
selection=None, max_points=self.max_points
433-
)
434-
interaction_values = self.explainer.get_interaction_values(selection=list_ind)
435-
sorted_top_features_indices = compute_sorted_variables_interactions_list_indices(interaction_values)
436-
indices_to_plot = sorted_top_features_indices[: self.nb_top_interactions]
435+
if self.display_interaction_plot:
436+
explain_contrib_data_interaction = list()
437+
list_ind, _ = self.explainer.plot._select_indices_interactions_plot(
438+
selection=None, max_points=self.max_points
439+
)
440+
interaction_values = self.explainer.get_interaction_values(selection=list_ind)
441+
sorted_top_features_indices = compute_sorted_variables_interactions_list_indices(interaction_values)
442+
indices_to_plot = sorted_top_features_indices[: self.nb_top_interactions]
437443

438-
for i, ids in enumerate(indices_to_plot):
439-
id0, id1 = ids
444+
for i, ids in enumerate(indices_to_plot):
445+
id0, id1 = ids
440446

441-
fig_one_interaction = self.explainer.plot.interactions_plot(
442-
col1=self.explainer.columns_dict[id0],
443-
col2=self.explainer.columns_dict[id1],
444-
max_points=self.max_points,
445-
)
447+
fig_one_interaction = self.explainer.plot.interactions_plot(
448+
col1=self.explainer.columns_dict[id0],
449+
col2=self.explainer.columns_dict[id1],
450+
max_points=self.max_points,
451+
)
446452

447-
explain_contrib_data_interaction.append(
453+
explain_contrib_data_interaction.append(
454+
{
455+
"feature_index": i,
456+
"name": self.explainer.columns_dict[id0] + " / " + self.explainer.columns_dict[id1],
457+
"description": self.explainer.features_dict[self.explainer.columns_dict[id0]]
458+
+ " / "
459+
+ self.explainer.features_dict[self.explainer.columns_dict[id1]],
460+
"plot": plotly.io.to_html(fig_one_interaction, include_plotlyjs=False, full_html=False),
461+
}
462+
)
463+
464+
# Aggregating the data
465+
explain_data.append(
448466
{
449-
"feature_index": i,
450-
"name": self.explainer.columns_dict[id0] + " / " + self.explainer.columns_dict[id1],
451-
"description": self.explainer.features_dict[self.explainer.columns_dict[id0]]
452-
+ " / "
453-
+ self.explainer.features_dict[self.explainer.columns_dict[id1]],
454-
"plot": plotly.io.to_html(fig_one_interaction, include_plotlyjs=False, full_html=False),
467+
"index": index_label,
468+
"name": label_value,
469+
"feature_importance_plot": plotly.io.to_html(
470+
fig_features_importance, include_plotlyjs=False, full_html=False
471+
),
472+
"features": explain_contrib_data,
473+
"features_interaction": explain_contrib_data_interaction,
455474
}
456475
)
457-
458-
# Aggregating the data
459-
explain_data.append(
460-
{
461-
"index": index_label,
462-
"name": label_value,
463-
"feature_importance_plot": plotly.io.to_html(
464-
fig_features_importance, include_plotlyjs=False, full_html=False
465-
),
466-
"features": explain_contrib_data,
467-
"features_interaction": explain_contrib_data_interaction,
468-
}
469-
)
470476
print_html(explainability_template.render(labels=explain_data))
471477
print_md("---")
472478

0 commit comments

Comments
 (0)