Skip to content

Commit cf38b91

Browse files
jzxucopybara-github
authored andcommitted
Add ability to wrap building detection input images in WarpedVRT to transform them to a unified CRS, resolution, and origin.
PiperOrigin-RevId: 702068546
1 parent 4eda411 commit cf38b91

File tree

4 files changed

+319
-63
lines changed

4 files changed

+319
-63
lines changed

src/detect_buildings_main.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
from skai import detect_buildings
5555
from skai import extract_tiles
5656
from skai import read_raster
57-
from skai import utils
5857

5958
import tensorflow as tf
6059

@@ -138,15 +137,9 @@ def main(args):
138137
gdf = gpd.read_file(f)
139138
aoi = gdf.geometry.values[0]
140139
gdal_env = read_raster.parse_gdal_env(FLAGS.gdal_env)
141-
image_paths = utils.expand_file_patterns(FLAGS.image_paths)
142-
for image_path in image_paths:
143-
if not read_raster.raster_is_tiled(image_path):
144-
raise ValueError(f'Raster "{image_path}" is not tiled.')
145-
146-
vrt_paths = read_raster.build_vrts(
147-
image_paths, os.path.join(temp_dir, 'image'), 0.5, FLAGS.mosaic_images
140+
vrt_paths = read_raster.prepare_input_images(
141+
FLAGS.image_paths, os.path.join(temp_dir, 'vrts'), gdal_env
148142
)
149-
150143
tiles = []
151144
for path in vrt_paths:
152145
tiles.extend(

src/skai/extract_tiles.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,19 @@ def get_tiles_for_aoi(image_path: str,
203203
204204
Yields:
205205
A grid of tiles that covers the AOI.
206+
207+
Raises:
208+
RuntimeError: If the image file does not exist.
206209
"""
210+
if not rasterio.shutil.exists(image_path):
211+
raise RuntimeError(f'File {image_path} does not exist')
212+
207213
with rasterio.Env(**gdal_env):
208214
image = rasterio.open(image_path)
209-
x_min, y_min, x_max, y_max = _get_pixel_bounds_for_aoi(image, aoi)
210-
yield from get_tiles(
211-
image_path, x_min, y_min, x_max, y_max, tile_size, margin
212-
)
215+
x_min, y_min, x_max, y_max = _get_pixel_bounds_for_aoi(image, aoi)
216+
yield from get_tiles(
217+
image_path, x_min, y_min, x_max, y_max, tile_size, margin
218+
)
213219

214220

215221
class ExtractTilesAsExamplesFn(beam.DoFn):
@@ -227,7 +233,10 @@ def _get_raster(
227233
raster, rgb_bands = self._rasters.get(image_path, (None, None))
228234
if raster is None:
229235
with rasterio.Env(**self._gdal_env):
230-
raster = rasterio.open(image_path)
236+
try:
237+
raster = rasterio.open(image_path)
238+
except rasterio.errors.RasterioIOError as error:
239+
raise ValueError(f'Error opening raster {image_path}') from error
231240
rgb_bands = read_raster.get_rgb_indices(raster)
232241
self._rasters[image_path] = (raster, rgb_bands)
233242
return raster, rgb_bands

src/skai/read_raster.py

Lines changed: 141 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import dataclasses
1717
import functools
1818
import logging
19+
import math
1920
import os
2021
import re
2122
import shutil
@@ -26,10 +27,13 @@
2627

2728
import affine
2829
import apache_beam as beam
30+
import geopandas as gpd
2931
import numpy as np
32+
import pandas as pd
3033
import pyproj
3134
import rasterio
3235
import rasterio.plot
36+
import rasterio.shutil
3337
import rasterio.warp
3438
import rtree
3539
import shapely.geometry
@@ -762,6 +766,31 @@ def _run_gdalbuildvrt(
762766
extents: If not None, sets the extents of the VRT. Should by x_min, x_max,
763767
y_min, y_max.
764768
"""
769+
# First verify that all images have the same projections and number of bands.
770+
# VRTs do not support images with different projections and different numbers
771+
# of bands.
772+
# Input images with different resolutions are supported.
773+
raster = rasterio.open(image_paths[0])
774+
expected_crs = raster.crs
775+
expected_band_count = raster.count
776+
if expected_crs.units_factor[0] not in ('meter', 'metre'):
777+
# Requiring meters may be too strict but is simpler. If other linear units
778+
# such as feet are absolutely required, we can support them as well.
779+
raise ValueError(
780+
'The only supported linear unit is "meter", but found'
781+
f' {expected_crs.units_factor[0]}'
782+
)
783+
for path in image_paths[1:]:
784+
raster = rasterio.open(path)
785+
if raster.crs != expected_crs:
786+
raise ValueError(
787+
f'Expecting CRS {expected_crs}, got {raster.crs}'
788+
)
789+
if raster.count != expected_band_count:
790+
raise ValueError(
791+
f'Expecting {expected_band_count} bands, got {raster.count}'
792+
)
793+
765794
# GDAL doesn't recognize gs:// prefixes. Instead it wants /vsigs/ prefixes.
766795
gdal_image_paths = [
767796
p.replace('gs://', '/vsigs/') if p.startswith('gs://') else p
@@ -796,11 +825,78 @@ def _run_gdalbuildvrt(
796825
shutil.copyfileobj(source, dest)
797826

798827

828+
def _get_unified_warped_vrt_options(
829+
image_paths: list[str], resolution: float
830+
) -> dict[str, Any]:
831+
"""Gets options for a WarpedVRT that projects images into unified space.
832+
833+
Input images can have arbitrary boundaries, CRS, and resolution. The WarpedVRT
834+
will project them to the same boundaries, CRS, and resolution.
835+
836+
Args:
837+
image_paths: Input image paths.
838+
resolution: Desired output resolution.
839+
840+
Returns:
841+
Dictionary of WarpedVRT constructor options.
842+
"""
843+
image_bounds = []
844+
for image_path in image_paths:
845+
r = rasterio.open(image_path)
846+
image_bounds.append(
847+
gpd.GeoDataFrame(
848+
geometry=[shapely.geometry.box(*r.bounds)], crs=r.crs
849+
).to_crs('EPSG:4326')
850+
)
851+
combined = pd.concat(image_bounds)
852+
utm_crs = combined.estimate_utm_crs()
853+
left, bottom, right, top = combined.to_crs(
854+
utm_crs
855+
).geometry.unary_union.bounds
856+
width = int(math.ceil((right - left) / resolution))
857+
height = int(math.ceil((top - bottom) / resolution))
858+
transform = affine.Affine(resolution, 0.0, left, 0.0, -resolution, top)
859+
return {
860+
'resampling': rasterio.enums.Resampling.cubic,
861+
'crs': utm_crs,
862+
'transform': transform,
863+
'width': width,
864+
'height': height,
865+
}
866+
867+
868+
def _build_warped_vrt(
869+
image_path: str,
870+
vrt_path: str,
871+
vrt_options: dict[str, Any],
872+
gdal_env: dict[str, str],
873+
) -> None:
874+
"""Creates a WarpedVRT file from an image.
875+
876+
Args:
877+
image_path: Path to source image.
878+
vrt_path: VRT file output path.
879+
vrt_options: Options for VRT creation.
880+
gdal_env: GDAL environment configuration.
881+
"""
882+
with rasterio.Env(**gdal_env):
883+
raster = rasterio.open(image_path)
884+
with rasterio.vrt.WarpedVRT(raster, **vrt_options) as vrt:
885+
with tempfile.TemporaryDirectory() as temp_dir:
886+
temp_vrt_path = os.path.join(temp_dir, 'temp.vrt')
887+
rasterio.shutil.copy(vrt, temp_vrt_path, driver='VRT')
888+
with open(temp_vrt_path, 'rb') as source, tf.io.gfile.GFile(
889+
vrt_path, 'wb'
890+
) as dest:
891+
shutil.copyfileobj(source, dest)
892+
893+
799894
def build_vrts(
800895
image_paths: list[str],
801896
vrt_prefix: str,
802897
resolution: float,
803898
mosaic_images: bool,
899+
gdal_env: dict[str, str],
804900
) -> list[str]:
805901
"""Builds VRTs from a list of image paths.
806902
@@ -810,48 +906,61 @@ def build_vrts(
810906
resolution: VRT resolution in meters per pixel.
811907
mosaic_images: If true, build a single VRT containing all images. If false,
812908
build an individual VRT per input image.
909+
gdal_env: GDAL environment configuration.
813910
814911
Returns:
815912
A list of paths of the generated VRTs.
816913
"""
817-
# First verify that all images have the same projections and number of bands.
818-
# VRTs do not support images with different projections and different numbers
819-
# of bands.
820-
# Input images with different resolutions are supported.
821-
raster = rasterio.open(image_paths[0])
822-
expected_crs = raster.crs
823-
expected_band_count = raster.count
824-
x_bounds = [raster.bounds.left, raster.bounds.right]
825-
y_bounds = [raster.bounds.bottom, raster.bounds.top]
826-
if expected_crs.units_factor[0] not in ('meter', 'metre'):
827-
# Requiring meters may be too strict but is simpler. If other linear units
828-
# such as feet are absolutely required, we can support them as well.
829-
raise ValueError(
830-
'The only supported linear unit is "meter", but found'
831-
f' {expected_crs.units_factor[0]}'
832-
)
833-
for path in image_paths[1:]:
834-
raster = rasterio.open(path)
835-
if raster.crs != expected_crs:
836-
raise ValueError(
837-
f'Expecting CRS {expected_crs}, got {raster.crs}'
838-
)
839-
if raster.count != expected_band_count:
840-
raise ValueError(
841-
f'Expecting {expected_band_count} bands, got {raster.count}'
842-
)
843-
x_bounds.extend((raster.bounds.left, raster.bounds.right))
844-
y_bounds.extend((raster.bounds.bottom, raster.bounds.top))
845-
846-
extents = [min(x_bounds), min(y_bounds), max(x_bounds), max(y_bounds)]
847914
vrt_paths = []
848915
if mosaic_images:
849-
vrt_path = f'{vrt_prefix}-00000-of-00001.vrt'
916+
vrt_path = f'{vrt_prefix}.vrt'
850917
_run_gdalbuildvrt(image_paths, vrt_path, resolution, None)
851918
vrt_paths.append(vrt_path)
852919
else:
920+
warped_vrt_options = _get_unified_warped_vrt_options(
921+
image_paths, resolution
922+
)
853923
for i, image_path in enumerate(image_paths):
854924
vrt_path = f'{vrt_prefix}-{i:05d}-of-{len(image_paths):05d}.vrt'
855-
_run_gdalbuildvrt([image_path], vrt_path, resolution, extents)
925+
_build_warped_vrt(image_path, vrt_path, warped_vrt_options, gdal_env)
856926
vrt_paths.append(vrt_path)
857927
return vrt_paths
928+
929+
930+
def prepare_input_images(
931+
image_patterns: list[str], vrt_dir: str, gdal_env: dict[str, str]
932+
) -> list[str]:
933+
"""Unify image resolutions and CRS.
934+
935+
Args:
936+
image_patterns: Input image patterns.
937+
vrt_dir: Directory to store VRTs in.
938+
gdal_env: GDAL environment variables.
939+
940+
Returns:
941+
List of VRTs.
942+
943+
Raises:
944+
FileNotFoundError: If any of the image patterns does not match any files.
945+
"""
946+
wrapped_paths = []
947+
for i, pattern in enumerate(image_patterns):
948+
image_paths = utils.expand_file_patterns([pattern])
949+
if not image_paths:
950+
raise FileNotFoundError(f'{pattern} did not match any files.')
951+
for image_path in image_paths:
952+
if not raster_is_tiled(image_path):
953+
raise ValueError(f'Raster "{image_path}" is not tiled.')
954+
if len(image_paths) == 1:
955+
wrapped_paths.append(image_paths[0])
956+
else:
957+
mosaic_dir = os.path.join(vrt_dir, 'mosaics')
958+
if not tf.io.gfile.exists(mosaic_dir):
959+
tf.io.gfile.makedirs(mosaic_dir)
960+
vrt_prefix = os.path.join(vrt_dir, 'mosaics', f'mosaic-{i:05d}')
961+
wrapped_paths.extend(
962+
build_vrts(image_paths, vrt_prefix, 0.5, True, gdal_env)
963+
)
964+
return build_vrts(
965+
wrapped_paths, os.path.join(vrt_dir, 'input'), 0.5, False, gdal_env
966+
)

0 commit comments

Comments
 (0)