Skip to content

Commit

Permalink
compute_quantiles added to .../weatherbench2/scripts/
Browse files Browse the repository at this point in the history
This rechunks to `WORKING_CHUNKS`, and computes quantiles in memory within these. An `xarray_beam.Quantiles` reducer would be more efficient, but it doesn't exist.

PiperOrigin-RevId: 684639846
  • Loading branch information
langmore authored and Weatherbench2 authors committed Oct 18, 2024
1 parent b720fea commit e148c6b
Show file tree
Hide file tree
Showing 2 changed files with 390 additions and 0 deletions.
253 changes: 253 additions & 0 deletions scripts/compute_quantiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
r"""Computes quantiles in a Dataset.
Example of getting quantiles of temperature by latitude, longitude, and level.
So we reduce over all other dims (in this case, "time").
```
export BUCKET=my-bucket
export PROJECT=my-project
python scripts/compute_quantiles.py \
--input_path=gs://weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_with_poles_conservative.zarr \
--output_path=gs://$BUCKET/datasets/era5/$USER/temperature-quantiles.zarr \
--runner=DataflowRunner \
-- \
--project=$PROJECT \
--dim=time \
--variables=temperature \
--time_start="2000-01-01" \
--time_stop="2000-12-31" \
--working_chunks="latitude=4,longitude=4,level=1" \
--temp_location=gs://$BUCKET/tmp/ \
--setup_file=./setup.py \
--requirements_file=./scripts/dataflow-requirements.txt \
--job_name=compute-vertical-profile-$USER
```
"""

import typing as t

from absl import app
from absl import flags
import apache_beam as beam
from weatherbench2 import flag_utils
import xarray as xr
import xarray_beam as xbeam


INPUT_PATH = flags.DEFINE_string('input_path', None, help='zarr input path')
OUTPUT_PATH = flags.DEFINE_string(
'output_path',
None,
help='Path to output zarr',
)
QUANTILES = flags.DEFINE_list(
'quantiles',
None,
help='Comma delimited list of quantiles, 0 <= q <= 1.',
)
DIM = flags.DEFINE_list(
name='dim',
default=[],
help='Comma delimited list of dimensions to reduce over.',
)
SKIPNA = flags.DEFINE_boolean(
'skipna',
False,
help=(
'Whether to skip NaN data points (in forecasts and observations) when'
' evaluating.'
),
)

LEVELS = flags.DEFINE_list(
'levels',
None,
help=(
'Comma delimited list of pressure levels to compute spectra on. If'
' empty, compute on all levels of --input_path'
),
)
TIME_DIM = flags.DEFINE_string(
'time_dim',
'time',
help=(
'Name for the time dimension to slice data on, if TIME_START or'
' TIME_STOP is provided.'
),
)
TIME_START = flags.DEFINE_string(
'time_start',
'2020-01-01',
help='ISO 8601 timestamp (inclusive) at which to start evaluation',
)
TIME_STOP = flags.DEFINE_string(
'time_stop',
'2020-12-31',
help='ISO 8601 timestamp (inclusive) at which to stop evaluation',
)
VARIABLES = flags.DEFINE_list(
'variables',
None,
help=(
'Comma delimited list of data variables to include in output. '
'If empty, compute on all data_vars of --input_path'
),
)

WORKING_CHUNKS = flag_utils.DEFINE_chunks(
'working_chunks',
'',
help=(
'If provided, rechunk to this when reducing. E.g. "time=1,timedelta=5".'
' Keys must be a subset of dimensions not being reduced over'
' (preserved dims). The in process memory size is the working chunk'
' size, and dims not preserved cannot be working chunks. So set this'
' carefully. For that reason, the default value for all preserved dims'
' is 1.'
),
)
OUTPUT_CHUNKS = flag_utils.DEFINE_chunks(
'output_chunks',
'',
help=(
'If provided, rechunk output to this after reducing. E.g.'
' "time=1,timedelta=1". By default, re-use the input dataset chunk'
' sizes.'
),
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)
RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')


def _get_preserve_dims(ds: xr.Dataset) -> set[t.Hashable]:
"""Dims in ds that are preserved."""
return set([d for d in ds.dims if d not in DIM.value])


def _impose_data_selection(
ds: xr.Dataset,
chunks: t.Mapping[str, int],
) -> tuple[xr.Dataset, dict[str, int]]:
"""Select subset of data and chunks as requested by FLAGS."""
if VARIABLES.value is not None:
ds = ds[VARIABLES.value]
selection = {
TIME_DIM.value: slice(TIME_START.value, TIME_STOP.value),
}
if LEVELS.value:
selection['level'] = [float(l) for l in LEVELS.value]
ds = ds.sel({k: v for k, v in selection.items() if k in ds.dims})
chunks = {k: v for k, v in chunks.items() if k in ds.dims}
return ds, chunks


def evaluate_chunk(
key: xbeam.Key, chunk: xr.Dataset
) -> tuple[xbeam.Key, xr.Dataset]:
new_chunk = _evaluate_chunk_core(chunk)
new_key = key.with_offsets(
**{k: None for k in key.offsets if k not in new_chunk.dims}
)
return new_key, new_chunk


def _evaluate_chunk_core(chunk: xr.Dataset) -> xr.Dataset:
"""Implementation of evaluate_chunk that doesn't use a key."""
preserve_dims = _get_preserve_dims(chunk)
if not preserve_dims.issubset(chunk.dims.keys()):
raise ValueError(
f'User specified {DIM.value=}, which results in preserved dims'
f' {preserve_dims} , not being a subset of {chunk.dims.keys()=}'
)

quantiles = [float(q) for q in QUANTILES.value]
if any(q < 0 or q > 1 for q in quantiles):
raise ValueError(
f'Expected all quantiles to be in [0, 1]. Found {quantiles=}'
)
return chunk.quantile(quantiles, dim=DIM.value, skipna=SKIPNA.value)


def main(argv: list[str]) -> None:
source_ds, source_chunks = _impose_data_selection(
*xbeam.open_zarr(INPUT_PATH.value)
)

preserve_dims = _get_preserve_dims(source_ds)

if not set(WORKING_CHUNKS.value).issubset(preserve_dims):
raise flags.IllegalFlagValueError(
f'{WORKING_CHUNKS.value.keys()=} was not a subset of preserved dims'
f' {preserve_dims}'
)

working_chunks = WORKING_CHUNKS.value.copy()
for k in set(source_chunks).difference(working_chunks):
if k in preserve_dims:
working_chunks[k] = 1
else:
working_chunks[k] = -1
output_chunks = {
k: OUTPUT_CHUNKS.value.get(k, source_chunks[k])
for k in preserve_dims.intersection(source_chunks)
}
output_chunks.setdefault('quantile', -1)

# Make the template by evaluation (which reduces to produce a dataset with
# correct output dims).
template = _evaluate_chunk_core(xbeam.make_template(source_ds))

output_chunks = {
# The template may be smaller than output_chunks.
k: min(output_chunks[k], template.sizes[k])
for k in output_chunks
}

itemsize = max(var.dtype.itemsize for var in template.values())

with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
_ = (
root
| xbeam.DatasetToChunks(
source_ds,
source_chunks,
split_vars=True,
num_threads=NUM_THREADS.value,
)
# TODO(langmore) Write a xarray_beam quantile reducer to avoid this
# rechunking.
| 'RechunkToWorkingChunks'
>> xbeam.Rechunk( # pytype: disable=wrong-arg-types
source_ds.sizes,
source_chunks,
working_chunks,
itemsize=itemsize,
)
| 'Compute_nan_fraction' >> beam.MapTuple(evaluate_chunk)
| 'RechunkToOutputChunks'
>> xbeam.Rechunk( # pytype: disable=wrong-arg-types
template.sizes,
# Want to inject -1 for new dims
{k: working_chunks.get(k, -1) for k in output_chunks},
output_chunks,
itemsize=itemsize,
)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value,
template,
output_chunks,
num_threads=NUM_THREADS.value,
)
)


if __name__ == '__main__':
flags.mark_flags_as_required(
['input_path', 'output_path', 'dim', 'quantiles']
)
app.run(main)
137 changes: 137 additions & 0 deletions scripts/compute_quantiles_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from absl.testing import absltest
from absl.testing import flagsaver
from absl.testing import parameterized
import numpy as np
import pandas as pd
import xarray as xr
import xarray_beam

from . import compute_quantiles


class ComputeQuantileTest(parameterized.TestCase):

def _dict_to_str(self, a_dict):
return ','.join(f'{k}={v}' for k, v in a_dict.items())

@parameterized.named_parameters(
dict(
testcase_name='NoChunks',
),
dict(
testcase_name='SpecifyInputWorkingChunks1',
input_chunks={'time': 2},
working_chunks={'time': 1, 'timedelta': 1},
),
dict(
testcase_name='SpecifyInputWorkingChunks2',
input_chunks={'time': 2},
working_chunks={'time': 1, 'timedelta': -1},
),
dict(
testcase_name='SpecifyInputOutputAndWorkingChunks1',
input_chunks={'time': 2},
output_chunks={'time': 2, 'timedelta': 3},
working_chunks={'time': -1},
),
dict(
testcase_name='SpecifyInputOutputAndWorkingChunks2',
input_chunks={'timedelta': 2},
output_chunks={'time': 2},
working_chunks={'timedelta': 1},
),
dict(
testcase_name='SpecifyInputOutputChunks1',
input_chunks={'timedelta': 2},
output_chunks={'time': 2},
),
dict(
testcase_name='SpecifyInputOutputChunks2',
input_chunks={'timedelta': 2},
output_chunks={'time': -1},
),
)
def test_basic_dataset(
self,
input_chunks=None,
output_chunks=None,
working_chunks=None,
):
input_chunks = input_chunks or {}
output_chunks = output_chunks or {}
working_chunks = working_chunks or {}
times = pd.DatetimeIndex(
[
'2023-01-01',
'2023-01-02',
'2023-01-03',
'2023-01-04',
]
) # fmt: skip
lats = np.arange(50)
timedeltas = np.arange(6)

precip = np.random.RandomState(
802701 + len(input_chunks) + len(output_chunks) + len(working_chunks)
).rand(4, 50, 6)

quantiles = [0.2, 0.8]

input_ds = xr.Dataset(
{
'precip': xr.DataArray(
precip,
coords=[times, lats, timedeltas],
dims=['time', 'lat', 'timedelta'],
),
'should_drop': xr.DataArray(
precip * 2,
coords=[times, lats, timedeltas],
dims=['time', 'lat', 'timedelta'],
),
}
) # fmt: skip

input_path = self.create_tempdir('source').full_path
input_ds.chunk(input_chunks).to_zarr(input_path)

# Get modified output
output_path = self.create_tempdir('output').full_path
with flagsaver.as_parsed(
input_path=input_path,
output_path=output_path,
working_chunks=self._dict_to_str(working_chunks),
output_chunks=self._dict_to_str(output_chunks),
dim='lat',
variables='precip',
time_start='2023-01-01',
time_stop='2023-01-03',
quantiles=','.join(str(q) for q in quantiles),
runner='DirectRunner',
):
compute_quantiles.main([])
output, actual_output_chunks = xarray_beam.open_zarr(output_path)

# Output only has the "preserved dims" + quantile
self.assertCountEqual(output.dims, ['time', 'timedelta', 'quantile'])

expected_output_chunks = {'quantile': -1}
for k in output.dims:
if k in output_chunks and output_chunks[k] == -1:
expected_output_chunks[k] = output.sizes[k]
elif k in output_chunks:
expected_output_chunks[k] = output_chunks[k]
else:
expected_output_chunks[k] = min(
input_chunks.get(k, np.inf), output.sizes[k]
)
self.assertEqual(expected_output_chunks, actual_output_chunks)

xr.testing.assert_equal(
input_ds[['precip']].isel(time=slice(3)).quantile(quantiles, dim='lat'),
output,
)


if __name__ == '__main__':
absltest.main()

0 comments on commit e148c6b

Please sign in to comment.