Skip to content

Commit 3228732

Browse files
zero323srowen
authored andcommitted
[SPARK-30533][ML][PYSPARK] Add classes to represent Java Regressors and RegressionModels
### What changes were proposed in this pull request? This PR adds: - `pyspark.ml.regression.JavaRegressor` - `pyspark.ml.regression.JavaRegressionModel` classes and replaces `JavaPredictor` and `JavaPredictionModel` in - `LinearRegression` / `LinearRegressionModel` - `DecisionTreeRegressor` / `DecisionTreeRegressionModel` (just addition as `JavaPredictionModel` hasn't been used) - `RandomForestRegressor` / `RandomForestRegressionModel` (just addition as `JavaPredictionModel` hasn't been used) - `GBTRegressor` / `GBTRegressionModel` (just addition as `JavaPredictionModel` hasn't been used) - `AFTSurvivalRegression` / `AFTSurvivalRegressionModel` - `GeneralizedLinearRegression` / `GeneralizedLinearRegressionModel` - `FMRegressor` / `FMRegressionModel` ### Why are the changes needed? - Internal PySpark consistency. - Feature parity with Scala. - Intermediate step towards implementing [SPARK-29212](https://issues.apache.org/jira/browse/SPARK-29212) ### Does this PR introduce any user-facing change? It adds new base classes, so it will affect `mro`. Otherwise interfaces should stay intact. ### How was this patch tested? Existing tests. Closes #27241 from zero323/SPARK-30533. Authored-by: zero323 <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 505693c commit 3228732

File tree

1 file changed

+42
-16
lines changed

1 file changed

+42
-16
lines changed

python/pyspark/ml/regression.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,25 @@
4141
'FMRegressor', 'FMRegressionModel']
4242

4343

44+
class JavaRegressor(JavaPredictor, _JavaPredictorParams):
45+
"""
46+
Java Regressor for regression tasks.
47+
48+
.. versionadded:: 3.0.0
49+
"""
50+
pass
51+
52+
53+
class JavaRegressionModel(JavaPredictionModel, _JavaPredictorParams):
54+
"""
55+
Java Model produced by a ``_JavaRegressor``.
56+
To be mixed in with class:`pyspark.ml.JavaModel`
57+
58+
.. versionadded:: 3.0.0
59+
"""
60+
pass
61+
62+
4463
class _LinearRegressionParams(_JavaPredictorParams, HasRegParam, HasElasticNetParam, HasMaxIter,
4564
HasTol, HasFitIntercept, HasStandardization, HasWeightCol, HasSolver,
4665
HasAggregationDepth, HasLoss):
@@ -69,7 +88,7 @@ def getEpsilon(self):
6988

7089

7190
@inherit_doc
72-
class LinearRegression(JavaPredictor, _LinearRegressionParams, JavaMLWritable, JavaMLReadable):
91+
class LinearRegression(JavaRegressor, _LinearRegressionParams, JavaMLWritable, JavaMLReadable):
7392
"""
7493
Linear regression.
7594
@@ -251,7 +270,7 @@ def setLoss(self, value):
251270
return self._set(lossType=value)
252271

253272

254-
class LinearRegressionModel(JavaPredictionModel, _LinearRegressionParams, GeneralJavaMLWritable,
273+
class LinearRegressionModel(JavaRegressionModel, _LinearRegressionParams, GeneralJavaMLWritable,
255274
JavaMLReadable, HasTrainingSummary):
256275
"""
257276
Model fitted by :class:`LinearRegression`.
@@ -758,7 +777,7 @@ class _DecisionTreeRegressorParams(_DecisionTreeParams, _TreeRegressorParams, Ha
758777

759778

760779
@inherit_doc
761-
class DecisionTreeRegressor(JavaPredictor, _DecisionTreeRegressorParams, JavaMLWritable,
780+
class DecisionTreeRegressor(JavaRegressor, _DecisionTreeRegressorParams, JavaMLWritable,
762781
JavaMLReadable):
763782
"""
764783
`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
@@ -953,8 +972,10 @@ def setVarianceCol(self, value):
953972

954973

955974
@inherit_doc
956-
class DecisionTreeRegressionModel(_DecisionTreeModel, _DecisionTreeRegressorParams,
957-
JavaMLWritable, JavaMLReadable):
975+
class DecisionTreeRegressionModel(
976+
JavaRegressionModel, _DecisionTreeModel, _DecisionTreeRegressorParams,
977+
JavaMLWritable, JavaMLReadable
978+
):
958979
"""
959980
Model fitted by :class:`DecisionTreeRegressor`.
960981
@@ -1000,7 +1021,7 @@ class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams):
10001021

10011022

10021023
@inherit_doc
1003-
class RandomForestRegressor(JavaPredictor, _RandomForestRegressorParams, JavaMLWritable,
1024+
class RandomForestRegressor(JavaRegressor, _RandomForestRegressorParams, JavaMLWritable,
10041025
JavaMLReadable):
10051026
"""
10061027
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
@@ -1198,8 +1219,10 @@ def setMinWeightFractionPerNode(self, value):
11981219
return self._set(minWeightFractionPerNode=value)
11991220

12001221

1201-
class RandomForestRegressionModel(_TreeEnsembleModel, _RandomForestRegressorParams,
1202-
JavaMLWritable, JavaMLReadable):
1222+
class RandomForestRegressionModel(
1223+
JavaRegressionModel, _TreeEnsembleModel, _RandomForestRegressorParams,
1224+
JavaMLWritable, JavaMLReadable
1225+
):
12031226
"""
12041227
Model fitted by :class:`RandomForestRegressor`.
12051228
@@ -1251,7 +1274,7 @@ def getLossType(self):
12511274

12521275

12531276
@inherit_doc
1254-
class GBTRegressor(JavaPredictor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
1277+
class GBTRegressor(JavaRegressor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
12551278
"""
12561279
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
12571280
learning algorithm for regression.
@@ -1492,7 +1515,10 @@ def setMinWeightFractionPerNode(self, value):
14921515
return self._set(minWeightFractionPerNode=value)
14931516

14941517

1495-
class GBTRegressionModel(_TreeEnsembleModel, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
1518+
class GBTRegressionModel(
1519+
JavaRegressionModel, _TreeEnsembleModel, _GBTRegressorParams,
1520+
JavaMLWritable, JavaMLReadable
1521+
):
14961522
"""
14971523
Model fitted by :class:`GBTRegressor`.
14981524
@@ -1582,7 +1608,7 @@ def getQuantilesCol(self):
15821608

15831609

15841610
@inherit_doc
1585-
class AFTSurvivalRegression(JavaPredictor, _AFTSurvivalRegressionParams,
1611+
class AFTSurvivalRegression(JavaRegressor, _AFTSurvivalRegressionParams,
15861612
JavaMLWritable, JavaMLReadable):
15871613
"""
15881614
Accelerated Failure Time (AFT) Model Survival Regression
@@ -1723,7 +1749,7 @@ def setAggregationDepth(self, value):
17231749
return self._set(aggregationDepth=value)
17241750

17251751

1726-
class AFTSurvivalRegressionModel(JavaPredictionModel, _AFTSurvivalRegressionParams,
1752+
class AFTSurvivalRegressionModel(JavaRegressionModel, _AFTSurvivalRegressionParams,
17271753
JavaMLWritable, JavaMLReadable):
17281754
"""
17291755
Model fitted by :class:`AFTSurvivalRegression`.
@@ -1855,7 +1881,7 @@ def getOffsetCol(self):
18551881

18561882

18571883
@inherit_doc
1858-
class GeneralizedLinearRegression(JavaPredictor, _GeneralizedLinearRegressionParams,
1884+
class GeneralizedLinearRegression(JavaRegressor, _GeneralizedLinearRegressionParams,
18591885
JavaMLWritable, JavaMLReadable):
18601886
"""
18611887
Generalized Linear Regression.
@@ -2060,7 +2086,7 @@ def setAggregationDepth(self, value):
20602086
return self._set(aggregationDepth=value)
20612087

20622088

2063-
class GeneralizedLinearRegressionModel(JavaPredictionModel, _GeneralizedLinearRegressionParams,
2089+
class GeneralizedLinearRegressionModel(JavaRegressionModel, _GeneralizedLinearRegressionParams,
20642090
JavaMLWritable, JavaMLReadable, HasTrainingSummary):
20652091
"""
20662092
Model fitted by :class:`GeneralizedLinearRegression`.
@@ -2348,7 +2374,7 @@ def getInitStd(self):
23482374

23492375

23502376
@inherit_doc
2351-
class FMRegressor(JavaPredictor, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable):
2377+
class FMRegressor(JavaRegressor, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable):
23522378
"""
23532379
Factorization Machines learning algorithm for regression.
23542380
@@ -2512,7 +2538,7 @@ def setRegParam(self, value):
25122538
return self._set(regParam=value)
25132539

25142540

2515-
class FMRegressionModel(JavaPredictionModel, _FactorizationMachinesParams, JavaMLWritable,
2541+
class FMRegressionModel(JavaRegressionModel, _FactorizationMachinesParams, JavaMLWritable,
25162542
JavaMLReadable):
25172543
"""
25182544
Model fitted by :class:`FMRegressor`.

0 commit comments

Comments
 (0)