Skip to content

Commit f925dba

Browse files
committed
Simplify and speed up open_mfdataset
Xarray's open_mfdataset is very slow for BOUT++ datasets. Other datasets also have issues (see e.g. pydata/xarray#1385) though the cause may not be the same. Using an implementation that opens the datasets and concatenates significantly speeds up this process.
1 parent 2d11a6f commit f925dba

File tree

2 files changed

+97
-77
lines changed

2 files changed

+97
-77
lines changed

xbout/load.py

Lines changed: 13 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,6 @@
1919
_is_dir,
2020
)
2121

22-
23-
_BOUT_PER_PROC_VARIABLES = [
24-
"wall_time",
25-
"wtime",
26-
"wtime_rhs",
27-
"wtime_invert",
28-
"wtime_comms",
29-
"wtime_io",
30-
"wtime_per_rhs",
31-
"wtime_per_rhs_e",
32-
"wtime_per_rhs_i",
33-
"PE_XIND",
34-
"PE_YIND",
35-
"MYPE",
36-
]
37-
_BOUT_TIME_DEPENDENT_META_VARS = ["iteration", "hist_hi", "tt"]
3822
_BOUT_GEOMETRY_VARS = [
3923
"ixseps1",
4024
"ixseps2",
@@ -69,9 +53,6 @@
6953
)
7054

7155

72-
# TODO somehow check that we have access to the latest version of auto_combine
73-
74-
7556
def open_boutdataset(
7657
datapath="./BOUT.dmp.*.nc",
7758
inputfilepath=None,
@@ -295,15 +276,6 @@ def attrs_remove_section(obj, section):
295276
else:
296277
raise ValueError(f"internal error: unexpected input_type={input_type}")
297278

298-
if not is_restart:
299-
for var in _BOUT_TIME_DEPENDENT_META_VARS:
300-
if var in ds:
301-
# Assume different processors in x & y have same iteration etc.
302-
latest_top_left = {dim: 0 for dim in ds[var].dims}
303-
if "t" in ds[var].dims:
304-
latest_top_left["t"] = -1
305-
ds[var] = ds[var].isel(latest_top_left).squeeze(drop=True)
306-
307279
ds, metadata = _separate_metadata(ds)
308280
# Store as ints because netCDF doesn't support bools, so we can't save
309281
# bool attributes
@@ -616,11 +588,6 @@ def _auto_open_mfboutdataset(
616588
if chunks is None:
617589
chunks = {}
618590

619-
if is_restart:
620-
data_vars = "minimal"
621-
else:
622-
data_vars = _BOUT_TIME_DEPENDENT_META_VARS
623-
624591
if _is_path(datapath):
625592
filepaths, filetype = _expand_filepaths(datapath)
626593

@@ -640,6 +607,9 @@ def _auto_open_mfboutdataset(
640607
else:
641608
remove_yboundaries = False
642609

610+
# Create a partial application of _trim
611+
# Calls to _preprocess will call _trim to trim guard / boundary cells
612+
# from datasets before merging.
643613
_preprocess = partial(
644614
_trim,
645615
guards={"x": mxg, "y": myg},
@@ -651,40 +621,11 @@ def _auto_open_mfboutdataset(
651621

652622
paths_grid, concat_dims = _arrange_for_concatenation(filepaths, nxpe, nype)
653623

654-
try:
655-
ds = xr.open_mfdataset(
656-
paths_grid,
657-
concat_dim=concat_dims,
658-
combine="nested",
659-
data_vars=data_vars,
660-
preprocess=_preprocess,
661-
engine=filetype,
662-
chunks=chunks,
663-
join="exact",
664-
**kwargs,
665-
)
666-
except ValueError as e:
667-
message_to_catch = (
668-
"some variables in data_vars are not data variables on the first "
669-
"dataset:"
670-
)
671-
if str(e)[: len(message_to_catch)] == message_to_catch:
672-
# Open concatenating any variables that are different in
673-
# different files as a work around to support opening older
674-
# data.
675-
ds = xr.open_mfdataset(
676-
paths_grid,
677-
concat_dim=concat_dims,
678-
combine="nested",
679-
data_vars="different",
680-
preprocess=_preprocess,
681-
engine=filetype,
682-
chunks=chunks,
683-
join="exact",
684-
**kwargs,
685-
)
686-
else:
687-
raise
624+
# Call custom implementation of open_mfdataset
625+
# avoiding some of the performance issues.
626+
from .mfdataset import mfdataset
627+
628+
ds = mfdataset(paths_grid, concat_dim=concat_dims, preprocess=_preprocess)
688629
else:
689630
# datapath was nested list of Datasets
690631

@@ -731,11 +672,6 @@ def _auto_open_mfboutdataset(
731672
combine_attrs="no_conflicts",
732673
)
733674

734-
if not is_restart:
735-
# Remove any duplicate time values from concatenation
736-
_, unique_indices = unique(ds["t_array"], return_index=True)
737-
ds = ds.isel(t=unique_indices)
738-
739675
return ds, remove_yboundaries
740676

741677

@@ -933,8 +869,10 @@ def _trim(ds, *, guards, keep_boundaries, nxpe, nype, is_restart):
933869
"""
934870
Trims all guard (and optionally boundary) cells off a single dataset read from a
935871
single BOUT dump file, to prepare for concatenation.
936-
Also drops some variables that store timing information, which are different for each
937-
process and so cannot be concatenated.
872+
873+
Variables that store timing information, which are different for each
874+
process, are not trimmed but are taken from the first processor during
875+
concatenation.
938876
939877
Parameters
940878
----------
@@ -973,9 +911,7 @@ def _trim(ds, *, guards, keep_boundaries, nxpe, nype, is_restart):
973911
):
974912
trimmed_ds = trimmed_ds.drop_vars(name)
975913

976-
to_drop = _BOUT_PER_PROC_VARIABLES
977-
978-
return trimmed_ds.drop_vars(to_drop, errors="ignore")
914+
return trimmed_ds
979915

980916

981917
def _infer_contains_boundaries(ds, nxpe, nype):

xbout/mfdataset.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Custom implementation of xarray.open_mfdataset()
2+
3+
import xarray as xr
4+
5+
6+
def concat_outer(dss, operation=None):
7+
"""
8+
Concatenate nested lists along their outer dimension
9+
10+
# Example
11+
12+
>>> m = [[1,2,3],
13+
[2,3,3],
14+
[5,4,3]]
15+
16+
>>> concat_outer(m)
17+
18+
[[1, 2, 5], [2, 3, 4], [3, 3, 3]]
19+
20+
>>> concat_outer(m, operation=sum)
21+
22+
[8, 9, 9]
23+
24+
"""
25+
if not isinstance(dss[0], list):
26+
# Input is a 1D list
27+
if operation is not None:
28+
return operation(dss)
29+
return dss
30+
31+
# Two or more dimensions
32+
# Swap first and second indices then concatenate inner
33+
if len(dss[0]) == 1:
34+
return concat_outer([dss[j][0] for j in range(len(dss))], operation=operation)
35+
36+
return [
37+
concat_outer([dss[j][i] for j in range(len(dss))], operation=operation)
38+
for i in range(len(dss[0]))
39+
]
40+
41+
42+
def mfdataset(paths, chunks=None, concat_dim=None, preprocess=None):
43+
if chunks is None:
44+
chunks = {}
45+
46+
if not isinstance(concat_dim, list):
47+
concat_dim = [concat_dim]
48+
49+
if isinstance(paths, list):
50+
# Read nested dataset
51+
52+
dss = [
53+
mfdataset(
54+
path, chunks=chunks, concat_dim=concat_dim[1:], preprocess=preprocess
55+
)
56+
for path in paths
57+
]
58+
59+
# The dimension to concatenate along
60+
if concat_dim[0] is None:
61+
# Not concatenating
62+
if len(dss) == 1:
63+
return dss[0]
64+
return dss
65+
66+
# Concatenating along the top-level dimension
67+
return concat_outer(
68+
dss,
69+
operation=lambda ds: xr.concat(
70+
ds,
71+
concat_dim[0],
72+
data_vars="minimal", # Only data variables in which the dimension already appears are concatenated.
73+
coords="minimal", # Only coordinates in which the dimension already appears are concatenated.
74+
compat="override", # Duplicate data taken from first dataset
75+
join="exact", # Don't align. Raise ValueError when indexes to be aligned are not equal
76+
combine_attrs="override", # Duplicate attributes taken from first dataset
77+
create_index_for_new_dim=False,
78+
),
79+
)
80+
# A single path
81+
ds = xr.open_dataset(paths, chunks=chunks)
82+
if preprocess is not None:
83+
ds = preprocess(ds)
84+
return ds

0 commit comments

Comments
 (0)