Skip to content

Commit 6f4d166

Browse files
committed
#678 Add a method to convert integrals to decimals in schema according to the metadata.
1 parent 8997831 commit 6f4d166

File tree

8 files changed

+322
-16
lines changed

8 files changed

+322
-16
lines changed

build.sbt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,16 @@ lazy val sparkCobol = (project in file("spark-cobol"))
9292
log.info(s"Building with Spark ${sparkVersion(scalaVersion.value)}, Scala ${scalaVersion.value}")
9393
sparkVersion(scalaVersion.value)
9494
},
95-
(Compile / compile) := ((Compile / compile) dependsOn printSparkVersion).value,
95+
Compile / compile := ((Compile / compile) dependsOn printSparkVersion).value,
96+
Compile / unmanagedSourceDirectories += {
97+
val sourceDir = (Compile / sourceDirectory).value
98+
CrossVersion.partialVersion(scalaVersion.value) match {
99+
case Some((2, n)) if n == 11 => sourceDir / "scala_2.11"
100+
case Some((2, n)) if n == 12 => sourceDir / "scala_2.12"
101+
case Some((2, n)) if n == 13 => sourceDir / "scala_2.13"
102+
case _ => throw new RuntimeException("Unsupported Scala version")
103+
}
104+
},
96105
libraryDependencies ++= SparkCobolDependencies(scalaVersion.value) :+ getScalaDependency(scalaVersion.value),
97106
dependencyOverrides ++= SparkCobolDependenciesOverride,
98107
Test / fork := true, // Spark tests fail randomly otherwise

spark-cobol/pom.xml

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,33 @@
5959
</dependency>
6060
</dependencies>
6161

62-
<build>
63-
<resources>
64-
<resource>
65-
<directory>src/main/resources</directory>
66-
<filtering>true</filtering>
67-
</resource>
68-
</resources>
69-
</build>
62+
<build>
63+
<resources>
64+
<resource>
65+
<directory>src/main/resources</directory>
66+
<filtering>true</filtering>
67+
</resource>
68+
</resources>
69+
<plugins>
70+
<plugin>
71+
<groupId>org.codehaus.mojo</groupId>
72+
<artifactId>build-helper-maven-plugin</artifactId>
73+
<version>3.0.0</version>
74+
<executions>
75+
<execution>
76+
<phase>generate-sources</phase>
77+
<goals>
78+
<goal>add-source</goal>
79+
</goals>
80+
<configuration>
81+
<sources>
82+
<source>src/main/scala_${scala.compat.version}</source>
83+
</sources>
84+
</configuration>
85+
</execution>
86+
</executions>
87+
</plugin>
88+
</plugins>
89+
</build>
7090

7191
</project>

spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ package za.co.absa.cobrix.spark.cobol.utils
1919
import com.fasterxml.jackson.databind.ObjectMapper
2020
import org.apache.hadoop.fs.FileSystem
2121
import org.apache.spark.SparkContext
22-
import org.apache.spark.sql.functions.{concat_ws, expr, max}
22+
import org.apache.spark.sql.functions.{array, col, expr, max, struct}
23+
import za.co.absa.cobrix.spark.cobol.utils.impl.HofsWrapper.transform
2324
import org.apache.spark.sql.types._
2425
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
2526
import za.co.absa.cobrix.cobol.internal.Logging
@@ -178,6 +179,48 @@ object SparkUtils extends Logging {
178179
df.select(fields.toSeq: _*)
179180
}
180181

182+
def mapPrimitives(df: DataFrame)(f: (StructField, Column) => Column): DataFrame = {
183+
def mapField(column: Column, field: StructField): Column = {
184+
field.dataType match {
185+
case st: StructType =>
186+
val columns = st.fields.map(f => mapField(column.getField(field.name), f))
187+
struct(columns: _*).as(field.name)
188+
case ar: ArrayType =>
189+
mapArray(ar, column, field.name).as(field.name)
190+
case _ =>
191+
f(field, column).as(field.name)
192+
}
193+
}
194+
195+
def mapArray(arr: ArrayType, column: Column, columnName: String): Column = {
196+
arr.elementType match {
197+
case st: StructType =>
198+
transform(column, c => {
199+
val columns = st.fields.map(f => mapField(c.getField(f.name), f))
200+
struct(columns: _*)
201+
})
202+
case ar: ArrayType =>
203+
array(mapArray(ar, column, columnName))
204+
case p =>
205+
array(f(StructField(columnName, p), column))
206+
}
207+
}
208+
209+
val columns = df.schema.fields.map(f => mapField(col(f.name), f))
210+
df.select(columns: _*)
211+
}
212+
213+
def covertIntegralToDecimal(df: DataFrame): DataFrame = {
214+
mapPrimitives(df) { (field, c) =>
215+
val metadata = field.metadata
216+
if (metadata.contains("precision") && (field.dataType == LongType || field.dataType == IntegerType || field.dataType == ShortType)) {
217+
val precision = metadata.getLong("precision").toInt
218+
c.cast(DecimalType(precision, 0)).as(field.name)
219+
} else {
220+
c
221+
}
222+
}
223+
}
181224

182225
/**
183226
* Given an instance of DataFrame returns a dataframe where all primitive fields are converted to String
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright 2018 ABSA Group Limited
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package za.co.absa.cobrix.spark.cobol.utils.impl
18+
19+
import org.apache.spark.sql.Column
20+
21+
object HofsWrapper {
22+
/**
23+
* Applies the function `f` to every element in the `array`. The method is an equivalent to the `map` function
24+
* from functional programming.
25+
*
26+
* The method is not available in Scala 2.11 and Spark < 3.0
27+
*/
28+
def transform(
29+
array: Column,
30+
f: Column => Column): Column = {
31+
throw new IllegalArgumentException("Array transformation is not available for Scala 2.11 and Spark < 3.0.")
32+
}
33+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright 2018 ABSA Group Limited
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package za.co.absa.cobrix.spark.cobol.utils.impl
18+
19+
import org.apache.spark.sql.Column
20+
import org.apache.spark.sql.functions.{transform => sparkTransform}
21+
22+
object HofsWrapper {
23+
/**
24+
* Applies the function `f` to every element in the `array`. The method is an equivalent to the `map` function
25+
* from functional programming.
26+
*
27+
* (The idea comes from https://github.com/AbsaOSS/spark-hats/blob/v0.3.0/src/main/scala_2.12/za/co/absa/spark/hats/HofsWrapper.scala)
28+
*
29+
* @param array A column of arrays
30+
* @param f A function transforming individual elements of the array
31+
* @return A column of arrays with transformed elements
32+
*/
33+
def transform(
34+
array: Column,
35+
f: Column => Column): Column = {
36+
sparkTransform(array, f)
37+
}
38+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright 2018 ABSA Group Limited
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package za.co.absa.cobrix.spark.cobol.utils.impl
18+
19+
import org.apache.spark.sql.Column
20+
import org.apache.spark.sql.functions.{transform => sparkTransform}
21+
22+
object HofsWrapper {
23+
/**
24+
* Applies the function `f` to every element in the `array`. The method is an equivalent to the `map` function
25+
* from functional programming.
26+
*
27+
* (The idea comes from https://github.com/AbsaOSS/spark-hats/blob/v0.3.0/src/main/scala_2.12/za/co/absa/spark/hats/HofsWrapper.scala)
28+
*
29+
* @param array A column of arrays
30+
* @param f A function transforming individual elements of the array
31+
* @return A column of arrays with transformed elements
32+
*/
33+
def transform(
34+
array: Column,
35+
f: Column => Column): Column = {
36+
sparkTransform(array, f)
37+
}
38+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright 2018 ABSA Group Limited
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package za.co.absa.cobrix.spark.cobol.source.fixtures
18+
19+
import org.scalatest.{Assertion, Suite}
20+
21+
trait TextComparisonFixture {
22+
this: Suite =>
23+
24+
protected def compareText(actual: String, expected: String): Assertion = {
25+
if (actual.replaceAll("[\r\n]", "") != expected.replaceAll("[\r\n]", "")) {
26+
fail(renderTextDifference(actual, expected))
27+
} else {
28+
succeed
29+
}
30+
}
31+
32+
protected def compareTextVertical(actual: String, expected: String): Unit = {
33+
if (actual.replaceAll("[\r\n]", "") != expected.replaceAll("[\r\n]", "")) {
34+
fail(s"ACTUAL:\n$actual\nEXPECTED: \n$expected")
35+
}
36+
}
37+
38+
protected def renderTextDifference(textActual: String, textExpected: String): String = {
39+
val t1 = textActual.replaceAll("\\r\\n", "\\n").split('\n')
40+
val t2 = textExpected.replaceAll("\\r\\n", "\\n").split('\n')
41+
42+
val maxLen = Math.max(getMaxStrLen(t1), getMaxStrLen(t2))
43+
val header = s" ${rightPad("ACTUAL:", maxLen)} ${rightPad("EXPECTED:", maxLen)}\n"
44+
45+
val stringBuilder = new StringBuilder
46+
stringBuilder.append(header)
47+
48+
val linesCount = Math.max(t1.length, t2.length)
49+
var i = 0
50+
while (i < linesCount) {
51+
val a = if (i < t1.length) t1(i) else ""
52+
val b = if (i < t2.length) t2(i) else ""
53+
54+
val marker1 = if (a != b) ">" else " "
55+
val marker2 = if (a != b) "<" else " "
56+
57+
val comparisonText = s"$marker1${rightPad(a, maxLen)} ${rightPad(b, maxLen)}$marker2\n"
58+
stringBuilder.append(comparisonText)
59+
60+
i += 1
61+
}
62+
63+
val footer = s"\nACTUAL:\n$textActual"
64+
stringBuilder.append(footer)
65+
stringBuilder.toString()
66+
}
67+
68+
def getMaxStrLen(text: Seq[String]): Int = {
69+
if (text.isEmpty) {
70+
0
71+
} else {
72+
text.maxBy(_.length).length
73+
}
74+
}
75+
76+
def rightPad(s: String, length: Int): String = {
77+
if (s.length < length) {
78+
s + " " * (length - s.length)
79+
} else if (s.length > length) {
80+
s.take(length)
81+
} else {
82+
s
83+
}
84+
}
85+
}

spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616

1717
package za.co.absa.cobrix.spark.cobol.utils
1818

19-
import org.apache.spark.sql.types.{ArrayType, LongType, MetadataBuilder, StringType, StructField, StructType}
19+
import org.apache.spark.sql.types._
2020
import org.scalatest.funsuite.AnyFunSuite
21-
import za.co.absa.cobrix.spark.cobol.source.base.SparkTestBase
2221
import org.slf4j.LoggerFactory
23-
import za.co.absa.cobrix.spark.cobol.source.fixtures.BinaryFileFixture
22+
import za.co.absa.cobrix.spark.cobol.source.base.SparkTestBase
23+
import za.co.absa.cobrix.spark.cobol.source.fixtures.{BinaryFileFixture, TextComparisonFixture}
2424
import za.co.absa.cobrix.spark.cobol.utils.TestUtils._
2525

2626
import java.nio.charset.StandardCharsets
27-
import scala.collection.immutable
27+
import scala.util.Properties
2828

29-
class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixture {
29+
class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixture with TextComparisonFixture {
3030

3131
import spark.implicits._
3232

@@ -377,7 +377,7 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
377377
assert(dfFlattened.count() == 0)
378378
}
379379

380-
test("Schema with multiple OCCURS should properly determine array sized") {
380+
test("Schema with multiple OCCURS should properly determine array sizes") {
381381
val copyBook: String =
382382
""" 01 RECORD.
383383
| 02 COUNT PIC 9(1).
@@ -429,6 +429,46 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
429429
}
430430
}
431431

432+
test("Integral to decimal conversion for complex schema") {
433+
val expectedSchema =
434+
"""|root
435+
| |-- COUNT: decimal(1,0) (nullable = true)
436+
| |-- GROUP: array (nullable = true)
437+
| | |-- element: struct (containsNull = false)
438+
| | | |-- INNER_COUNT: decimal(1,0) (nullable = true)
439+
| | | |-- INNER_GROUP: array (nullable = true)
440+
| | | | |-- element: struct (containsNull = false)
441+
| | | | | |-- FIELD: decimal(1,0) (nullable = true)
442+
|""".stripMargin
443+
444+
val copyBook: String =
445+
""" 01 RECORD.
446+
| 02 COUNT PIC 9(1).
447+
| 02 GROUP OCCURS 2 TIMES.
448+
| 03 INNER-COUNT PIC S9(1).
449+
| 03 INNER-GROUP OCCURS 3 TIMES.
450+
| 04 FIELD PIC 9.
451+
|""".stripMargin
452+
453+
withTempTextFile("fletten", "test", StandardCharsets.UTF_8, "") { filePath =>
454+
val df = spark.read
455+
.format("cobol")
456+
.option("copybook_contents", copyBook)
457+
.option("pedantic", "true")
458+
.option("record_format", "D")
459+
.option("metadata", "extended")
460+
.load(filePath)
461+
462+
if (!Properties.versionString.startsWith("2.")) {
463+
// This method only works with Scala 2.12+ and Spark 3.0+
464+
val actualDf = SparkUtils.covertIntegralToDecimal(df)
465+
val actualSchema = actualDf.schema.treeString
466+
467+
compareText(actualSchema, expectedSchema)
468+
}
469+
}
470+
}
471+
432472
private def assertSchema(actualSchema: String, expectedSchema: String): Unit = {
433473
if (actualSchema != expectedSchema) {
434474
logger.error(s"EXPECTED:\n$expectedSchema")

0 commit comments

Comments
 (0)