Skip to content

[SPARKNLP-1161] Adding features to PDF Reader #14596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
773 changes: 613 additions & 160 deletions examples/python/reader/SparkNLP_PDF_Reader_Demo.ipynb

Large diffs are not rendered by default.

35 changes: 34 additions & 1 deletion python/sparknlp/reader/pdf_to_text.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright 2017-2025 John Snow Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pyspark import keyword_only
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
Expand Down Expand Up @@ -89,6 +102,14 @@ class PdfToText(JavaTransformer, HasInputCol, HasOutputCol,
"Force to extract only number of pages",
typeConverter=TypeConverters.toBoolean)

extractCoordinates = Param(Params._dummy(), "extractCoordinates",
"Force extract coordinates of text.",
typeConverter=TypeConverters.toBoolean)

normalizeLigatures = Param(Params._dummy(), "normalizeLigatures",
"Whether to convert ligature chars such as 'fl' into its corresponding chars (e.g., {'f', 'l'}).",
typeConverter=TypeConverters.toBoolean)

@keyword_only
def __init__(self):
"""
Expand Down Expand Up @@ -154,4 +175,16 @@ def setSort(self, value):
"""
Sets the value of :py:attr:`sort`.
"""
return self._set(sort=value)
return self._set(sort=value)

def setExtractCoordinates(self, value):
"""
Sets the value of :py:attr:`extractCoordinates`.
"""
return self._set(extractCoordinates=value)

def setNormalizeLigatures(self, value):
"""
Sets the value of :py:attr:`normalizeLigatures`.
"""
return self._set(normalizeLigatures=value)
146 changes: 119 additions & 27 deletions src/main/scala/com/johnsnowlabs/reader/PdfToText.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,20 @@ package com.johnsnowlabs.reader
import com.johnsnowlabs.nlp.IAnnotation
import com.johnsnowlabs.reader.util.HasPdfProperties
import com.johnsnowlabs.reader.util.pdf._
import com.johnsnowlabs.reader.util.pdf.schema.{MappingMatrix, PageMatrix}
import org.apache.pdfbox.pdmodel.PDDocument
import org.apache.pdfbox.text.PDFTextStripper
import org.apache.spark.internal.Logging
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{col, posexplode_outer, udf}
import org.apache.spark.sql.functions.{col, lit, posexplode_outer, udf}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset}

import java.io.ByteArrayOutputStream
import java.io.{ByteArrayOutputStream, PrintWriter, StringWriter}
import scala.util.{Failure, Success, Try}

/** Extract text from PDF document to a single string or to several strings per each page. Input
Expand Down Expand Up @@ -104,6 +105,7 @@ class PdfToText(override val uid: String)

protected def outputDataType: StructType = new StructType()
.add($(outputCol), StringType)
.add("positions", PageMatrix.dataType)
.add("height_dimension", IntegerType)
.add("width_dimension", IntegerType)
.add($(inputCol), BinaryType)
Expand All @@ -128,25 +130,34 @@ class PdfToText(override val uid: String)
setDefault(inputCol -> "content", outputCol -> "text")

private def transformUDF: UserDefinedFunction = udf(
(path: String, content: Array[Byte]) => {
doProcess(content)
(path: String, content: Array[Byte], exception: String) => {
doProcess(content, exception)
},
ArrayType(outputDataType))

private def doProcess(
content: Array[Byte]): Seq[(String, Int, Int, Array[Byte], String, Int)] = {
content: Array[Byte],
exception: String): Seq[(String, Seq[PageMatrix], Int, Int, Array[Byte], String, Int)] = {
val pagesTry = Try(
pdfToText(
content,
$(onlyPageNum),
$(splitPage),
$(storeSplittedPdf),
$(sort),
$(textStripper)))
$(textStripper),
$(extractCoordinates),
$(normalizeLigatures)))

pagesTry match {
case Failure(_) =>
Seq()
case Failure(e) =>
log.error("Pdf load error during text extraction")
val sw = new StringWriter
e.printStackTrace(new PrintWriter(sw))
log.error(sw.toString)
log.error(pagesTry.toString)
val errMessage = e.toString + " " + e.getMessage
Seq(("", Seq(), -1, -1, Array(), exception.concatException(s"PdfToText: $errMessage"), 0))
case Success(content) =>
content
}
Expand All @@ -157,7 +168,15 @@ class PdfToText(override val uid: String)

val selCols1 = df.columns
.filterNot(_ == $(inputCol))
.map(col) :+ posexplode_outer(transformUDF(df.col($(originCol)), df.col($(inputCol))))
.map(col) :+ posexplode_outer(
transformUDF(
df.col($(originCol)),
df.col($(inputCol)),
if (df.columns.contains("exception")) {
col("exception")
} else {
lit(null)
}))
.as(Seq("tmp_num", "tmp_result"))
val selCols = df.columns
.filterNot(_ == $(inputCol))
Expand All @@ -179,16 +198,25 @@ class PdfToText(override val uid: String)
getOrDefault(inputCol),
throw new RuntimeException(s"Column not found ${getOrDefault(inputCol)}"))

pdfs flatMap { case BinaryFile(bytes, path) =>
doProcess(bytes).zipWithIndex.map { case ((text, _, _, content, exception, _), pageNum) =>
val metadata =
Map("exception" -> exception, "sourcePath" -> path, "pageNum" -> pageNum.toString)
pdfs flatMap {
case BinaryFile(bytes, path) =>
doProcess(bytes, path).zipWithIndex.map {
case ((text, pageMatrix, _, _, content, exception, _), pageNum) =>
val metadata =
Map("exception" -> exception, "sourcePath" -> path, "pageNum" -> pageNum.toString)

val result = lightRecord ++ Map(
getOutputCol -> Seq(OcrText(text, metadata, content)),
getOrDefault(pageNumCol) -> Seq(PageNum(pageNum)))

if ($(extractCoordinates))
result ++ Map("positions" -> pageMatrix.map(pm => PositionsOutput(pm.mapping)))
else
result

val result = lightRecord ++ Map(
getOutputCol -> Seq(OcrText(text, metadata, content)),
getOrDefault(pageNumCol) -> Seq(PageNum(pageNum)))
result
}
case _ => lightRecord.chainExceptions(s"Wrong Input in $uid")
}
case _ => Seq(lightRecord.chainExceptions(s"Wrong Input in $uid"))
}
}
}
Expand Down Expand Up @@ -224,11 +252,14 @@ trait PdfToTextTrait extends Logging with PdfUtils {
splitPage: Boolean,
storeSplittedPdf: Boolean,
sort: Boolean,
textStripper: String): Seq[(String, Int, Int, Array[Byte], String, Int)] = {
textStripper: String,
extractCoordinates: Boolean,
normalizeLigatures: Boolean = false)
: Seq[(String, Seq[PageMatrix], Int, Int, Array[Byte], String, Int)] = {
val validPdf = checkAndFixPdf(content)
val pdfDoc = PDDocument.load(validPdf)
val numPages = pdfDoc.getNumberOfPages
log.info(s"Number of pages ${numPages}")
log.info(s"Number of pages $numPages")
require(numPages >= 1, "pdf input stream cannot be empty")
val result = if (!onlyPageNum) {
pdfboxMethod(
Expand All @@ -239,9 +270,11 @@ trait PdfToTextTrait extends Logging with PdfUtils {
splitPage,
storeSplittedPdf,
sort,
textStripper)
textStripper,
extractCoordinates,
normalizeLigatures = normalizeLigatures)
} else {
Range(1, numPages + 1).map(pageNum => ("", 1, 1, null, null, pageNum))
Range(1, numPages + 1).map(pageNum => ("", null, 1, 1, null, null, pageNum))
}
pdfDoc.close()
log.info("Close pdf")
Expand All @@ -256,9 +289,13 @@ trait PdfToTextTrait extends Logging with PdfUtils {
splitPage: Boolean,
storeSplittedPdf: Boolean,
sort: Boolean,
textStripper: String): Seq[(String, Int, Int, Array[Byte], String, Int)] = {
textStripper: String,
extractCoordinates: Boolean,
normalizeCoordinates: Boolean = true,
normalizeLigatures: Boolean = false)
: Seq[(String, Seq[PageMatrix], Int, Int, Array[Byte], String, Int)] = {
lazy val out: ByteArrayOutputStream = new ByteArrayOutputStream()
if (splitPage)
if (splitPage) {
Range(startPage, endPage + 1).flatMap(pagenum =>
extractText(pdfDoc, pagenum, pagenum, sort, textStripper)
.map { text =>
Expand All @@ -271,23 +308,78 @@ trait PdfToTextTrait extends Logging with PdfUtils {
outputDocument.close()
out.toByteArray
} else null
val coordinates =
if (extractCoordinates)
getCoordinates(pdfDoc, pagenum, pagenum, normalizeCoordinates, normalizeLigatures)
else null
(
text,
coordinates,
page.getMediaBox.getHeight.toInt,
page.getMediaBox.getWidth.toInt,
splittedPdf,
null,
pagenum)
})
else {
} else {
val text = extractText(pdfDoc, startPage, endPage, sort, textStripper).mkString(
System.lineSeparator())
val heightDimension = pdfDoc.getPage(startPage).getMediaBox.getHeight.toInt
val widthDimension = pdfDoc.getPage(startPage).getMediaBox.getWidth.toInt
val coordinates =
if (extractCoordinates)
getCoordinates(pdfDoc, startPage, endPage, normalizeCoordinates, normalizeLigatures)
else null
Seq(
(text, heightDimension, widthDimension, if (storeSplittedPdf) content else null, null, 0))
(
text,
coordinates,
heightDimension,
widthDimension,
if (storeSplittedPdf) content else null,
null,
0))
}
}

private def getCoordinates(
doc: => PDDocument,
startPage: Int,
endPage: Int,
normalizeOutput: Boolean = true,
normalizeLigatures: Boolean = true): Seq[PageMatrix] = {
import scala.collection.JavaConverters._
val unicodeUtils = new UnicodeUtils
Range(startPage, endPage + 1).map(pagenum => {
val (_, pHeight) = getPageDims(pagenum, doc)
val stripper = new CustomStripper
stripper.setStartPage(pagenum + 1)
stripper.setEndPage(pagenum + 1)
stripper.getText(doc)
val line = stripper.lines.asScala.flatMap(_.textPositions.asScala)

val mappings = line.toArray.map(p => {
MappingMatrix(
p.toString,
p.getTextMatrix.getTranslateX,
if (normalizeOutput) pHeight - p.getTextMatrix.getTranslateY - p.getHeightDir
else p.getTextMatrix.getTranslateY,
p.getWidth,
p.getHeightDir,
0,
"pdf")
})

val coordinates =
if (normalizeLigatures) unicodeUtils.normalizeLigatures(mappings) else mappings
PageMatrix(coordinates)
})
}

private def getPageDims(numPage: Int, document: PDDocument) = {
val page = document.getPage(numPage).getMediaBox
(page.getWidth, page.getHeight)
}
}

object PdfToText extends DefaultParamsReadable[PdfToText] {
Expand Down
51 changes: 23 additions & 28 deletions src/main/scala/com/johnsnowlabs/reader/SparkNLPReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ class SparkNLPReader(
.setOnlyPageNum(getOnlyPageNum)
.setTextStripper(getTextStripper)
.setSort(getSort)
.setExtractCoordinates(getExtractCoordinates)
.setNormalizeLigatures(getNormalizeLigatures)
val binaryPdfDF = spark.read.format("binaryFile").load(pdfPath)
val pipelineModel = new Pipeline()
.setStages(Array(pdfToText))
Expand All @@ -326,43 +328,36 @@ class SparkNLPReader(
}

private def getSplitPage: Boolean = {
val splitPage =
try {
params.asScala.getOrElse("splitPage", "true").toBoolean
} catch {
case _: IllegalArgumentException => true
}
splitPage
getDefaultBoolean(params.asScala.toMap, Seq("splitPage", "split_page"), default = true)
}

private def getOnlyPageNum: Boolean = {
val splitPage =
try {
params.asScala.getOrElse("onlyPageNum", "false").toBoolean
} catch {
case _: IllegalArgumentException => false
}
splitPage
getDefaultBoolean(params.asScala.toMap, Seq("onlyPageNum", "only_page_num"), default = false)
}

private def getTextStripper: String = {
val textStripper =
try {
params.asScala.getOrElse("textStripper", TextStripperType.PDF_TEXT_STRIPPER)
} catch {
case _: IllegalArgumentException => TextStripperType.PDF_TEXT_STRIPPER
}
textStripper
getDefaultString(
params.asScala.toMap,
Seq("textStripper", "text_stripper"),
default = TextStripperType.PDF_TEXT_STRIPPER)
}

private def getSort: Boolean = {
val sort =
try {
params.asScala.getOrElse("sort", "false").toBoolean
} catch {
case _: IllegalArgumentException => false
}
sort
getDefaultBoolean(params.asScala.toMap, Seq("sort"), default = false)
}

private def getExtractCoordinates: Boolean = {
getDefaultBoolean(
params.asScala.toMap,
Seq("extractCoordinates", "extract_coordinates"),
default = false)
}

private def getNormalizeLigatures: Boolean = {
getDefaultBoolean(
params.asScala.toMap,
Seq("normalizeLigatures", "normalize_ligatures"),
default = true)
}

/** Instantiates class to read Excel files.
Expand Down
Loading