|
41 | 41 | 'FMRegressor', 'FMRegressionModel'] |
42 | 42 |
|
43 | 43 |
|
| 44 | +class JavaRegressor(JavaPredictor, _JavaPredictorParams): |
| 45 | + """ |
| 46 | + Java Regressor for regression tasks. |
| 47 | +
|
| 48 | + .. versionadded:: 3.0.0 |
| 49 | + """ |
| 50 | + pass |
| 51 | + |
| 52 | + |
| 53 | +class JavaRegressionModel(JavaPredictionModel, _JavaPredictorParams): |
| 54 | + """ |
| 55 | + Java Model produced by a ``_JavaRegressor``. |
| 56 | + To be mixed in with class:`pyspark.ml.JavaModel` |
| 57 | +
|
| 58 | + .. versionadded:: 3.0.0 |
| 59 | + """ |
| 60 | + pass |
| 61 | + |
| 62 | + |
44 | 63 | class _LinearRegressionParams(_JavaPredictorParams, HasRegParam, HasElasticNetParam, HasMaxIter, |
45 | 64 | HasTol, HasFitIntercept, HasStandardization, HasWeightCol, HasSolver, |
46 | 65 | HasAggregationDepth, HasLoss): |
@@ -69,7 +88,7 @@ def getEpsilon(self): |
69 | 88 |
|
70 | 89 |
|
71 | 90 | @inherit_doc |
72 | | -class LinearRegression(JavaPredictor, _LinearRegressionParams, JavaMLWritable, JavaMLReadable): |
| 91 | +class LinearRegression(JavaRegressor, _LinearRegressionParams, JavaMLWritable, JavaMLReadable): |
73 | 92 | """ |
74 | 93 | Linear regression. |
75 | 94 |
|
@@ -251,7 +270,7 @@ def setLoss(self, value): |
251 | 270 | return self._set(lossType=value) |
252 | 271 |
|
253 | 272 |
|
254 | | -class LinearRegressionModel(JavaPredictionModel, _LinearRegressionParams, GeneralJavaMLWritable, |
| 273 | +class LinearRegressionModel(JavaRegressionModel, _LinearRegressionParams, GeneralJavaMLWritable, |
255 | 274 | JavaMLReadable, HasTrainingSummary): |
256 | 275 | """ |
257 | 276 | Model fitted by :class:`LinearRegression`. |
@@ -758,7 +777,7 @@ class _DecisionTreeRegressorParams(_DecisionTreeParams, _TreeRegressorParams, Ha |
758 | 777 |
|
759 | 778 |
|
760 | 779 | @inherit_doc |
761 | | -class DecisionTreeRegressor(JavaPredictor, _DecisionTreeRegressorParams, JavaMLWritable, |
| 780 | +class DecisionTreeRegressor(JavaRegressor, _DecisionTreeRegressorParams, JavaMLWritable, |
762 | 781 | JavaMLReadable): |
763 | 782 | """ |
764 | 783 | `Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_ |
@@ -953,8 +972,10 @@ def setVarianceCol(self, value): |
953 | 972 |
|
954 | 973 |
|
955 | 974 | @inherit_doc |
956 | | -class DecisionTreeRegressionModel(_DecisionTreeModel, _DecisionTreeRegressorParams, |
957 | | - JavaMLWritable, JavaMLReadable): |
| 975 | +class DecisionTreeRegressionModel( |
| 976 | + JavaRegressionModel, _DecisionTreeModel, _DecisionTreeRegressorParams, |
| 977 | + JavaMLWritable, JavaMLReadable |
| 978 | +): |
958 | 979 | """ |
959 | 980 | Model fitted by :class:`DecisionTreeRegressor`. |
960 | 981 |
|
@@ -1000,7 +1021,7 @@ class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams): |
1000 | 1021 |
|
1001 | 1022 |
|
1002 | 1023 | @inherit_doc |
1003 | | -class RandomForestRegressor(JavaPredictor, _RandomForestRegressorParams, JavaMLWritable, |
| 1024 | +class RandomForestRegressor(JavaRegressor, _RandomForestRegressorParams, JavaMLWritable, |
1004 | 1025 | JavaMLReadable): |
1005 | 1026 | """ |
1006 | 1027 | `Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_ |
@@ -1198,8 +1219,10 @@ def setMinWeightFractionPerNode(self, value): |
1198 | 1219 | return self._set(minWeightFractionPerNode=value) |
1199 | 1220 |
|
1200 | 1221 |
|
1201 | | -class RandomForestRegressionModel(_TreeEnsembleModel, _RandomForestRegressorParams, |
1202 | | - JavaMLWritable, JavaMLReadable): |
| 1222 | +class RandomForestRegressionModel( |
| 1223 | + JavaRegressionModel, _TreeEnsembleModel, _RandomForestRegressorParams, |
| 1224 | + JavaMLWritable, JavaMLReadable |
| 1225 | +): |
1203 | 1226 | """ |
1204 | 1227 | Model fitted by :class:`RandomForestRegressor`. |
1205 | 1228 |
|
@@ -1251,7 +1274,7 @@ def getLossType(self): |
1251 | 1274 |
|
1252 | 1275 |
|
1253 | 1276 | @inherit_doc |
1254 | | -class GBTRegressor(JavaPredictor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable): |
| 1277 | +class GBTRegressor(JavaRegressor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable): |
1255 | 1278 | """ |
1256 | 1279 | `Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_ |
1257 | 1280 | learning algorithm for regression. |
@@ -1492,7 +1515,10 @@ def setMinWeightFractionPerNode(self, value): |
1492 | 1515 | return self._set(minWeightFractionPerNode=value) |
1493 | 1516 |
|
1494 | 1517 |
|
1495 | | -class GBTRegressionModel(_TreeEnsembleModel, _GBTRegressorParams, JavaMLWritable, JavaMLReadable): |
| 1518 | +class GBTRegressionModel( |
| 1519 | + JavaRegressionModel, _TreeEnsembleModel, _GBTRegressorParams, |
| 1520 | + JavaMLWritable, JavaMLReadable |
| 1521 | +): |
1496 | 1522 | """ |
1497 | 1523 | Model fitted by :class:`GBTRegressor`. |
1498 | 1524 |
|
@@ -1582,7 +1608,7 @@ def getQuantilesCol(self): |
1582 | 1608 |
|
1583 | 1609 |
|
1584 | 1610 | @inherit_doc |
1585 | | -class AFTSurvivalRegression(JavaPredictor, _AFTSurvivalRegressionParams, |
| 1611 | +class AFTSurvivalRegression(JavaRegressor, _AFTSurvivalRegressionParams, |
1586 | 1612 | JavaMLWritable, JavaMLReadable): |
1587 | 1613 | """ |
1588 | 1614 | Accelerated Failure Time (AFT) Model Survival Regression |
@@ -1723,7 +1749,7 @@ def setAggregationDepth(self, value): |
1723 | 1749 | return self._set(aggregationDepth=value) |
1724 | 1750 |
|
1725 | 1751 |
|
1726 | | -class AFTSurvivalRegressionModel(JavaPredictionModel, _AFTSurvivalRegressionParams, |
| 1752 | +class AFTSurvivalRegressionModel(JavaRegressionModel, _AFTSurvivalRegressionParams, |
1727 | 1753 | JavaMLWritable, JavaMLReadable): |
1728 | 1754 | """ |
1729 | 1755 | Model fitted by :class:`AFTSurvivalRegression`. |
@@ -1855,7 +1881,7 @@ def getOffsetCol(self): |
1855 | 1881 |
|
1856 | 1882 |
|
1857 | 1883 | @inherit_doc |
1858 | | -class GeneralizedLinearRegression(JavaPredictor, _GeneralizedLinearRegressionParams, |
| 1884 | +class GeneralizedLinearRegression(JavaRegressor, _GeneralizedLinearRegressionParams, |
1859 | 1885 | JavaMLWritable, JavaMLReadable): |
1860 | 1886 | """ |
1861 | 1887 | Generalized Linear Regression. |
@@ -2060,7 +2086,7 @@ def setAggregationDepth(self, value): |
2060 | 2086 | return self._set(aggregationDepth=value) |
2061 | 2087 |
|
2062 | 2088 |
|
2063 | | -class GeneralizedLinearRegressionModel(JavaPredictionModel, _GeneralizedLinearRegressionParams, |
| 2089 | +class GeneralizedLinearRegressionModel(JavaRegressionModel, _GeneralizedLinearRegressionParams, |
2064 | 2090 | JavaMLWritable, JavaMLReadable, HasTrainingSummary): |
2065 | 2091 | """ |
2066 | 2092 | Model fitted by :class:`GeneralizedLinearRegression`. |
@@ -2348,7 +2374,7 @@ def getInitStd(self): |
2348 | 2374 |
|
2349 | 2375 |
|
2350 | 2376 | @inherit_doc |
2351 | | -class FMRegressor(JavaPredictor, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable): |
| 2377 | +class FMRegressor(JavaRegressor, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable): |
2352 | 2378 | """ |
2353 | 2379 | Factorization Machines learning algorithm for regression. |
2354 | 2380 |
|
@@ -2512,7 +2538,7 @@ def setRegParam(self, value): |
2512 | 2538 | return self._set(regParam=value) |
2513 | 2539 |
|
2514 | 2540 |
|
2515 | | -class FMRegressionModel(JavaPredictionModel, _FactorizationMachinesParams, JavaMLWritable, |
| 2541 | +class FMRegressionModel(JavaRegressionModel, _FactorizationMachinesParams, JavaMLWritable, |
2516 | 2542 | JavaMLReadable): |
2517 | 2543 | """ |
2518 | 2544 | Model fitted by :class:`FMRegressor`. |
|
0 commit comments