Skip to content

Commit b8eaa11

Browse files
committed
fix multiclass classification with interactions
Fixes #24
1 parent b540e5b commit b8eaa11

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

ebm2onnx/ebm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _get_bin_score_2d(g):
104104

105105
init_reshape = graph.create_initializer(
106106
g, "score_reshape", onnx.TensorProto.INT64,
107-
[3], [-1, 1, 1],
107+
[3], [-1, 1, bin_scores.shape[-1]],
108108
)
109109

110110
g = ops.concat(axis=1)(g)

tests/test_convert.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,10 @@ def train_titanic_regression(interactions):
8888
return model, x_test, y_test
8989

9090

91-
def train_bank_churners_multiclass_classification(encode_label=True):
91+
def train_bank_churners_multiclass_classification(
92+
encode_label=True,
93+
interactions=0,
94+
):
9295
df = pd.read_csv(
9396
os.path.join('examples','BankChurners.csv'),
9497
)
@@ -105,7 +108,10 @@ def train_bank_churners_multiclass_classification(encode_label=True):
105108
y_enc = y
106109
x = df[feature_columns]
107110
x_train, x_test, y_train, y_test = train_test_split(x, y_enc)
108-
model = ExplainableBoostingClassifier(interactions=0, feature_types=feature_types)
111+
model = ExplainableBoostingClassifier(
112+
interactions=interactions,
113+
feature_types=feature_types
114+
)
109115
model.fit(x_train, y_train)
110116

111117
return model, x_test, y_test
@@ -262,8 +268,12 @@ def test_predict_binary_classification_with_categorical(interactions, explain, o
262268

263269

264270
@pytest.mark.parametrize("encode_label", [False, True])
265-
def test_predict_multiclass_classification(encode_label):
266-
model_ebm, x_test, y_test = train_bank_churners_multiclass_classification(encode_label=encode_label)
271+
@pytest.mark.parametrize("interactions", [0, 2, [(0, 1, 2)]])
272+
def test_predict_multiclass_classification(encode_label, interactions):
273+
model_ebm, x_test, y_test = train_bank_churners_multiclass_classification(
274+
encode_label=encode_label,
275+
interactions=interactions,
276+
)
267277
pred_ebm = model_ebm.predict(x_test)
268278

269279
model_onnx = ebm2onnx.to_onnx(

0 commit comments

Comments
 (0)