Skip to content

Commit

Permalink
[SPARK-51132][ML][BUILD] Upgrade JPMML to 1.7.1
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR aims to upgrade `JPMML` from 1.4.8 to 1.7.1. The main changes from 1.4.8 to 1.7.1 are as follows:
1. Starting from version 1.5.0, `PMML` schema version has been updated from 4.3 to 4.4, the related commit is:
jpmml/jpmml-model@7d8607a
2. Starting from version 1.6.0, `Java XML Binding` has been upgraded to `Jakarta XML Binding`,  the related commit is: jpmml/jpmml-model@d76de1c

After this PR, the exported PMML model schema version has been upgraded from 4.3(https://dmg.org/pmml/v4-3/GeneralStructure.html) to 4.4(https://dmg.org/pmml/v4-4-1/GeneralStructure.html).

### Why are the changes needed?

1. Upgrade the PMML standard to the latest 4.4 version ;
2. Upgrade `Java XML Binding` to `Jakarta XML Binding` by upgrade `JPMML`.

### Does this PR introduce _any_ user-facing change?

Yes, the exported PMML model version has been upgraded from 4.3 to 4.4, details are as follows:
1. `version` has been changed from 4.2 to 4.4;
2. `xmlns` has been changed from `http://www.dmg.org/PMML-4_3` to `http://www.dmg.org/PMML-4_4`;
3. `Application version` has been changed to the  current Spark version.

Before:
```
<PMML version="4.2" xmlns="http://www.dmg.org/PMML-4_3" xmlns:data="http://jpmml.org/jpmml-model/InlineTable">
    <Header description="k-means clustering">
        <Application name="Apache Spark MLlib" version="3.5.4"/>
        <Timestamp>2025-02-08T10:00:55</Timestamp>
    </Header>
...
```
After:
```
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML version="4.4" xmlns="http://www.dmg.org/PMML-4_4" xmlns:data="http://jpmml.org/jpmml-model/InlineTable">
    <Header description="k-means clustering">
        <Application name="Apache Spark MLlib" version="4.1.0-SNAPSHOT"/>
        <Timestamp>2025-02-08T11:34:40</Timestamp>
    </Header>
...
```

**Despite this change, on the one hand, PMML version 4.4 was released as early as November 2019. On the other hand, the upgrade from 4.3 to 4.4 is backward compatible.(Reference: https://dmg.org/pmml/v4-4-1/Changes.html)**

### How was this patch tested?

1. Passed GA;
2. Manually checked the changes after the upgrade.
KMeans:
```
import org.apache.spark.ml.clustering.KMeans
val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
val kmeans = new KMeans().setK(2).setSeed(1L)
val model = kmeans.fit(dataset)
model.write.format("pmml").save("./kmeans")
```

LinearRegression:
```
import org.apache.spark.ml.regression.LinearRegression
val dataset = spark.read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt")
val lr = new LinearRegression()
val model = lr.fit(dataset)
model.write.format("pmml").save("./lr")
```
### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #49854 from wayneguow/pmml.

Authored-by: Wei Guo <[email protected]>
Signed-off-by: yangjie01 <[email protected]>
  • Loading branch information
wayneguow authored and LuciferYang committed Feb 11, 2025
1 parent 54959ab commit cea79dc
Show file tree
Hide file tree
Showing 12 changed files with 61 additions and 52 deletions.
2 changes: 2 additions & 0 deletions LICENSE-binary
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,9 @@ javax.xml.bind:jaxb-api https://github.com/javaee/jaxb-v2
Eclipse Distribution License (EDL) 1.0
--------------------------------------
com.sun.istack:istack-commons-runtime
jakarta.activation:jakarta.activation-api
jakarta.xml.bind:jakarta.xml.bind-api
org.glassfish.jaxb:jaxb-core
org.glassfish.jaxb:jaxb-runtime

Eclipse Public License (EPL) 2.0
Expand Down
9 changes: 5 additions & 4 deletions dev/deps/spark-deps-hadoop-3-hive-2.3
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ httpclient/4.5.14//httpclient-4.5.14.jar
httpcore/4.4.16//httpcore-4.4.16.jar
icu4j/76.1//icu4j-76.1.jar
ini4j/0.5.4//ini4j-0.5.4.jar
istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar
istack-commons-runtime/4.1.2//istack-commons-runtime-4.1.2.jar
ivy/2.5.3//ivy-2.5.3.jar
j2objc-annotations/3.0.0//j2objc-annotations-3.0.0.jar
jackson-annotations/2.18.2//jackson-annotations-2.18.2.jar
Expand All @@ -113,21 +113,22 @@ jackson-dataformat-yaml/2.18.2//jackson-dataformat-yaml-2.18.2.jar
jackson-datatype-jsr310/2.18.2//jackson-datatype-jsr310-2.18.2.jar
jackson-mapper-asl/1.9.13//jackson-mapper-asl-1.9.13.jar
jackson-module-scala_2.13/2.18.2//jackson-module-scala_2.13-2.18.2.jar
jakarta.activation-api/2.1.3//jakarta.activation-api-2.1.3.jar
jakarta.annotation-api/2.1.1//jakarta.annotation-api-2.1.1.jar
jakarta.inject-api/2.0.1//jakarta.inject-api-2.0.1.jar
jakarta.servlet-api/5.0.0//jakarta.servlet-api-5.0.0.jar
jakarta.validation-api/3.0.2//jakarta.validation-api-3.0.2.jar
jakarta.ws.rs-api/3.0.0//jakarta.ws.rs-api-3.0.0.jar
jakarta.xml.bind-api/2.3.2//jakarta.xml.bind-api-2.3.2.jar
jakarta.xml.bind-api/4.0.2//jakarta.xml.bind-api-4.0.2.jar
janino/3.1.9//janino-3.1.9.jar
java-diff-utils/4.15//java-diff-utils-4.15.jar
java-xmlbuilder/1.2//java-xmlbuilder-1.2.jar
javassist/3.30.2-GA//javassist-3.30.2-GA.jar
javax.jdo/3.2.0-m3//javax.jdo-3.2.0-m3.jar
javax.servlet-api/4.0.1//javax.servlet-api-4.0.1.jar
javolution/5.5.1//javolution-5.5.1.jar
jaxb-api/2.2.11//jaxb-api-2.2.11.jar
jaxb-runtime/2.3.2//jaxb-runtime-2.3.2.jar
jaxb-core/4.0.5//jaxb-core-4.0.5.jar
jaxb-runtime/4.0.5//jaxb-runtime-4.0.5.jar
jcl-over-slf4j/2.0.16//jcl-over-slf4j-2.0.16.jar
jdo-api/3.0.1//jdo-api-3.0.1.jar
jdom2/2.0.6//jdom2-2.0.6.jar
Expand Down
19 changes: 19 additions & 0 deletions docs/ml-migration-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,25 @@ Note that this migration guide describes the items specific to MLlib.
Many items of SQL migration can be applied when migrating MLlib to higher versions for DataFrame-based APIs.
Please refer [Migration Guide: SQL, Datasets and DataFrame](sql-migration-guide.html).

## Upgrading from MLlib 3.5 to 4.0

### Breaking changes
{:.no_toc}

There are no breaking changes.

### Deprecations and changes of behavior
{:.no_toc}

**Deprecations**

There are no deprecations.

**Changes of behavior**

* [SPARK-51132](https://issues.apache.org/jira/browse/SPARK-51132):
The PMML XML schema version of exported PMML format models by [PMML model export](mllib-pmml-model-export.html) has been upgraded from `PMML-4_3` to `PMML-4_4`.

## Upgrading from MLlib 2.4 to 3.0

### Breaking changes
Expand Down
8 changes: 4 additions & 4 deletions mllib/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@
<url>https://spark.apache.org/</url>

<dependencies>
<dependency>
<groupId>javax.xml.bind</groupId>
<artifactId>jaxb-api</artifactId>
</dependency>
<dependency>
<groupId>org.scala-lang.modules</groupId>
<artifactId>scala-parser-combinators_${scala.binary.version}</artifactId>
Expand Down Expand Up @@ -144,6 +140,10 @@
<groupId>org.glassfish.jaxb</groupId>
<artifactId>jaxb-runtime</artifactId>
</dependency>
<dependency>
<groupId>jakarta.xml.bind</groupId>
<artifactId>jakarta.xml.bind-api</artifactId>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-tags_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.mllib.pmml
import java.io.{File, OutputStream, StringWriter}
import javax.xml.transform.stream.StreamResult

import org.jpmml.model.JAXBUtil
import org.jpmml.model.JAXBSerializer

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
Expand All @@ -39,7 +39,8 @@ trait PMMLExportable {
*/
private def toPMML(streamResult: StreamResult): Unit = {
val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this)
JAXBUtil.marshalPMML(pmmlModelExport.getPmml(), streamResult)
val jaxbSerializer = new JAXBSerializer()
jaxbSerializer.marshalPretty(pmmlModelExport.getPmml(), streamResult)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ package org.apache.spark.mllib.pmml.`export`

import scala.{Array => SArray}

import org.dmg.pmml.{DataDictionary, DataField, DataType, FieldName, MiningField,
MiningFunction, MiningSchema, OpType}
import org.dmg.pmml.{DataDictionary, DataField, DataType, MiningField, MiningFunction,
MiningSchema, OpType}
import org.dmg.pmml.regression.{NumericPredictor, RegressionModel, RegressionTable}

import org.apache.spark.mllib.regression.GeneralizedLinearModel
Expand All @@ -44,7 +44,7 @@ private[mllib] class BinaryClassificationPMMLModelExport(
pmml.getHeader.setDescription(description)

if (model.weights.size > 0) {
val fields = new SArray[FieldName](model.weights.size)
val fields = new SArray[String](model.weights.size)
val dataDictionary = new DataDictionary
val miningSchema = new MiningSchema
val regressionTableYES = new RegressionTable(model.intercept).setTargetCategory("1")
Expand All @@ -67,7 +67,7 @@ private[mllib] class BinaryClassificationPMMLModelExport(
.addRegressionTables(regressionTableYES, regressionTableNO)

for (i <- 0 until model.weights.size) {
fields(i) = FieldName.create("field_" + i)
fields(i) = "field_" + i
dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.addMiningFields(new MiningField(fields(i))
Expand All @@ -76,7 +76,7 @@ private[mllib] class BinaryClassificationPMMLModelExport(
}

// add target field
val targetField = FieldName.create("target")
val targetField = "target"
dataDictionary
.addDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING))
miningSchema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ package org.apache.spark.mllib.pmml.`export`

import scala.{Array => SArray}

import org.dmg.pmml.{DataDictionary, DataField, DataType, FieldName, MiningField,
MiningFunction, MiningSchema, OpType}
import org.dmg.pmml.{DataDictionary, DataField, DataType, MiningField, MiningFunction,
MiningSchema, OpType}
import org.dmg.pmml.regression.{NumericPredictor, RegressionModel, RegressionTable}

import org.apache.spark.mllib.regression.GeneralizedLinearModel
Expand All @@ -42,7 +42,7 @@ private[mllib] class GeneralizedLinearPMMLModelExport(
pmml.getHeader.setDescription(description)

if (model.weights.size > 0) {
val fields = new SArray[FieldName](model.weights.size)
val fields = new SArray[String](model.weights.size)
val dataDictionary = new DataDictionary
val miningSchema = new MiningSchema
val regressionTable = new RegressionTable(model.intercept)
Expand All @@ -53,7 +53,7 @@ private[mllib] class GeneralizedLinearPMMLModelExport(
.addRegressionTables(regressionTable)

for (i <- 0 until model.weights.size) {
fields(i) = FieldName.create("field_" + i)
fields(i) = "field_" + i
dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.addMiningFields(new MiningField(fields(i))
Expand All @@ -62,7 +62,7 @@ private[mllib] class GeneralizedLinearPMMLModelExport(
}

// for completeness add target field
val targetField = FieldName.create("target")
val targetField = "target"
dataDictionary.addDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.addMiningFields(new MiningField(targetField)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.mllib.pmml.`export`
import scala.{Array => SArray}

import org.dmg.pmml.{Array, CompareFunction, ComparisonMeasure, DataDictionary, DataField, DataType,
FieldName, MiningField, MiningFunction, MiningSchema, OpType, SquaredEuclidean}
MiningField, MiningFunction, MiningSchema, OpType, SquaredEuclidean}
import org.dmg.pmml.clustering.{Cluster, ClusteringField, ClusteringModel}

import org.apache.spark.mllib.clustering.KMeansModel
Expand All @@ -40,7 +40,7 @@ private[mllib] class KMeansPMMLModelExport(model: KMeansModel) extends PMMLModel

if (model.clusterCenters.length > 0) {
val clusterCenter = model.clusterCenters(0)
val fields = new SArray[FieldName](clusterCenter.size)
val fields = new SArray[String](clusterCenter.size)
val dataDictionary = new DataDictionary
val miningSchema = new MiningSchema
val comparisonMeasure = new ComparisonMeasure()
Expand All @@ -55,7 +55,7 @@ private[mllib] class KMeansPMMLModelExport(model: KMeansModel) extends PMMLModel
.setNumberOfClusters(model.clusterCenters.length)

for (i <- 0 until clusterCenter.size) {
fields(i) = FieldName.create("field_" + i)
fields(i) = "field_" + i
dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.addMiningFields(new MiningField(fields(i))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.Locale

import scala.beans.BeanProperty

import org.dmg.pmml.{Application, Header, PMML, Timestamp}
import org.dmg.pmml.{Application, Header, PMML, Timestamp, Version}

private[mllib] trait PMMLModelExport {

Expand All @@ -44,6 +44,6 @@ private[mllib] trait PMMLModelExport {
val header = new Header()
.setApplication(app)
.setTimestamp(timestamp)
new PMML("4.2", header, null)
new PMML(Version.PMML_4_4.getVersion(), header, null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
assert(fields(0).getOpType() == OpType.CONTINUOUS)
val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel]
val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors
val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList
val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient().doubleValue()).toList
assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3)
assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3)
}
Expand Down
5 changes: 3 additions & 2 deletions mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.io.ByteArrayInputStream
import java.nio.charset.StandardCharsets

import org.dmg.pmml.PMML
import org.jpmml.model.{JAXBUtil, SAXUtil}
import org.jpmml.model.{JAXBSerializer, SAXUtil}
import org.jpmml.model.filters.ImportFilter

/**
Expand All @@ -37,6 +37,7 @@ private[spark] object PMMLUtils {
val transformed = SAXUtil.createFilteredSource(
new ByteArrayInputStream(input.getBytes(StandardCharsets.UTF_8)),
new ImportFilter())
JAXBUtil.unmarshalPMML(transformed)
val jaxbSerializer = new JAXBSerializer()
jaxbSerializer.unmarshal(transformed).asInstanceOf[PMML]
}
}
33 changes: 9 additions & 24 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-model</artifactId>
<version>1.4.8</version>
<version>1.7.1</version>
<scope>provided</scope>
<exclusions>
<exclusion>
Expand Down Expand Up @@ -599,32 +599,24 @@
<dependency>
<groupId>org.glassfish.jaxb</groupId>
<artifactId>jaxb-runtime</artifactId>
<version>2.3.2</version>
<version>4.0.5</version>
<scope>compile</scope>
<exclusions>
<!-- for now, we only write XML in PMML export, and these can be excluded -->
<exclusion>
<groupId>com.sun.xml.fastinfoset</groupId>
<artifactId>FastInfoset</artifactId>
</exclusion>
<exclusion>
<groupId>org.glassfish.jaxb</groupId>
<artifactId>txw2</artifactId>
</exclusion>
<exclusion>
<groupId>org.jvnet.staxex</groupId>
<artifactId>stax-ex</artifactId>
</exclusion>
<!--
SPARK-27611: Exclude redundant javax.activation implementation, which
conflicts with the existing javax.activation:activation:1.1.1 dependency.
-->
<exclusion>
<groupId>jakarta.activation</groupId>
<artifactId>jakarta.activation-api</artifactId>
<groupId>org.eclipse.angus</groupId>
<artifactId>angus-activation</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>jakarta.xml.bind</groupId>
<artifactId>jakarta.xml.bind-api</artifactId>
<version>4.0.2</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
Expand Down Expand Up @@ -1061,13 +1053,6 @@
<groupId>org.glassfish.jersey.core</groupId>
<artifactId>jersey-server</artifactId>
<version>${jersey.version}</version>
<!-- SPARK-28765 Unused JDK11-specific dependency -->
<exclusions>
<exclusion>
<groupId>jakarta.xml.bind</groupId>
<artifactId>jakarta.xml.bind-api</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.glassfish.jersey.core</groupId>
Expand Down

0 comments on commit cea79dc

Please sign in to comment.