Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to wrap building detection input images in WarpedVRT to transform them to a unified CRS, resolution, and origin. #321

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions src/detect_buildings_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
from skai import detect_buildings
from skai import extract_tiles
from skai import read_raster
from skai import utils

import tensorflow as tf

Expand Down Expand Up @@ -138,15 +137,9 @@ def main(args):
gdf = gpd.read_file(f)
aoi = gdf.geometry.values[0]
gdal_env = read_raster.parse_gdal_env(FLAGS.gdal_env)
image_paths = utils.expand_file_patterns(FLAGS.image_paths)
for image_path in image_paths:
if not read_raster.raster_is_tiled(image_path):
raise ValueError(f'Raster "{image_path}" is not tiled.')

vrt_paths = read_raster.build_vrts(
image_paths, os.path.join(temp_dir, 'image'), 0.5, FLAGS.mosaic_images
vrt_paths = read_raster.prepare_building_detection_input_images(
FLAGS.image_paths, os.path.join(temp_dir, 'vrts'), gdal_env
)

tiles = []
for path in vrt_paths:
tiles.extend(
Expand Down
19 changes: 14 additions & 5 deletions src/skai/extract_tiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,19 @@ def get_tiles_for_aoi(image_path: str,

Yields:
A grid of tiles that covers the AOI.

Raises:
RuntimeError: If the image file does not exist.
"""
if not rasterio.shutil.exists(image_path):
raise RuntimeError(f'File {image_path} does not exist')

with rasterio.Env(**gdal_env):
image = rasterio.open(image_path)
x_min, y_min, x_max, y_max = _get_pixel_bounds_for_aoi(image, aoi)
yield from get_tiles(
image_path, x_min, y_min, x_max, y_max, tile_size, margin
)
x_min, y_min, x_max, y_max = _get_pixel_bounds_for_aoi(image, aoi)
yield from get_tiles(
image_path, x_min, y_min, x_max, y_max, tile_size, margin
)


class ExtractTilesAsExamplesFn(beam.DoFn):
Expand All @@ -227,7 +233,10 @@ def _get_raster(
raster, rgb_bands = self._rasters.get(image_path, (None, None))
if raster is None:
with rasterio.Env(**self._gdal_env):
raster = rasterio.open(image_path)
try:
raster = rasterio.open(image_path)
except rasterio.errors.RasterioIOError as error:
raise ValueError(f'Error opening raster {image_path}') from error
rgb_bands = read_raster.get_rgb_indices(raster)
self._rasters[image_path] = (raster, rgb_bands)
return raster, rgb_bands
Expand Down
180 changes: 148 additions & 32 deletions src/skai/read_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import dataclasses
import functools
import logging
import math
import os
import re
import shutil
Expand All @@ -26,10 +27,13 @@

import affine
import apache_beam as beam
import geopandas as gpd
import numpy as np
import pandas as pd
import pyproj
import rasterio
import rasterio.plot
import rasterio.shutil
import rasterio.warp
import rtree
import shapely.geometry
Expand Down Expand Up @@ -762,6 +766,31 @@ def _run_gdalbuildvrt(
extents: If not None, sets the extents of the VRT. Should by x_min, x_max,
y_min, y_max.
"""
# First verify that all images have the same projections and number of bands.
# VRTs do not support images with different projections and different numbers
# of bands.
# Input images with different resolutions are supported.
raster = rasterio.open(image_paths[0])
expected_crs = raster.crs
expected_band_count = raster.count
if expected_crs.units_factor[0] not in ('meter', 'metre'):
# Requiring meters may be too strict but is simpler. If other linear units
# such as feet are absolutely required, we can support them as well.
raise ValueError(
'The only supported linear unit is "meter", but found'
f' {expected_crs.units_factor[0]}'
)
for path in image_paths[1:]:
raster = rasterio.open(path)
if raster.crs != expected_crs:
raise ValueError(
f'Expecting CRS {expected_crs}, got {raster.crs}'
)
if raster.count != expected_band_count:
raise ValueError(
f'Expecting {expected_band_count} bands, got {raster.count}'
)

# GDAL doesn't recognize gs:// prefixes. Instead it wants /vsigs/ prefixes.
gdal_image_paths = [
p.replace('gs://', '/vsigs/') if p.startswith('gs://') else p
Expand Down Expand Up @@ -796,11 +825,78 @@ def _run_gdalbuildvrt(
shutil.copyfileobj(source, dest)


def _get_unified_warped_vrt_options(
image_paths: list[str], resolution: float
) -> dict[str, Any]:
"""Gets options for a WarpedVRT that projects images into unified space.

Input images can have arbitrary boundaries, CRS, and resolution. The WarpedVRT
will project them to the same boundaries, CRS, and resolution.

Args:
image_paths: Input image paths.
resolution: Desired output resolution.

Returns:
Dictionary of WarpedVRT constructor options.
"""
image_bounds = []
for image_path in image_paths:
r = rasterio.open(image_path)
image_bounds.append(
gpd.GeoDataFrame(
geometry=[shapely.geometry.box(*r.bounds)], crs=r.crs
).to_crs('EPSG:4326')
)
combined = pd.concat(image_bounds)
utm_crs = combined.estimate_utm_crs()
left, bottom, right, top = combined.to_crs(
utm_crs
).geometry.unary_union.bounds
width = int(math.ceil((right - left) / resolution))
height = int(math.ceil((top - bottom) / resolution))
transform = affine.Affine(resolution, 0.0, left, 0.0, -resolution, top)
return {
'resampling': rasterio.enums.Resampling.cubic,
'crs': utm_crs,
'transform': transform,
'width': width,
'height': height,
}


def _build_warped_vrt(
image_path: str,
vrt_path: str,
vrt_options: dict[str, Any],
gdal_env: dict[str, str],
) -> None:
"""Creates a WarpedVRT file from an image.

Args:
image_path: Path to source image.
vrt_path: VRT file output path.
vrt_options: Options for VRT creation.
gdal_env: GDAL environment configuration.
"""
with rasterio.Env(**gdal_env):
raster = rasterio.open(image_path)
with rasterio.vrt.WarpedVRT(raster, **vrt_options) as vrt:
with tempfile.TemporaryDirectory() as temp_dir:
temp_vrt_path = os.path.join(temp_dir, 'temp.vrt')
rasterio.shutil.copy(vrt, temp_vrt_path, driver='VRT')
with open(temp_vrt_path, 'rb') as source, tf.io.gfile.GFile(
vrt_path, 'wb'
) as dest:
shutil.copyfileobj(source, dest)


def build_vrts(
image_paths: list[str],
vrt_prefix: str,
resolution: float,
mosaic_images: bool,
gdal_env: dict[str, str],
) -> list[str]:
"""Builds VRTs from a list of image paths.

Expand All @@ -810,48 +906,68 @@ def build_vrts(
resolution: VRT resolution in meters per pixel.
mosaic_images: If true, build a single VRT containing all images. If false,
build an individual VRT per input image.
gdal_env: GDAL environment configuration.

Returns:
A list of paths of the generated VRTs.
"""
# First verify that all images have the same projections and number of bands.
# VRTs do not support images with different projections and different numbers
# of bands.
# Input images with different resolutions are supported.
raster = rasterio.open(image_paths[0])
expected_crs = raster.crs
expected_band_count = raster.count
x_bounds = [raster.bounds.left, raster.bounds.right]
y_bounds = [raster.bounds.bottom, raster.bounds.top]
if expected_crs.units_factor[0] not in ('meter', 'metre'):
# Requiring meters may be too strict but is simpler. If other linear units
# such as feet are absolutely required, we can support them as well.
raise ValueError(
'The only supported linear unit is "meter", but found'
f' {expected_crs.units_factor[0]}'
)
for path in image_paths[1:]:
raster = rasterio.open(path)
if raster.crs != expected_crs:
raise ValueError(
f'Expecting CRS {expected_crs}, got {raster.crs}'
)
if raster.count != expected_band_count:
raise ValueError(
f'Expecting {expected_band_count} bands, got {raster.count}'
)
x_bounds.extend((raster.bounds.left, raster.bounds.right))
y_bounds.extend((raster.bounds.bottom, raster.bounds.top))

extents = [min(x_bounds), min(y_bounds), max(x_bounds), max(y_bounds)]
vrt_paths = []
if mosaic_images:
vrt_path = f'{vrt_prefix}-00000-of-00001.vrt'
vrt_path = f'{vrt_prefix}.vrt'
_run_gdalbuildvrt(image_paths, vrt_path, resolution, None)
vrt_paths.append(vrt_path)
else:
warped_vrt_options = _get_unified_warped_vrt_options(
image_paths, resolution
)
for i, image_path in enumerate(image_paths):
vrt_path = f'{vrt_prefix}-{i:05d}-of-{len(image_paths):05d}.vrt'
_run_gdalbuildvrt([image_path], vrt_path, resolution, extents)
_build_warped_vrt(image_path, vrt_path, warped_vrt_options, gdal_env)
vrt_paths.append(vrt_path)
return vrt_paths


def prepare_building_detection_input_images(
image_patterns: list[str], vrt_dir: str, gdal_env: dict[str, str]
) -> list[str]:
"""Prepares input images for the building detection pipeline.

This function performs two operations:
1. For each image pattern that matches multiple files, the files are mosaic'ed
together by wrapping them in a regular VRT.
2. For all input images, including mosaic'ed images, this function wraps a
WarpedVRT around it to transform the image into the correct CRS and
resolution (0.5 meter).

Args:
image_patterns: Input image patterns.
vrt_dir: Directory to store VRTs in.
gdal_env: GDAL environment variables.

Returns:
List of VRTs.

Raises:
FileNotFoundError: If any of the image patterns does not match any files.
"""
wrapped_paths = []
for i, pattern in enumerate(image_patterns):
image_paths = utils.expand_file_patterns([pattern])
if not image_paths:
raise FileNotFoundError(f'{pattern} did not match any files.')
for image_path in image_paths:
if not raster_is_tiled(image_path):
raise ValueError(f'Raster "{image_path}" is not tiled.')
if len(image_paths) == 1:
wrapped_paths.append(image_paths[0])
else:
mosaic_dir = os.path.join(vrt_dir, 'mosaics')
if not tf.io.gfile.exists(mosaic_dir):
tf.io.gfile.makedirs(mosaic_dir)
vrt_prefix = os.path.join(vrt_dir, 'mosaics', f'mosaic-{i:05d}')
wrapped_paths.extend(
build_vrts(image_paths, vrt_prefix, 0.5, True, gdal_env)
)
return build_vrts(
wrapped_paths, os.path.join(vrt_dir, 'input'), 0.5, False, gdal_env
)
Loading
Loading