Skip to content

Commit fb813cf

Browse files
committed
In xgb classifier, only use a label encoder if it was created
Signed-off-by: Avi Shinnar <[email protected]>
1 parent 787b032 commit fb813cf

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

lale/lib/xgboost/xgb_classifier.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def validate_hyperparams(cls, **hyperparams):
7373
def __init__(self, **hyperparams):
7474
self.validate_hyperparams(**hyperparams)
7575
self._wrapped_model = xgboost.XGBClassifier(**hyperparams)
76+
self._label_encoder = None
7677

7778
def fit(self, X, y, **fit_params):
7879
renamed_X = _rename_all_features(X)
@@ -93,7 +94,10 @@ def fit(self, X, y, **fit_params):
9394
trained_le = trainable_le.fit(y)
9495
self._label_encoder = trained_le
9596

96-
numeric_y = self._label_encoder.transform(y)
97+
if self._label_encoder is not None:
98+
numeric_y = self._label_encoder.transform(y)
99+
else:
100+
numeric_y = y
97101
self._wrapped_model.fit(renamed_X, numeric_y, **fit_params)
98102
return self
99103

@@ -109,8 +113,10 @@ def predict(self, X, **predict_params):
109113
with warnings.catch_warnings():
110114
warnings.filterwarnings("ignore", category=FutureWarning)
111115
numeric_result = self._wrapped_model.predict(renamed_X, **predict_params)
112-
result = self._label_encoder.inverse_transform(numeric_result)
113-
return result
116+
if self._label_encoder is not None:
117+
return self._label_encoder.inverse_transform(numeric_result)
118+
else:
119+
return numeric_result
114120

115121
def predict_proba(self, X):
116122
return self._wrapped_model.predict_proba(X)

0 commit comments

Comments
 (0)