-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocessing.py
197 lines (173 loc) · 7.95 KB
/
preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
from xgcm import Grid
import pop_tools
import gcsfs
import fsspec as fs
import numpy as np
import xesmf as xe
import xarray as xr
import random
import matplotlib.pyplot as plt
import warnings
from xgcm import Grid
def load_and_combine_cm26(
filesystem: gcsfs.GCSFileSystem, inline_array=False
) -> xr.Dataset:
"""Loads, combines, and preprocesses CM2.6 data
Steps:
- Interpolate ocean velocities on ocean tracer points (with xgcm)
- Regrid atmospheric variables to ocean tracer grid (with xesmf)
- Match time and merge datasets
- Adjust units for aerobulk input
- Calculate relative wind components
"""
kwargs = dict(
consolidated=True, use_cftime=True, inline_array=inline_array, engine="zarr"
)
print("Load Data")
mapper = filesystem.get_mapper("gs://cmip6/GFDL_CM2_6/control/surface")
ds_ocean = xr.open_dataset(mapper, chunks={"time": 3}, **kwargs)
mapper = filesystem.get_mapper("gs://cmip6/GFDL_CM2_6/control/ocean_boundary")
xr.open_dataset(mapper, chunks={"time": 3}, **kwargs)
# xarray says not to do this
# ds_atmos = xr.open_zarr('gs://cmip6/GFDL_CM2_6/control/atmos_daily.zarr', chunks={'time':1}, **kwargs) # noqa: E501
mapper = filesystem.get_mapper("gs://cmip6/GFDL_CM2_6/control/atmos_daily.zarr")
ds_atmos = xr.open_dataset(mapper, chunks={"time": 120}, **kwargs).chunk(
{"time": 3}
)
mapper = filesystem.get_mapper("gs://cmip6/GFDL_CM2_6/grid")
ds_oc_grid = xr.open_dataset(mapper, chunks={}, **kwargs)
# ds_oc_grid = cat["GFDL_CM2_6_grid"].to_dask()
print("Align in time")
# cut to same time
all_dims = set(list(ds_ocean.dims) + list(ds_atmos.dims))
ds_ocean, ds_atmos = xr.align(
ds_ocean,
ds_atmos,
join="inner",
exclude=(di for di in all_dims if di != "time"),
)
print("Interpolating ocean velocities")
# interpolate ocean velocities onto the tracer points using xgcm
from xgcm import Grid
# add xgcm comodo attrs
ds_ocean["xu_ocean"].attrs["axis"] = "X"
ds_ocean["xt_ocean"].attrs["axis"] = "X"
ds_ocean["xu_ocean"].attrs["c_grid_axis_shift"] = 0.5
ds_ocean["xt_ocean"].attrs["c_grid_axis_shift"] = 0.0
ds_ocean["yu_ocean"].attrs["axis"] = "Y"
ds_ocean["yt_ocean"].attrs["axis"] = "Y"
ds_ocean["yu_ocean"].attrs["c_grid_axis_shift"] = 0.5
ds_ocean["yt_ocean"].attrs["c_grid_axis_shift"] = 0.0
grid = Grid(ds_ocean)
# fill missing values with 0, then interpolate.
sst_wet_mask = ~np.isnan(ds_ocean["surface_temp"])
# TODO: Maybe stencil out the nans from SST? This is done again in aerobulk-python
ds_ocean["u_ocean"] = grid.interp_like(
ds_ocean["usurf"].fillna(0), ds_ocean["surface_temp"]
).where(sst_wet_mask)
ds_ocean["v_ocean"] = grid.interp_like(
ds_ocean["vsurf"].fillna(0), ds_ocean["surface_temp"]
).where(sst_wet_mask)
# FIXME: All of these additional operations (filling/masking) add lot more tasks...
print("Regrid Atmospheric Data")
# Start regridding the atmosphere onto the ocean grid
# Load precalculated regridder weights from group bucket
# TODO: Maybe this should be an input argument?
path = "gs://leap-persistent/jbusecke/scale-aware-air-sea/regridding_weights/CM26_atmos2ocean.zarr" # noqa: E501
mapper = filesystem.get_mapper(path)
ds_regridder = xr.open_zarr(mapper).load()
regridder = xe.Regridder(
ds_atmos.olr.to_dataset(name="dummy")
.isel(time=0)
.reset_coords(drop=True), # this is the same dumb problem I keep having with
ds_ocean.surface_temp.to_dataset(name="dummy")
.isel(time=0)
.reset_coords(drop=True),
"bilinear",
weights=ds_regridder,
periodic=True,
)
ds_atmos_regridded = regridder(
ds_atmos[["slp", "v_ref", "u_ref", "t_ref", "q_ref", "wind"]]
) # We are only doing noskin for now , 'swdn_sfc', 'lwdn_sfc'
## combine into merged dataset
ds_merged = xr.merge(
[
ds_atmos_regridded,
ds_ocean[["surface_temp", "u_ocean", "v_ocean"]],
]
)
print("Modify units")
# ds_merged = ds_merged.transpose(
# 'xt_ocean', 'yt_ocean', 'time'
# )
# fix units for aerobulk
ds_merged["surface_temp"] = ds_merged["surface_temp"] + 273.15
ds_merged["slp"] = ds_merged["slp"] * 100 # check this
print("Mask nans")
# atmos missing values are filled with 0s, which causes issues with the filtering
# Ideally this should be masked before the regridding, but xesmf fills with 0 again...
mask = ~np.isnan(ds_merged['surface_temp'].isel(time=0).reset_coords(drop=True))
for mask_var in ['slp', 't_ref', 'q_ref', 'v_ref', 'u_ref', 'wind']:
ds_merged[mask_var] = ds_merged[mask_var].where(mask)
# also apply this mask to certain coordinates from the grid dataset (for now only tracer_area since that
for mask_coord in ['area_t']:
ds_merged.coords[mask_coord] = ds_oc_grid[mask_coord].where(mask,0.0).astype(np.float64)
# The casting to float64 is needed to avoid that weird bug where the manual global weighted ave
# is not close to the xarray weighted mean (I was not able to reproduce this with an example)
# Ideally this should be masked before the regridding,
# but xesmf fills with 0 again...
mask = ~np.isnan(ds_merged["surface_temp"])
for mask_var in ["slp", "t_ref", "q_ref"]:
ds_merged[mask_var] = ds_merged[mask_var].where(mask)
# Calculate relative wind
print("Calculate relative wind")
ds_merged["u_relative"] = ds_merged["u_ref"] - ds_merged["u_ocean"]
ds_merged["v_relative"] = ds_merged["v_ref"] - ds_merged["v_ocean"]
return ds_merged
def preprocess_data():
warnings.filterwarnings("ignore")
# Loading ocean data
kwargs = dict(consolidated=True, use_cftime=True, engine="zarr")
print("Load Data")
ocean_path = "gs://cmip6/GFDL_CM2_6/control/surface"
ds_ocean = xr.open_dataset(fs.get_mapper(ocean_path), chunks={"time": 3}, **kwargs)
ocean_boundary_path = "gs://cmip6/GFDL_CM2_6/control/ocean_boundary"
ds_ocean_boundary = xr.open_dataset(fs.get_mapper(ocean_boundary_path), chunks={"time": 3}, **kwargs)
grid_path = "gs://cmip6/GFDL_CM2_6/grid"
ds_ocean_grid = xr.open_dataset(fs.get_mapper(grid_path), chunks={}, **kwargs)
# combine all dataset on the ocean grid together
ds_ocean = xr.merge([ds_ocean_grid, ds_ocean, ds_ocean_boundary], compat='override')
print("Interpolating ocean velocities")
# interpolate ocean velocities onto the tracer points using xgcm
ds_ocean["xu_ocean"].attrs["axis"] = "X"
ds_ocean["xt_ocean"].attrs["axis"] = "X"
ds_ocean["xu_ocean"].attrs["c_grid_axis_shift"] = 0.5
ds_ocean["xt_ocean"].attrs["c_grid_axis_shift"] = 0.0
ds_ocean["yu_ocean"].attrs["axis"] = "Y"
ds_ocean["yt_ocean"].attrs["axis"] = "Y"
ds_ocean["yu_ocean"].attrs["c_grid_axis_shift"] = 0.5
ds_ocean["yt_ocean"].attrs["c_grid_axis_shift"] = 0.0
grid = Grid(ds_ocean)
# fill missing values with 0, then interpolate.
tracer_ref = ds_ocean["surface_temp"]
sst_wet_mask = ~np.isnan(tracer_ref)
ds_ocean["u_ocean"] = grid.interp_like(
ds_ocean["usurf"].fillna(0), tracer_ref
).where(sst_wet_mask)
ds_ocean["v_ocean"] = grid.interp_like(
ds_ocean["vsurf"].fillna(0), tracer_ref
).where(sst_wet_mask)
# Loading atmos data
atmos_path = "gs://cmip6/GFDL_CM2_6/control/atmos_daily.zarr"
ds_atmos = xr.open_dataset(fs.get_mapper(atmos_path), chunks={"time": 120}, **kwargs).chunk(
{"time": 3}
)
# rename the atmos data coordinates only to CESM conventions
ds_atmos = ds_atmos.rename({'grid_xt':'lon', 'grid_yt':'lat'})
print("Modify units")
ds_ocean["surface_temp"] = ds_ocean["surface_temp"] + 273.15
ds_atmos["slp"] = ds_atmos["slp"] * 100 # TODO: Double check this
# Merge atmos and ocean datasets
ds_merged = load_and_combine_cm26(fs, inline_array=True)
return ds_merged