@@ -103,6 +103,11 @@ def __init__(
103103 else :
104104 self .max_points = 200
105105
106+ if "display_interaction_plot" in self .config .keys ():
107+ self .display_interaction_plot = config ["display_interaction_plot" ]
108+ else :
109+ self .display_interaction_plot = False
110+
106111 if "nb_top_interactions" in self .config .keys ():
107112 self .nb_top_interactions = config ["nb_top_interactions" ]
108113 else :
@@ -427,46 +432,47 @@ def display_model_explainability(self):
427432 )
428433
429434 # Interaction Plot
430- explain_contrib_data_interaction = list ()
431- list_ind , _ = self .explainer .plot ._select_indices_interactions_plot (
432- selection = None , max_points = self .max_points
433- )
434- interaction_values = self .explainer .get_interaction_values (selection = list_ind )
435- sorted_top_features_indices = compute_sorted_variables_interactions_list_indices (interaction_values )
436- indices_to_plot = sorted_top_features_indices [: self .nb_top_interactions ]
435+ if self .display_interaction_plot :
436+ explain_contrib_data_interaction = list ()
437+ list_ind , _ = self .explainer .plot ._select_indices_interactions_plot (
438+ selection = None , max_points = self .max_points
439+ )
440+ interaction_values = self .explainer .get_interaction_values (selection = list_ind )
441+ sorted_top_features_indices = compute_sorted_variables_interactions_list_indices (interaction_values )
442+ indices_to_plot = sorted_top_features_indices [: self .nb_top_interactions ]
437443
438- for i , ids in enumerate (indices_to_plot ):
439- id0 , id1 = ids
444+ for i , ids in enumerate (indices_to_plot ):
445+ id0 , id1 = ids
440446
441- fig_one_interaction = self .explainer .plot .interactions_plot (
442- col1 = self .explainer .columns_dict [id0 ],
443- col2 = self .explainer .columns_dict [id1 ],
444- max_points = self .max_points ,
445- )
447+ fig_one_interaction = self .explainer .plot .interactions_plot (
448+ col1 = self .explainer .columns_dict [id0 ],
449+ col2 = self .explainer .columns_dict [id1 ],
450+ max_points = self .max_points ,
451+ )
446452
447- explain_contrib_data_interaction .append (
453+ explain_contrib_data_interaction .append (
454+ {
455+ "feature_index" : i ,
456+ "name" : self .explainer .columns_dict [id0 ] + " / " + self .explainer .columns_dict [id1 ],
457+ "description" : self .explainer .features_dict [self .explainer .columns_dict [id0 ]]
458+ + " / "
459+ + self .explainer .features_dict [self .explainer .columns_dict [id1 ]],
460+ "plot" : plotly .io .to_html (fig_one_interaction , include_plotlyjs = False , full_html = False ),
461+ }
462+ )
463+
464+ # Aggregating the data
465+ explain_data .append (
448466 {
449- "feature_index" : i ,
450- "name" : self .explainer .columns_dict [id0 ] + " / " + self .explainer .columns_dict [id1 ],
451- "description" : self .explainer .features_dict [self .explainer .columns_dict [id0 ]]
452- + " / "
453- + self .explainer .features_dict [self .explainer .columns_dict [id1 ]],
454- "plot" : plotly .io .to_html (fig_one_interaction , include_plotlyjs = False , full_html = False ),
467+ "index" : index_label ,
468+ "name" : label_value ,
469+ "feature_importance_plot" : plotly .io .to_html (
470+ fig_features_importance , include_plotlyjs = False , full_html = False
471+ ),
472+ "features" : explain_contrib_data ,
473+ "features_interaction" : explain_contrib_data_interaction ,
455474 }
456475 )
457-
458- # Aggregating the data
459- explain_data .append (
460- {
461- "index" : index_label ,
462- "name" : label_value ,
463- "feature_importance_plot" : plotly .io .to_html (
464- fig_features_importance , include_plotlyjs = False , full_html = False
465- ),
466- "features" : explain_contrib_data ,
467- "features_interaction" : explain_contrib_data_interaction ,
468- }
469- )
470476 print_html (explainability_template .render (labels = explain_data ))
471477 print_md ("---" )
472478
0 commit comments