Skip to content

Commit

Permalink
nit
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengruifeng committed Feb 14, 2025
1 parent 55c65d9 commit 9e26ddd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
18 changes: 6 additions & 12 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@
JavaMLReadable,
JavaMLWritable,
try_remote_attribute_relation,
ML_CONNECT_HELPER_ID,
invoke_helper_attr,
)
from pyspark.ml.wrapper import (
JavaWrapper,
JavaEstimator,
JavaModel,
JavaParams,
Expand Down Expand Up @@ -1225,8 +1224,7 @@ def from_vocabulary(

if is_remote():
model = CountVectorizerModel()
helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
model._java_obj = helper._call_java(
model._java_obj = invoke_helper_attr(
"countVectorizerModelFromVocabulary",
model.uid,
list(vocabulary),
Expand Down Expand Up @@ -4845,8 +4843,7 @@ def from_labels(
"""
if is_remote():
model = StringIndexerModel()
helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
model._java_obj = helper._call_java(
model._java_obj = invoke_helper_attr(
"stringIndexerModelFromLabels",
model.uid,
(list(labels), ArrayType(StringType())),
Expand Down Expand Up @@ -4885,8 +4882,7 @@ def from_arrays_of_labels(
"""
if is_remote():
model = StringIndexerModel()
helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
model._java_obj = helper._call_java(
model._java_obj = invoke_helper_attr(
"stringIndexerModelFromLabelsArray",
model.uid,
(
Expand Down Expand Up @@ -5142,8 +5138,7 @@ def __init__(
"org.apache.spark.ml.feature.StopWordsRemover", self.uid
)
if is_remote():
helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
locale = helper._call_java("stopWordsRemoverGetDefaultOrUS")
locale = invoke_helper_attr("stopWordsRemoverGetDefaultOrUS")
else:
locale = self._java_obj.getLocale()

Expand Down Expand Up @@ -5274,8 +5269,7 @@ def loadDefaultStopWords(language: str) -> List[str]:
italian, norwegian, portuguese, russian, spanish, swedish, turkish
"""
if is_remote():
helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
stopWords = helper._call_java("stopWordsRemoverLoadDefaultStopWords", language)
stopWords = invoke_helper_attr("stopWordsRemoverLoadDefaultStopWords", language)
return list(stopWords)

else:
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def wrapped(self: "JavaWrapper") -> Any:
return cast(FuncT, wrapped)


def invoke_helper_attr(method: str, *args: Any) -> Any:
from pyspark.ml.wrapper import JavaWrapper

helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
return helper._call_java(method, *args)


def invoke_helper_relation(method: str, *args: Any) -> "ConnectDataFrame":
from pyspark.ml.wrapper import JavaWrapper

Expand Down

0 comments on commit 9e26ddd

Please sign in to comment.