|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | + |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +import pathlib |
| 7 | +from typing import Optional, Tuple, Union |
| 8 | +from osgeo import gdal |
| 9 | + |
| 10 | +from _exactextract import RasterSource |
| 11 | + |
| 12 | + |
| 13 | +class GDALRasterSource(RasterSource): |
| 14 | + def __init__(self, ds, band_idx: int = 1): |
| 15 | + super().__init__() |
| 16 | + self.ds = ds |
| 17 | + |
| 18 | + # Sanity check inputs |
| 19 | + if band_idx is not None and band_idx <= 0: |
| 20 | + raise ValueError("Raster band index starts from 1!") |
| 21 | + |
| 22 | + self.band = self.ds.GetRasterBand(band_idx) |
| 23 | + |
| 24 | + def res(self): |
| 25 | + gt = self.ds.GetGeoTransform() |
| 26 | + return gt[1], abs(gt[5]) |
| 27 | + |
| 28 | + def extent(self): |
| 29 | + gt = self.ds.GetGeoTransform() |
| 30 | + |
| 31 | + dx, dy = self.res() |
| 32 | + |
| 33 | + left = gt[0] |
| 34 | + right = left + dx * self.ds.RasterXSize |
| 35 | + top = gt[3] |
| 36 | + bottom = gt[3] - dy * self.ds.RasterYSize |
| 37 | + |
| 38 | + return (left, bottom, right, top) |
| 39 | + |
| 40 | + def read_window(self, x0, y0, nx, ny): |
| 41 | + return self.band.ReadAsArray(xoff=x0, yoff=y0, win_xsize=nx, win_ysize=ny) |
| 42 | + |
| 43 | + |
| 44 | +class NumPyRasterSource(RasterSource): |
| 45 | + def __init__(self, mat, xmin, ymin, xmax, ymax): |
| 46 | + super().__init__() |
| 47 | + self.mat = mat |
| 48 | + self.ext = (xmin, ymin, xmax, ymax) |
| 49 | + |
| 50 | + def res(self): |
| 51 | + ny, nx = self.mat.shape |
| 52 | + dy = (self.ext[3] - self.ext[1]) / ny |
| 53 | + dx = (self.ext[2] - self.ext[0]) / nx |
| 54 | + |
| 55 | + return (dx, dy) |
| 56 | + |
| 57 | + def extent(self): |
| 58 | + return self.ext |
| 59 | + |
| 60 | + def read_window(self, x0, y0, nx, ny): |
| 61 | + return self.mat[y0 : y0 + ny, x0 : x0 + ny] |
0 commit comments