@@ -88,7 +88,10 @@ def train_titanic_regression(interactions):
8888 return model , x_test , y_test
8989
9090
91- def train_bank_churners_multiclass_classification (encode_label = True ):
91+ def train_bank_churners_multiclass_classification (
92+ encode_label = True ,
93+ interactions = 0 ,
94+ ):
9295 df = pd .read_csv (
9396 os .path .join ('examples' ,'BankChurners.csv' ),
9497 )
@@ -105,7 +108,10 @@ def train_bank_churners_multiclass_classification(encode_label=True):
105108 y_enc = y
106109 x = df [feature_columns ]
107110 x_train , x_test , y_train , y_test = train_test_split (x , y_enc )
108- model = ExplainableBoostingClassifier (interactions = 0 , feature_types = feature_types )
111+ model = ExplainableBoostingClassifier (
112+ interactions = interactions ,
113+ feature_types = feature_types
114+ )
109115 model .fit (x_train , y_train )
110116
111117 return model , x_test , y_test
@@ -262,8 +268,12 @@ def test_predict_binary_classification_with_categorical(interactions, explain, o
262268
263269
264270@pytest .mark .parametrize ("encode_label" , [False , True ])
265- def test_predict_multiclass_classification (encode_label ):
266- model_ebm , x_test , y_test = train_bank_churners_multiclass_classification (encode_label = encode_label )
271+ @pytest .mark .parametrize ("interactions" , [0 , 2 , [(0 , 1 , 2 )]])
272+ def test_predict_multiclass_classification (encode_label , interactions ):
273+ model_ebm , x_test , y_test = train_bank_churners_multiclass_classification (
274+ encode_label = encode_label ,
275+ interactions = interactions ,
276+ )
267277 pred_ebm = model_ebm .predict (x_test )
268278
269279 model_onnx = ebm2onnx .to_onnx (
0 commit comments