-
Notifications
You must be signed in to change notification settings - Fork 0
/
gcm_filtering.py
90 lines (82 loc) · 2.75 KB
/
gcm_filtering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from xgcm import Grid
import pop_tools
import gcsfs
import fsspec as fs
import numpy as np
import xesmf as xe
import xarray as xr
import random
import matplotlib.pyplot as plt
import warnings
from xgcm import Grid
import gcm_filters
def filter_inputs(
da: xr.DataArray,
wet_mask: xr.DataArray,
dims: list,
filter_scale: int,
filter_type: str = "taper",
) -> xr.DataArray:
"""filters input using gcm-filters"""
if filter_type == "gaussian":
input_filter = gcm_filters.Filter(
filter_scale=filter_scale,
dx_min=1,
filter_shape=gcm_filters.FilterShape.GAUSSIAN,
grid_type=gcm_filters.GridType.REGULAR_WITH_LAND,
grid_vars={"wet_mask": wet_mask},
)
elif filter_type == "taper":
# transition_width = np.pi/2
transition_width = np.pi
input_filter = gcm_filters.Filter(
filter_scale=filter_scale,
transition_width=transition_width,
dx_min=1,
filter_shape=gcm_filters.FilterShape.TAPER,
grid_type=gcm_filters.GridType.REGULAR_WITH_LAND,
grid_vars={"wet_mask": wet_mask},
)
elif filter_type == "tripolar_pop":
input_filter = gcm_filters.Filter(
filter_scale=filter_scale,
dx_min=1,
filter_shape=gcm_filters.FilterShape.GAUSSIAN,
grid_type=gcm_filters.GridType.TRIPOLAR_REGULAR_WITH_LAND_AREA_WEIGHTED,
grid_vars={"area":da.TAREA,"wet_mask": wet_mask},
)
else:
raise ValueError(
f"`filter_type` needs to be `gaussian` or `taper', got {filter_type}"
)
out = input_filter.apply(da, dims=dims)
out.attrs["filter_scale"] = filter_scale
out.attrs["filter_type"] = filter_type
return out
def filter_inputs_dataset(
ds: xr.Dataset,
dims: list,
filter_scale: int,
timedim: str = "time",
filter_type: str = "gaussian", #taper
) -> xr.Dataset:
"""Wrapper that filters a whole dataset, generating a wet_mask from
the nanmask of the first timestep (if time is present)."""
ds_out = xr.Dataset()
# create a wet mask that only allows values which are 'wet' in all variables
wet_masks = []
for var in ds.data_vars:
da = ds[var]
if timedim in da.dims:
mask_da = da.isel({timedim: 0})
else:
mask_da = da
wet_masks.append((~np.isnan(mask_da)))
combined_wet_mask = xr.concat(wet_masks, dim="var").all("var").astype(int)
for var in ds.data_vars:
ds_out[var] = filter_inputs(
ds[var], combined_wet_mask, dims, filter_scale, filter_type=filter_type
)
ds_out.attrs["filter_scale"] = filter_scale
ds_out.attrs["filter_type"] = filter_type
return ds_out