Skip to content

Commit

Permalink
Merge pull request #246 from wri/gtc-2824a
Browse files Browse the repository at this point in the history
 GTC-2824 Use raster GADM layers for pro dashboard for small # of features
  • Loading branch information
danscales authored Aug 20, 2024
2 parents fc99e94 + d4d46f3 commit e791f15
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,72 @@ object GfwProDashboardAnalysis extends SummaryAnalysis {

val name = "gfwpro_dashboard"

/** Run the GFWPro dashboard analysis.
*
* If doGadmIntersect is true, read in the entire gadm feature dataset and
* intersect with the user feature list to determine the relevant gadm areas. If
* doGadmIntersect is false (usually for small number of user features), then
* determine the relevant gadm areas by using the raster gadm datasets GadmAdm0,
* GadmAdm1, and GadmAdm2.
*/
def apply(
featureRDD: RDD[ValidatedLocation[Geometry]],
featureType: String,
contextualFeatureType: String,
contextualFeatureUrl: NonEmptyList[String],
doGadmIntersect: Boolean,
gadmFeatureUrl: NonEmptyList[String],
fireAlertRDD: SpatialRDD[Geometry],
spark: SparkSession,
kwargs: Map[String, Any]
): ValidatedWorkflow[Location[JobError],(FeatureId, GfwProDashboardData)] = {
featureRDD.persist(StorageLevel.MEMORY_AND_DISK)

val summaryRDD = ValidatedWorkflow(featureRDD).flatMap { rdd =>
val spatialContextualDF = SpatialFeatureDF(contextualFeatureUrl, contextualFeatureType, FeatureFilter.empty, "geom", spark)
val spatialContextualRDD = Adapter.toSpatialRdd(spatialContextualDF, "polyshape")
val spatialFeatureRDD = RDDAdapter.toSpatialRDDfromLocationRdd(rdd, spark)
val enrichedRDD = if (doGadmIntersect) {
println("Doing intersect with vector gadm")
val spatialContextualDF = SpatialFeatureDF(gadmFeatureUrl, "gadm", FeatureFilter.empty, "geom", spark)
val spatialContextualRDD = Adapter.toSpatialRdd(spatialContextualDF, "polyshape")
val spatialFeatureRDD = RDDAdapter.toSpatialRDDfromLocationRdd(rdd, spark)

/* Enrich the feature RDD by intersecting it with contextual features
* The resulting FeatureId carries combined identity of source feature and contextual geometry
*/
val enrichedRDD =
/* Enrich the feature RDD by intersecting it with contextual features
* The resulting FeatureId carries combined identity of source feature and contextual geometry
*/
SpatialJoinRDD
.flatSpatialJoin(spatialContextualRDD, spatialFeatureRDD, considerBoundaryIntersection = true, usingIndex = true)
.rdd
.flatMap { case (feature, context) =>
refineContextualIntersection(feature, context, contextualFeatureType)
refineContextualIntersection(feature, context, "gadm")
}
} else {
println("Using raster gadm")
rdd.map {
case Location(CombinedFeatureId(id@GfwProFeatureId(listId, locationId), featureCentroid: PointFeatureId), geom) => {
if (locationId != -1) {
// For a non-dissolved location, determine the GadmFeatureId for the
// centroid of the location's geometry, and add that to the feature id.
// This can be expensive, since the tile reads are not cached. So, we
// we only use this raster GADM approach for user inputs with a small
// number of locations (e.g. <50). In that case, we get significant
// performance improvement by not having to read in the entire vector
// GADM file, but instead only reading the GADM raster tiles for the
// relevant areas.
val pt = featureCentroid.pt
val windowLayout = GfwProDashboardGrid.blockTileGrid
val key = windowLayout.mapTransform.keysForGeometry(pt).toList.head
val rasterSource = GfwProDashboardRDD.getSources(key, windowLayout, kwargs).getOrElse(null)
val raster = rasterSource.readWindow(key, windowLayout).getOrElse(null)
val re = raster.rasterExtent
val col = re.mapXToGrid(pt.getX())
val row = re.mapYToGrid(pt.getY())
Validated.valid[Location[JobError], Location[Geometry]](Location(CombinedFeatureId(id, GadmFeatureId(raster.tile.gadm0.getData(col, row),
raster.tile.gadm1.getData(col, row),
raster.tile.gadm2.getData(col, row))), geom))
} else {
// For a dissolved location, add a dummy GadmFeatureId to the feature id.
Validated.valid[Location[JobError], Location[Geometry]](Location(CombinedFeatureId(id, GadmFeatureId("X", 0, 0)), geom))
}
}
}
}

ValidatedWorkflow(enrichedRDD)
.mapValidToValidated { rdd =>
Expand All @@ -64,11 +104,32 @@ object GfwProDashboardAnalysis extends SummaryAnalysis {
.flatMap { enrichedRDD =>
val fireStatsRDD = fireStats(enrichedRDD, fireAlertRDD, spark)
val tmp = enrichedRDD.map { case Location(id, geom) => Feature(geom, id) }
val validatedSummaryStatsRdd = GfwProDashboardRDD(tmp, GfwProDashboardGrid.blockTileGrid, kwargs)
// This is where the main analysis happens, including calling
// GfwProDashboardSummary.getGridVisitor.visit on each pixel.
val validatedSummaryStatsRdd = GfwProDashboardRDD(tmp,
GfwProDashboardGrid.blockTileGrid,
kwargs + ("getRasterGadm" -> !doGadmIntersect))
ValidatedWorkflow(validatedSummaryStatsRdd).mapValid { summaryStatsRDD =>
// fold in fireStatsRDD after polygonal summary and accumulate the errors
summaryStatsRDD
.mapValues(_.toGfwProDashboardData())
.flatMap { case (CombinedFeatureId(fid@GfwProFeatureId(listId, locationId), gadmId), summary) =>
// For non-dissolved locations or vector gadm intersection, merge all
// summaries (ignoring any differing group_gadm_id), and move the
// gadmId from the featureId into the group_gadm_id. For dissolved
// locations for raster gadm, merge summaries into multiple rows
// based on the per-pixel group_gadm_id.
val ignoreRasterGadm = locationId != -1 || doGadmIntersect
summary.toGfwProDashboardData(ignoreRasterGadm).map( x => {
val newx = if (ignoreRasterGadm) {
x.copy(group_gadm_id = gadmId.toString)
} else {
x
}
Location(fid, newx)
}
)
case _ => throw new NotImplementedError("Missing case")
}
// fold in fireStatsRDD after polygonal summary and accumulate the errors
.leftOuterJoin(fireStatsRDD)
.mapValues { case (data, fire) =>
data.copy(viirs_alerts_daily = fire.getOrElse(GfwProDashboardDataDateCount.empty))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@ import org.locationtech.jts.geom.Geometry

object GfwProDashboardCommand extends SummaryCommand {

val contextualFeatureUrlOpt: Opts[NonEmptyList[String]] = Opts
val gadmFeatureUrl: Opts[NonEmptyList[String]] = Opts
.options[String](
"contextual_feature_url",
help = "URI of contextual features in TSV format"
"gadm_feature_url",
help = "URI of GADM features in TSV format"
)

val contextualFeatureTypeOpt: Opts[String] = Opts
.option[String](
"contextual_feature_type",
help = "Type of contextual features"
)
val gadmIntersectThreshold: Opts[Int] = Opts
.option[Int](
"gadm_intersect_threshold",
help = "Number of input features at which to intersect GADM"
).withDefault(50)

val gfwProDashboardCommand: Opts[Unit] = Opts.subcommand(
name = GfwProDashboardAnalysis.name,
Expand All @@ -32,10 +32,10 @@ object GfwProDashboardCommand extends SummaryCommand {
defaultOptions,
optionalFireAlertOptions,
featureFilterOptions,
contextualFeatureUrlOpt,
contextualFeatureTypeOpt,
gadmFeatureUrl,
gadmIntersectThreshold,
pinnedVersionsOpts
).mapN { (default, fireAlert, filterOptions, contextualFeatureUrl, contextualFeatureType, pinned) =>
).mapN { (default, fireAlert, filterOptions, gadmFeatureUrl, gadmIntersectThreshold, pinned) =>
val kwargs = Map(
"outputUrl" -> default.outputUrl,
"noOutputPathSuffix" -> default.noOutputPathSuffix,
Expand All @@ -58,11 +58,19 @@ object GfwProDashboardCommand extends SummaryCommand {
spatialRDD
}

val featureCount = featureRDD.count()
val doGadmIntersect = featureCount > gadmIntersectThreshold
if (doGadmIntersect) {
println(s"Intersecting vector gadm for feature count $featureCount")
} else {
println(s"Using raster gadm for feature count $featureCount")
}

val dashRDD = GfwProDashboardAnalysis(
featureRDD,
default.featureType,
contextualFeatureType = contextualFeatureType,
contextualFeatureUrl = contextualFeatureUrl,
doGadmIntersect,
gadmFeatureUrl,
fireAlertRDD,
spark,
kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,23 @@ package org.globalforestwatch.summarystats.gfwpro_dashboard

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.globalforestwatch.features.{CombinedFeatureId, FeatureId, GadmFeatureId, GfwProFeatureId}
import org.globalforestwatch.features.{FeatureId, GfwProFeatureId}
import org.globalforestwatch.summarystats._
import cats.data.Validated.{Valid, Invalid}
import org.apache.spark.sql.functions.expr
import org.globalforestwatch.summarystats.SummaryDF.RowId

object GfwProDashboardDF extends SummaryDF {
case class RowGadmId(list_id: String, location_id: String, gadm_id: String)

def getFeatureDataFrameFromVerifiedRdd(
dataRDD: RDD[ValidatedLocation[GfwProDashboardData]],
spark: SparkSession
): DataFrame = {
import spark.implicits._

val rowId: FeatureId => RowGadmId = {
case CombinedFeatureId(proId: GfwProFeatureId, gadmId: GadmFeatureId) =>
RowGadmId(proId.listId, proId.locationId.toString, gadmId.toString())
val rowId: FeatureId => RowId = {
case proId: GfwProFeatureId =>
RowGadmId(proId.listId, proId.locationId.toString, "none")
RowId(proId.listId, proId.locationId.toString)
case _ =>
throw new IllegalArgumentException("Not a CombinedFeatureId[GfwProFeatureId, GadmFeatureId]")
}
Expand All @@ -30,28 +29,8 @@ object GfwProDashboardDF extends SummaryDF {
(rowId(id), SummaryDF.RowError.fromJobError(err), GfwProDashboardData.empty)
}
.toDF("id", "error", "data")
.select($"id.*", $"error.*", $"data.*")
}

def getFeatureDataFrame(
dataRDD: RDD[(FeatureId, ValidatedRow[GfwProDashboardData])],
spark: SparkSession
): DataFrame = {
import spark.implicits._

dataRDD.mapValues {
case Valid(data) =>
(SummaryDF.RowError.empty, data)
case Invalid(err) =>
(SummaryDF.RowError.fromJobError(err), GfwProDashboardData.empty)
}.map {
case (CombinedFeatureId(proId: GfwProFeatureId, gadmId: GadmFeatureId), (error, data)) =>
val rowId = RowGadmId(proId.listId, proId.locationId.toString, gadmId.toString())
(rowId, error, data)
case _ =>
throw new IllegalArgumentException("Not a CombinedFeatureId[GfwProFeatureId, GadmFeatureId]")
}
.toDF("id", "error", "data")
.select($"id.*", $"error.*", $"data.*")
// Put data.group_gadm_id right after list/location and rename to gadm_id
.select($"id.*", expr("data.group_gadm_id as gadm_id"), $"error.*", $"data.*")
.drop($"group_gadm_id")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
* Note: This case class contains mutable values
*/
case class GfwProDashboardData(
// Relevant for dissolved locations (locationId == -1)
group_gadm_id: String,

/* NOTE: We are temporarily leaving the existing integrated alerts fields named as
* glad_alerts_*, in order to reduce the number of moving pieces as we move from
* Glad alerts to integrated alerts in GFWPro. */
Expand Down Expand Up @@ -48,6 +51,7 @@ case class GfwProDashboardData(

def merge(other: GfwProDashboardData): GfwProDashboardData = {
GfwProDashboardData(
if (group_gadm_id != "") group_gadm_id else other.group_gadm_id,
glad_alerts_coverage || other.glad_alerts_coverage,
integrated_alerts_coverage || other.integrated_alerts_coverage,
total_ha.merge(other.total_ha),
Expand All @@ -73,6 +77,7 @@ object GfwProDashboardData {

def empty: GfwProDashboardData =
GfwProDashboardData(
group_gadm_id = "",
glad_alerts_coverage = false,
integrated_alerts_coverage = false,
total_ha = ForestChangeDiagnosticDataDouble.empty,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package org.globalforestwatch.summarystats.gfwpro_dashboard

import cats.implicits._
import geotrellis.layer.{LayoutDefinition, SpatialKey}
import geotrellis.raster.Raster
import org.globalforestwatch.grids.{GridSources, GridTile}
Expand All @@ -14,30 +13,19 @@ case class GfwProDashboardGridSources(gridTile: GridTile, kwargs: Map[String, An
val treeCoverDensity2000 = TreeCoverDensityPercent2000(gridTile, kwargs)
val sbtnNaturalForest: SBTNNaturalForests = SBTNNaturalForests(gridTile, kwargs)
val jrcForestCover: JRCForestCover = JRCForestCover(gridTile, kwargs)
val gadmAdm0: GadmAdm0 = GadmAdm0(gridTile, kwargs)
val gadmAdm1: GadmAdm1 = GadmAdm1(gridTile, kwargs)
val gadmAdm2: GadmAdm2 = GadmAdm2(gridTile, kwargs)

def readWindow(
windowKey: SpatialKey,
windowLayout: LayoutDefinition
): Either[Throwable, Raster[GfwProDashboardTile]] = {

for {
// Integrated alerts are Optional Tiles, but we keep it this way to avoid signature changes
integratedAlertsTile <- Either
.catchNonFatal(integratedAlerts.fetchWindow(windowKey, windowLayout))
.right
tcd2000Tile <- Either
.catchNonFatal(treeCoverDensity2000.fetchWindow(windowKey, windowLayout))
.right
sbtnNaturalForestTile <- Either
.catchNonFatal(sbtnNaturalForest.fetchWindow(windowKey, windowLayout))
.right
jrcForestCoverTile <- Either
.catchNonFatal(jrcForestCover.fetchWindow(windowKey, windowLayout))
.right
} yield {
val tile = GfwProDashboardTile(integratedAlertsTile, tcd2000Tile, sbtnNaturalForestTile, jrcForestCoverTile)
Raster(tile, windowKey.extent(windowLayout))
}
val tile = GfwProDashboardTile(
windowKey, windowLayout, this
)
Right(Raster(tile, windowKey.extent(windowLayout)))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import java.time.LocalDate


case class GfwProDashboardRawDataGroup(
groupGadmId: String,
alertDateAndConf: Option[(LocalDate, Int)],
integratedAlertsCoverage: Boolean,
isNaturalForest: Boolean,
Expand All @@ -20,6 +21,7 @@ case class GfwProDashboardRawDataGroup(
}

GfwProDashboardData(
group_gadm_id = groupGadmId,
glad_alerts_coverage = integratedAlertsCoverage,
integrated_alerts_coverage = integratedAlertsCoverage,
glad_alerts_daily = GfwProDashboardDataDateCount.fillDaily(alertDate, true, alertCount),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,23 @@ case class GfwProDashboardSummary(
def isEmpty = stats.isEmpty

/** Pivot raw data to GfwProDashboardData and aggregate across alert dates. */
def toGfwProDashboardData(): GfwProDashboardData = {
stats
.map { case (group, data) => group.
toGfwProDashboardData(data.alertCount, data.treeCoverExtentArea) }
.foldLeft(GfwProDashboardData.empty)( _ merge _)
def toGfwProDashboardData(ignoreGadm: Boolean): List[GfwProDashboardData] = {
if (ignoreGadm) {
// Combine all GfwProDashboardData results ignoring different groupGadmIds.
List(stats
.map { case (group, data) => group.
toGfwProDashboardData(data.alertCount, data.treeCoverExtentArea) }
.foldLeft(GfwProDashboardData.empty)( _ merge _))
} else {
// Combine all GfwProDashboardData results into separate rows based on groupGadmI
stats
.groupBy { case(group, data) => group.groupGadmId }
.map { case(key, list) =>
list.map { case (group, data) => group.
toGfwProDashboardData(data.alertCount, data.treeCoverExtentArea) }
.foldLeft(GfwProDashboardData.empty)(_ merge _)
}.toList
}
}
}

Expand All @@ -51,7 +63,21 @@ object GfwProDashboardSummary {
val naturalForestCategory: String = raster.tile.sbtnNaturalForest.getData(col, row)
val jrcForestCover: Boolean = raster.tile.jrcForestCover.getData(col, row)

val groupKey = GfwProDashboardRawDataGroup(integratedAlertDateAndConf,
val gadmId: String = if (kwargs("getRasterGadm") == true) {
val gadmAdm0: String = raster.tile.gadm0.getData(col, row)
// Skip processing this pixel if gadmAdm0 is empty
if (gadmAdm0 == "") {
return
}
val gadmAdm1: Integer = raster.tile.gadm1.getData(col, row)
val gadmAdm2: Integer = raster.tile.gadm2.getData(col, row)
s"$gadmAdm0.$gadmAdm1.$gadmAdm2"
} else {
""
}


val groupKey = GfwProDashboardRawDataGroup(gadmId, integratedAlertDateAndConf,
integratedAlertCoverage,
naturalForestCategory == "Natural Forest",
jrcForestCover,
Expand Down
Loading

0 comments on commit e791f15

Please sign in to comment.