Skip to content

Commit 95c483e

Browse files
dsamaeydsamaeypomadchin
authored
Add predictor 2 (integer) and predictor 3 (float) support for writing compressed GTiff files (#3588)
* feat: add predictor support for GTiff writing * feat: add predictor support for GTiff writing # #3587 * feat: add predictor support for GTiff writing # #3587 * Minor code cleanup --------- Co-authored-by: dsamaey <[email protected]> Co-authored-by: Grigory Pomadchin <[email protected]>
1 parent 6165638 commit 95c483e

File tree

10 files changed

+406
-60
lines changed

10 files changed

+406
-60
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99
### Added
1010
- Add ZStd compression support for GTiff [#3580](https://github.com/locationtech/geotrellis/pull/3580)
1111
- Do not depend on private Spark API, avoids sealing violation [#3586](https://github.com/locationtech/geotrellis/pull/3586)
12+
- Add predictor 2 (integer) and predictor 3 (float) support for writing compressed GTiff files [#3588](https://github.com/locationtech/geotrellis/pull/3588)
1213

1314
## [3.8.0] - 2025-04-23
1415

raster/src/main/scala/geotrellis/raster/io/geotiff/compression/Compression.scala

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,30 @@ package geotrellis.raster.io.geotiff.compression
1818

1919
import io.circe._
2020
import io.circe.syntax._
21-
import cats.syntax.either._
2221

2322
trait Compression extends Serializable {
2423
def createCompressor(segmentCount: Int): Compressor
24+
25+
def withPredictor(predictor: Predictor): Compression =
26+
(segmentCount: Int) => createCompressor(segmentCount).withPredictorEncoding(predictor)
2527
}
2628

2729
object Compression {
2830
implicit val compressionDecoder: Decoder[Compression] =
29-
new Decoder[Compression] {
30-
final def apply(c: HCursor): Decoder.Result[Compression] = {
31-
c.downField("compressionType").as[String].map {
32-
case "NoCompression" => NoCompression
33-
case _ =>
34-
c.downField("level").as[Int] match {
35-
case Left(_) => DeflateCompression()
36-
case Right(i) => DeflateCompression(i)
37-
}
38-
}
31+
(c: HCursor) =>
32+
c.downField("compressionType").as[String].map {
33+
case "NoCompression" => NoCompression
34+
case _ =>
35+
c.downField("level").as[Int] match {
36+
case Left(_) => DeflateCompression()
37+
case Right(i) => DeflateCompression(i)
38+
}
3939
}
40-
}
4140

42-
implicit val compressionEncoder: Encoder[Compression] =
43-
new Encoder[Compression] {
44-
final def apply(a: Compression): Json = a match {
45-
case NoCompression =>
46-
Json.obj(("compressionType", "NoCompression".asJson))
47-
case d: DeflateCompression =>
48-
Json.obj(("compressionType", "Deflate".asJson), ("level", d.level.asJson))
49-
}
50-
}
41+
implicit val compressionEncoder: Encoder[Compression] = {
42+
case NoCompression =>
43+
Json.obj(("compressionType", "NoCompression".asJson))
44+
case d: DeflateCompression =>
45+
Json.obj(("compressionType", "Deflate".asJson), ("level", d.level.asJson))
46+
}
5147
}

raster/src/main/scala/geotrellis/raster/io/geotiff/compression/Compressor.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,23 @@
1616

1717
package geotrellis.raster.io.geotiff.compression
1818

19+
import geotrellis.raster.io.geotiff.tags.TiffTags
20+
1921
trait Compressor extends Serializable {
2022
def compress(bytes: Array[Byte], segmentIndex: Int): Array[Byte]
2123

2224
/** Returns the decompressor that can decompress the segments compressed by this compressor */
2325
def createDecompressor(): Decompressor
26+
27+
def withPredictorEncoding(predictor: Predictor): Compressor =
28+
new Compressor {
29+
def wrapped: Compressor = Compressor.this
30+
31+
def compress(bytes: Array[Byte], segmentIndex: Int): Array[Byte] =
32+
wrapped.compress(predictor.encode(bytes, segmentIndex), segmentIndex = segmentIndex)
33+
34+
/** Returns the decompressor that can decompress the segments compressed by this compressor */
35+
def createDecompressor(): Decompressor = wrapped.createDecompressor().withPredictorDecoding(predictor)
36+
}
37+
2438
}

raster/src/main/scala/geotrellis/raster/io/geotiff/compression/Decompressor.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,18 @@ trait Decompressor extends Serializable {
3434
*/
3535
def flipEndian(bytesPerFlip: Int): Decompressor =
3636
new Decompressor {
37-
def code = Decompressor.this.code
38-
override def predictorCode = Decompressor.this.predictorCode
37+
def code: Int = Decompressor.this.code
38+
override def predictorCode: Int = Decompressor.this.predictorCode
3939

4040
override
41-
def byteOrder = ByteOrder.LITTLE_ENDIAN // Since we have to flip, image data is in Little Endian
41+
def byteOrder: ByteOrder = ByteOrder.LITTLE_ENDIAN // Since we have to flip, image data is in Little Endian
4242

4343
def decompress(bytes: Array[Byte], segmentIndex: Int): Array[Byte] =
4444
flip(Decompressor.this.decompress(bytes, segmentIndex))
4545

4646
def flip(bytes: Array[Byte]): Array[Byte] = {
4747
val arr = bytes.clone
48-
val size = arr.size
48+
val size = arr.length
4949

5050
var i = 0
5151
while (i < size) {
@@ -62,14 +62,14 @@ trait Decompressor extends Serializable {
6262
}
6363
}
6464

65-
def withPredictor(predictor: Predictor): Decompressor =
65+
def withPredictorDecoding(predictor: Predictor): Decompressor =
6666
new Decompressor {
67-
def code = Decompressor.this.code
68-
override def predictorCode = predictor.code
69-
override def byteOrder = Decompressor.this.byteOrder
67+
def code: Int = Decompressor.this.code
68+
override def predictorCode: Int = predictor.code
69+
override def byteOrder: ByteOrder = Decompressor.this.byteOrder
7070

7171
def decompress(bytes: Array[Byte], segmentIndex: Int): Array[Byte] =
72-
predictor(Decompressor.this.decompress(bytes, segmentIndex), segmentIndex)
72+
predictor.decode(Decompressor.this.decompress(bytes, segmentIndex), segmentIndex)
7373
}
7474
}
7575

@@ -88,9 +88,9 @@ object Decompressor {
8888
def checkPredictor(d: Decompressor): Decompressor = {
8989
val predictor = Predictor(tiffTags)
9090
if(predictor.checkEndian)
91-
checkEndian(d).withPredictor(predictor)
91+
checkEndian(d).withPredictorDecoding(predictor)
9292
else
93-
d.withPredictor(predictor)
93+
d.withPredictorDecoding(predictor)
9494
}
9595

9696
val segmentCount = tiffTags.segmentCount
@@ -108,7 +108,7 @@ object Decompressor {
108108

109109
tiffTags.compression match {
110110
case Uncompressed =>
111-
checkEndian(NoCompression)
111+
checkEndian(NoCompressor)
112112
case LZWCoded =>
113113
checkPredictor(LZWDecompressor(segmentSizes))
114114
case ZLibCoded | PkZipCoded =>

raster/src/main/scala/geotrellis/raster/io/geotiff/compression/FloatingPointPredictor.scala

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,20 @@ import spire.syntax.cfor._
2222

2323
/** See TIFF Technical Note 3 */
2424
object FloatingPointPredictor {
25+
26+
def apply(imageData: GeoTiffImageData): Predictor = {
27+
val colsPerRow = Predictor.colsPerRow(imageData)
28+
val rowsInSegment = Predictor.rowsInSegment(imageData)
29+
30+
if (imageData.segmentLayout.hasPixelInterleave)
31+
new FloatingPointPredictor(colsPerRow, rowsInSegment, imageData.bandType, imageData.bandCount)
32+
else
33+
new FloatingPointPredictor(colsPerRow, rowsInSegment, imageData.bandType, 1)
34+
}
35+
2536
def apply(tiffTags: TiffTags): Predictor = {
2637
val colsPerRow = tiffTags.rowSize
27-
val rowsInSegment: (Int => Int) = { i => tiffTags.rowsInSegment(i) }
38+
val rowsInSegment: Int => Int = { i => tiffTags.rowsInSegment(i) }
2839

2940
val bandType = tiffTags.bandType
3041

@@ -36,10 +47,45 @@ object FloatingPointPredictor {
3647
}
3748

3849
private class FloatingPointPredictor(colsPerRow: Int, rowsInSegment: Int => Int, bandType: BandType, bandCount: Int) extends Predictor {
39-
val code = Predictor.PREDICTOR_FLOATINGPOINT
50+
val code: Int = Predictor.PREDICTOR_FLOATINGPOINT
4051
val checkEndian = false
4152

42-
def apply(bytes: Array[Byte], segmentIndex: Int): Array[Byte] = {
53+
private def encodeDeltaBytes(bytes: Array[Byte], rows:Int): Array[Byte] = {
54+
val bytesPerSample = bandType.bytesPerSample
55+
val colValuesPerRow = colsPerRow * bandCount
56+
val the_cols = colsPerRow * bytesPerSample
57+
val bytesPerRow = colValuesPerRow * bytesPerSample
58+
59+
cfor(0)(_ < rows, _ + 1) { row =>
60+
val rowOffset = row * bytesPerRow
61+
cfor(the_cols-1)(_ > 0, _ - 1) { col =>
62+
cfor(0)(_ < bandCount, _ + 1) { band =>
63+
bytes(rowOffset + col * bandCount + band) = (bytes(rowOffset + col * bandCount + band) - bytes(rowOffset + (col - 1) * bandCount + band)).toByte
64+
}
65+
}
66+
}
67+
bytes
68+
}
69+
70+
def encode(bytes: Array[Byte], segmentIndex: Int): Array[Byte] = {
71+
val rows = rowsInSegment(segmentIndex)
72+
val bytesPerSample = bandType.bytesPerSample
73+
val bytesPerRow = colsPerRow * bandCount * bytesPerSample
74+
75+
val outputBytes = new Array[Byte](bytes.length)
76+
val rowIncrement = colsPerRow * bandCount
77+
cfor(0)(_ < rows, _ + 1) { row =>
78+
val rowOffset = bytesPerRow * row
79+
cfor(0)(_ < rowIncrement, _ + 1) { col =>
80+
cfor(0)(_ < bytesPerSample, _ + 1) { b =>
81+
outputBytes(rowOffset + b * rowIncrement + col) = bytes(rowOffset + bytesPerSample * col + b)
82+
}
83+
}
84+
}
85+
encodeDeltaBytes(outputBytes, rows)
86+
}
87+
88+
override def decode(bytes: Array[Byte], segmentIndex: Int): Array[Byte] = {
4389
val rows = rowsInSegment(segmentIndex)
4490
val stride = bandCount
4591
val bytesPerSample = bandType.bytesPerSample

raster/src/main/scala/geotrellis/raster/io/geotiff/compression/HorizontalPredictor.scala

Lines changed: 100 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,31 @@ import java.nio.ByteBuffer
2424
import spire.syntax.cfor._
2525

2626
object HorizontalPredictor {
27+
28+
def apply(imageData: GeoTiffImageData): Predictor = {
29+
val colsPerRow = Predictor.colsPerRow(imageData)
30+
val rowsInSegment = Predictor.rowsInSegment(imageData)
31+
32+
val bandType = imageData.bandType
33+
34+
val predictor =
35+
if (imageData.segmentLayout.hasPixelInterleave)
36+
new HorizontalPredictor(colsPerRow, rowsInSegment, imageData.bandCount)
37+
else
38+
new HorizontalPredictor(colsPerRow, rowsInSegment, 1)
39+
40+
predictor.forBandType(bandType)
41+
}
42+
43+
2744
def apply(tiffTags: TiffTags): Predictor = {
2845
val colsPerRow = tiffTags.rowSize
29-
val rowsInSegment: (Int => Int) = { i => tiffTags.rowsInSegment(i) }
46+
val rowsInSegment: Int => Int = { i => tiffTags.rowsInSegment(i) }
3047

3148
val bandType = tiffTags.bandType
3249

3350
val predictor =
34-
if(tiffTags.hasPixelInterleave) {
51+
if (tiffTags.hasPixelInterleave) {
3552
new HorizontalPredictor(colsPerRow, rowsInSegment, tiffTags.bandCount)
3653
} else {
3754
new HorizontalPredictor(colsPerRow, rowsInSegment, 1)
@@ -42,24 +59,55 @@ object HorizontalPredictor {
4259

4360
private class HorizontalPredictor(cols: Int, rowsInSegment: Int => Int, bandCount: Int) {
4461
def forBandType(bandType: BandType): Predictor = {
45-
val applyFunc: (Array[Byte], Int) => Array[Byte] =
62+
val encodeFunc: (Array[Byte], Int) => Array[Byte] =
4663
bandType.bitsPerSample match {
47-
case 8 => apply8 _
48-
case 16 => apply16 _
49-
case 32 => apply32 _
64+
case 8 => encode8
65+
case 16 => encode16
66+
case 32 => encode32
5067
case _ =>
5168
throw new MalformedGeoTiffException(s"""Horizontal differencing "Predictor" not supported with ${bandType.bitsPerSample} bits per sample""")
5269
}
5370

71+
val decodeFunc: (Array[Byte], Int) => Array[Byte] =
72+
bandType.bitsPerSample match {
73+
case 8 => decode8
74+
case 16 => decode16
75+
case 32 => decode32
76+
case _ =>
77+
throw new MalformedGeoTiffException(s"""Horizontal differencing "Predictor" not supported with ${bandType.bitsPerSample} bits per sample""")
78+
}
79+
80+
5481
new Predictor {
55-
val code = Predictor.PREDICTOR_HORIZONTAL
82+
val code: Int = Predictor.PREDICTOR_HORIZONTAL
5683
val checkEndian = true
57-
def apply(bytes: Array[Byte], segmentIndex: Int): Array[Byte] =
58-
applyFunc(bytes, segmentIndex)
84+
85+
def encode(bytes: Array[Byte], segmentIndex: Int): Array[Byte] =
86+
encodeFunc(bytes, segmentIndex)
87+
88+
def decode(bytes: Array[Byte], segmentIndex: Int): Array[Byte] =
89+
decodeFunc(bytes, segmentIndex)
90+
}
91+
}
92+
93+
def encode8(bytes: Array[Byte], segmentIndex: Int): Array[Byte] = {
94+
val encodedBytes = new Array[Byte](bytes.length)
95+
val rows = rowsInSegment(segmentIndex)
96+
val n = bytes.length / rows
97+
98+
cfor(0)(_ < rows, _ + 1) { row =>
99+
cfor(n - 1)({ k => k >= 0 }, _ - 1) { col =>
100+
val index = row * n + col
101+
if (col < bandCount)
102+
encodedBytes(index) = bytes(index)
103+
else
104+
encodedBytes(index) = (bytes(index) - bytes(index - bandCount)).toByte
105+
}
59106
}
107+
encodedBytes
60108
}
61109

62-
def apply8(bytes: Array[Byte], segmentIndex: Int): Array[Byte] = {
110+
def decode8(bytes: Array[Byte], segmentIndex: Int): Array[Byte] = {
63111
val rows = rowsInSegment(segmentIndex)
64112

65113
cfor(0)(_ < rows, _ + 1) { row =>
@@ -69,11 +117,30 @@ object HorizontalPredictor {
69117
count += 1
70118
}
71119
}
72-
73120
bytes
74121
}
75122

76-
def apply16(bytes: Array[Byte], segmentIndex: Int): Array[Byte] = {
123+
def encode16(bytes: Array[Byte], segmentIndex: Int): Array[Byte] = {
124+
val encodedBytes = new Array[Byte](bytes.length)
125+
val encodedBuffer = ByteBuffer.wrap(encodedBytes).asShortBuffer
126+
val buffer = ByteBuffer.wrap(bytes).asShortBuffer
127+
128+
val rows = rowsInSegment(segmentIndex)
129+
val n = bytes.length / (rows * 2)
130+
131+
cfor(0)(_ < rows, _ + 1) { row =>
132+
cfor(n - 1)({ k => k >= 0 }, _ - 1) { col =>
133+
val index = row * n + col
134+
if (col < bandCount)
135+
encodedBuffer.put(index, buffer.get(index))
136+
else
137+
encodedBuffer.put(index, (buffer.get(index) - buffer.get(index - bandCount)).toShort)
138+
}
139+
}
140+
encodedBytes
141+
}
142+
143+
def decode16(bytes: Array[Byte], segmentIndex: Int): Array[Byte] = {
77144
val buffer = ByteBuffer.wrap(bytes).asShortBuffer
78145
val rows = rowsInSegment(segmentIndex)
79146

@@ -84,11 +151,30 @@ object HorizontalPredictor {
84151
count += 1
85152
}
86153
}
87-
88154
bytes
89155
}
90156

91-
def apply32(bytes: Array[Byte], segmentIndex: Int): Array[Byte] = {
157+
def encode32(bytes: Array[Byte], segmentIndex: Int): Array[Byte] = {
158+
val encodedBytes = new Array[Byte](bytes.length)
159+
val encodedBuffer = ByteBuffer.wrap(encodedBytes).asIntBuffer
160+
val buffer = ByteBuffer.wrap(bytes).asIntBuffer
161+
162+
val rows = rowsInSegment(segmentIndex)
163+
val n = bytes.length / (rows * 4)
164+
165+
cfor(0)(_ < rows, _ + 1) { row =>
166+
cfor(n - 1)({ k => k >= 0 }, _ - 1) { col =>
167+
val index = row * n + col
168+
if (col < bandCount)
169+
encodedBuffer.put(index, buffer.get(index))
170+
else
171+
encodedBuffer.put(index, buffer.get(index) - buffer.get(index - bandCount))
172+
}
173+
}
174+
encodedBytes
175+
}
176+
177+
def decode32(bytes: Array[Byte], segmentIndex: Int): Array[Byte] = {
92178
val buffer = ByteBuffer.wrap(bytes).asIntBuffer
93179
val rows = rowsInSegment(segmentIndex)
94180

@@ -99,7 +185,6 @@ object HorizontalPredictor {
99185
count += 1
100186
}
101187
}
102-
103188
bytes
104189
}
105190
}

0 commit comments

Comments
 (0)