forked from pangeo-data/cog-best-practices
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paththread_pool_executor.py
77 lines (56 loc) · 2.55 KB
/
thread_pool_executor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""thread_pool_executor.py
Operate on a raster dataset window-by-window using a ThreadPoolExecutor.
Simulates a CPU-bound thread situation where multiple threads can improve
performance.
With -j 4, the program returns in about 1/4 the time as with -j 1.
"""
import concurrent.futures
import multiprocessing
import rasterio
from rasterio._example import compute
def main(infile, outfile, num_workers=4):
"""Process infile block-by-block and write to a new file
The output is the same as the input, but with band order
reversed.
"""
with rasterio.Env():
with rasterio.open(infile) as src:
# Create a destination dataset based on source params. The
# destination will be tiled, and we'll process the tiles
# concurrently.
profile = src.profile
profile.update(blockxsize=128, blockysize=128, tiled=True)
with rasterio.open(outfile, "w", **profile) as dst:
# Materialize a list of destination block windows
# that we will use in several statements below.
windows = [window for ij, window in dst.block_windows()]
# This generator comprehension gives us raster data
# arrays for each window. Later we will zip a mapping
# of it with the windows list to get (window, result)
# pairs.
data_gen = (src.read(window=window) for window in windows)
with concurrent.futures.ThreadPoolExecutor(
max_workers=num_workers
) as executor:
# We map the compute() function over the raster
# data generator, zip the resulting iterator with
# the windows list, and as pairs come back we
# write data to the destination dataset.
for window, result in zip(
windows, executor.map(compute, data_gen)
):
dst.write(result, window=window)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Concurrent raster processing demo")
parser.add_argument("input", metavar="INPUT", help="Input file name")
parser.add_argument("output", metavar="OUTPUT", help="Output file name")
parser.add_argument(
"-j",
metavar="NUM_JOBS",
type=int,
default=multiprocessing.cpu_count(),
help="Number of concurrent jobs",
)
args = parser.parse_args()
main(args.input, args.output, args.j)