Skip to content

Commit bff86f1

Browse files
committed
[SPARK-51060][ML][PYTHON][CONNECT][FOLLOW-UP] Fix the uid of Bucketizer fitted by QuantileDiscretizer
### What changes were proposed in this pull request? Fix the `uid` of `Bucketizer` fitted by `QuantileDiscretizer` ### Why are the changes needed? On connect, the `Bucketizer` fitted by `QuantileDiscretizer` has the same `uid` as `QuantileDiscretizer`. This is not consistent with pyspark classic. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes apache#49770 from zhengruifeng/ml_connect_qd_uid. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 3e8a0c1 commit bff86f1

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

python/pyspark/ml/tests/test_feature.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,9 @@ def test_quantile_discretizer_single_column(self):
998998
bucketizer = qds.fit(df)
999999
self.assertIsInstance(bucketizer, Bucketizer)
10001000
# Bucketizer doesn't inherit uid from QuantileDiscretizer
1001-
# self.assertEqual(qds.uid, bucketizer.uid)
1001+
self.assertNotEqual(qds.uid, bucketizer.uid)
1002+
self.assertTrue(qds.uid.startswith("QuantileDiscretizer"))
1003+
self.assertTrue(bucketizer.uid.startswith("Bucketizer"))
10021004

10031005
# check model coefficients
10041006
self.assertEqual(bucketizer.getSplits(), [float("-inf"), 0.4, float("inf")])
@@ -1046,7 +1048,9 @@ def test_quantile_discretizer_multiple_columns(self):
10461048
bucketizer = qds.fit(df)
10471049
self.assertIsInstance(bucketizer, Bucketizer)
10481050
# Bucketizer doesn't inherit uid from QuantileDiscretizer
1049-
# self.assertEqual(qds.uid, bucketizer.uid)
1051+
self.assertNotEqual(qds.uid, bucketizer.uid)
1052+
self.assertTrue(qds.uid.startswith("QuantileDiscretizer"))
1053+
self.assertTrue(bucketizer.uid.startswith("Bucketizer"))
10501054

10511055
# check model coefficients
10521056
self.assertEqual(

python/pyspark/ml/util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def wrapped(self: "JavaEstimator", dataset: "ConnectDataFrame") -> Any:
136136
model_info = deserialize(properties)
137137
client.add_ml_cache(model_info.obj_ref.id)
138138
model = self._create_model(model_info.obj_ref.id)
139-
model._resetUid(self.uid)
139+
if model.__class__.__name__ not in ["Bucketizer"]:
140+
model._resetUid(self.uid)
140141
return self._copyValues(model)
141142
else:
142143
return f(self, dataset)

0 commit comments

Comments
 (0)