From 9e26ddd56a6f15ff3405608544e619916f9825b2 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 14 Feb 2025 13:53:08 +0800 Subject: [PATCH] nit --- python/pyspark/ml/feature.py | 18 ++++++------------ python/pyspark/ml/util.py | 7 +++++++ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 6a4a9dc998753..d669fab27d505 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -54,10 +54,9 @@ JavaMLReadable, JavaMLWritable, try_remote_attribute_relation, - ML_CONNECT_HELPER_ID, + invoke_helper_attr, ) from pyspark.ml.wrapper import ( - JavaWrapper, JavaEstimator, JavaModel, JavaParams, @@ -1225,8 +1224,7 @@ def from_vocabulary( if is_remote(): model = CountVectorizerModel() - helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID) - model._java_obj = helper._call_java( + model._java_obj = invoke_helper_attr( "countVectorizerModelFromVocabulary", model.uid, list(vocabulary), @@ -4845,8 +4843,7 @@ def from_labels( """ if is_remote(): model = StringIndexerModel() - helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID) - model._java_obj = helper._call_java( + model._java_obj = invoke_helper_attr( "stringIndexerModelFromLabels", model.uid, (list(labels), ArrayType(StringType())), @@ -4885,8 +4882,7 @@ def from_arrays_of_labels( """ if is_remote(): model = StringIndexerModel() - helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID) - model._java_obj = helper._call_java( + model._java_obj = invoke_helper_attr( "stringIndexerModelFromLabelsArray", model.uid, ( @@ -5142,8 +5138,7 @@ def __init__( "org.apache.spark.ml.feature.StopWordsRemover", self.uid ) if is_remote(): - helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID) - locale = helper._call_java("stopWordsRemoverGetDefaultOrUS") + locale = invoke_helper_attr("stopWordsRemoverGetDefaultOrUS") else: locale = self._java_obj.getLocale() @@ -5274,8 +5269,7 @@ def loadDefaultStopWords(language: str) -> List[str]: italian, norwegian, portuguese, russian, spanish, swedish, turkish """ if is_remote(): - helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID) - stopWords = helper._call_java("stopWordsRemoverLoadDefaultStopWords", language) + stopWords = invoke_helper_attr("stopWordsRemoverLoadDefaultStopWords", language) return list(stopWords) else: diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 9eab45239b8f5..50b98ab12ce8d 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -78,6 +78,13 @@ def wrapped(self: "JavaWrapper") -> Any: return cast(FuncT, wrapped) +def invoke_helper_attr(method: str, *args: Any) -> Any: + from pyspark.ml.wrapper import JavaWrapper + + helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID) + return helper._call_java(method, *args) + + def invoke_helper_relation(method: str, *args: Any) -> "ConnectDataFrame": from pyspark.ml.wrapper import JavaWrapper