Skip to content

Commit

Permalink
Merge pull request #61 from SnowEx/wet-snow-nans
Browse files Browse the repository at this point in the history
Wet snow nans
  • Loading branch information
ZachHoppinen authored Sep 5, 2023
2 parents d95cae1 + c81e390 commit 2f1dbaa
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 204 deletions.
186 changes: 12 additions & 174 deletions scripts/optimize/param_ridgeplots.ipynb

Large diffs are not rendered by default.

61 changes: 33 additions & 28 deletions spicy_snow/processing/wet_snow.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def flag_wet_snow(dataset: xr.Dataset, inplace: bool = False) -> Union[None, xr.
dataset['wet_snow'].loc[dict(time = ts)] = dataset.sel(time = ts)['wet_snow'] - dataset.sel(time = ts)['freeze_flag']
dataset['wet_snow'].loc[dict(time = ts)] = dataset.sel(time = ts)['wet_snow'].where(dataset.sel(time = ts)['wet_snow'] > 0, 0)

dataset['wet_snow'].loc[dict(time = ts)] = dataset.sel(time = ts)['wet_snow'].where(~dataset['s1'].sel(time = ts, band = 'VV').isnull(), np.nan)
prev_time = ts

# if >50% wet of last 4 cycles after feb 1 then set remainer till
Expand All @@ -178,36 +179,40 @@ def flag_wet_snow(dataset: xr.Dataset, inplace: bool = False) -> Union[None, xr.
melt_orbit = (melt_season & (dataset.relative_orbit == orbit))

# check if there are at least 4 time slices in melt season for this orbit
if len(dataset['perma_wet'].loc[dict(time = melt_orbit)]) > 4:

# first set all perma wet to match the flagged wet times from dB drop
dataset['perma_wet'].loc[dict(time = melt_orbit)] = \
dataset['wet_flag'].loc[dict(time = melt_orbit)]

# then set times that were made wet by negative snow index to wet as well
dataset['perma_wet'].loc[dict(time = melt_orbit)] = dataset['perma_wet'].loc[dict(time = melt_orbit)] + \
dataset['alt_wet_flag'].loc[dict(time = melt_orbit)]

# now if we are over 1 (ie it was flagged wet by dB and negative SI flags) we should floor those back to 1
if len(dataset['perma_wet'].loc[dict(time = melt_orbit)]) < 4:
continue

# first set all perma wet to match the flagged wet times from dB drop
dataset['perma_wet'].loc[dict(time = melt_orbit)] = \
dataset['wet_flag'].loc[dict(time = melt_orbit)]

# then set times that were made wet by negative snow index to wet as well
dataset['perma_wet'].loc[dict(time = melt_orbit)] = dataset['perma_wet'].loc[dict(time = melt_orbit)] + \
dataset['alt_wet_flag'].loc[dict(time = melt_orbit)]

# now if we are over 1 (ie it was flagged wet by dB and negative SI flags) we should floor those back to 1
dataset['perma_wet'].loc[dict(time = melt_orbit)] = \
dataset['perma_wet'].loc[dict(time = melt_orbit)].where(dataset['perma_wet'].loc[dict(time = melt_orbit)] <= 1 , 1)

# now calculate the rolling mean of the perma wet so we have a % 0-1 of days out of 4 that were flagged
dataset['perma_wet'].loc[dict(time = melt_orbit)] = \
dataset['perma_wet'].loc[dict(time = melt_orbit)].rolling(time = 4).mean()

# then propogate forward so that if we get to > 50% in a 4 image window we mask the remained of the melt season
# this will fail if bottleneck is installed due to it lacking the min_periods keyword
# see: https://github.com/pydata/xarray/issues/4922

if 'bottleneck' not in sys.modules:
dataset['perma_wet'].loc[dict(time = melt_orbit)] = \
dataset['perma_wet'].loc[dict(time = melt_orbit)].where(dataset['perma_wet'].loc[dict(time = melt_orbit)] <= 1 , 1)
# now calculate the rolling mean of the perma wet so we have a % 0-1 of days out of 4 that were flagged
dataset['perma_wet'].loc[dict(time = melt_orbit)].rolling(time = len(orbit_dataset.time), min_periods = 1).max()
else:
log.info("bottleneck installed. Consider pip uninstalling and re-running if this fails.")
dataset['perma_wet'].loc[dict(time = melt_orbit)] = \
dataset['perma_wet'].loc[dict(time = melt_orbit)].rolling(time = 4).mean()

# then propogate forward so that if we get to > 50% in a 4 image window we mask the remained of the melt season
# this will fail if bottleneck is installed due to it lacking the min_periods keyword
# see: https://github.com/pydata/xarray/issues/4922

if 'bottleneck' not in sys.modules:
dataset['perma_wet'].loc[dict(time = melt_orbit)] = \
dataset['perma_wet'].loc[dict(time = melt_orbit)].rolling(time = len(orbit_dataset.time), min_periods = 1).max()
else:
log.info("bottleneck installed. Consider pip uninstalling and re-running if this fails.")
dataset['perma_wet'].loc[dict(time = melt_orbit)] = \
dataset['perma_wet'].loc[dict(time = melt_orbit)].rolling(time = len(orbit_dataset.time)).max()

dataset['perma_wet'].loc[dict(time = melt_orbit)].rolling(time = len(orbit_dataset.time)).max()


dataset['perma_wet'].loc[dict(time = melt_orbit)] = dataset.sel(dict(time = melt_orbit))['perma_wet'].where(~dataset['s1'].sel(dict(time = melt_orbit, band = 'VV')).isnull(), np.nan)

# if we have no data just set it to not be flagged perma_wet
dataset['perma_wet'] = dataset['perma_wet'].where(~dataset['perma_wet'].isnull(), 0)

Expand Down
43 changes: 41 additions & 2 deletions tests/test_wetsnow.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def test_id_wet_one_orbit(self):
deltaVV = np.random.randn(10, 10, 6) * 3
deltaCR = np.random.randn(10, 10, 6) * 3
ims = np.full((10, 10, 6), 4, dtype = int)
s1 = np.random.randn(10, 10, 6, 3)

# times = [np.datetime64(t) for t in ['2020-01-01','2020-01-02', '2020-01-07','2020-01-08', '2020-01-14', '2020-01-15']]
# [24,1,24,1,24,1]
Expand All @@ -203,6 +204,7 @@ def test_id_wet_one_orbit(self):
deltaVV = (["x", "y", "time"], deltaVV),
deltaCR = (["x", "y", "time"], deltaCR),
ims = (["x", "y", "time"], ims),
s1 = (["x", "y", "time", "band"], s1)
),
coords = dict(
lon = (["x", "y"], lon),
Expand Down Expand Up @@ -266,6 +268,23 @@ def test_id_wet_one_orbit(self):
ds['alt_wet_flag'].loc[dict(time = ds.time[2], x = xi, y = yi)] = 0
ds['freeze_flag'].loc[dict(time = ds.time[2], x = xi, y = yi)] = 0

# 3,3 is nans all along
xi , yi = 3, 3
ds['wet_flag'].loc[dict(time = ds.time[0], x = xi, y = yi)] = 0
ds['alt_wet_flag'].loc[dict(time = ds.time[0], x = xi, y = yi)] = 1
ds['freeze_flag'].loc[dict(time = ds.time[0], x = xi, y = yi)] = 0

ds['wet_flag'].loc[dict(time = ds.time[1], x = xi, y = yi)] = 0
ds['alt_wet_flag'].loc[dict(time = ds.time[1], x = xi, y = yi)] = 0
ds['freeze_flag'].loc[dict(time = ds.time[1], x = xi, y = yi)] = 1

ds['wet_flag'].loc[dict(time = ds.time[2], x = xi, y = yi)] = 1
ds['alt_wet_flag'].loc[dict(time = ds.time[2], x = xi, y = yi)] = 0
ds['freeze_flag'].loc[dict(time = ds.time[2], x = xi, y = yi)] = 0

for t in ds.time:
ds['s1'].loc[dict(time = t, x = xi, y = yi, band = 'VV')] = np.nan

ds = flag_wet_snow(ds)

# 0,0 is wet @ t0, dry @ t1, dry @ t 2
Expand All @@ -286,6 +305,12 @@ def test_id_wet_one_orbit(self):
self.assertEqual( ds['wet_snow'].loc[dict(time = ds.time[1], x = xi, y = yi)], 0)
self.assertEqual( ds['wet_snow'].loc[dict(time = ds.time[2], x = xi, y = yi)], 1)

# 3, 3 is nans all along
xi , yi = 3, 3
self.assertTrue(np.isnan(ds['wet_snow'].loc[dict(time = ds.time[0], x = xi, y = yi)].values))
self.assertTrue(np.isnan(ds['wet_snow'].loc[dict(time = ds.time[1], x = xi, y = yi)].values))
self.assertTrue(np.isnan(ds['wet_snow'].loc[dict(time = ds.time[2], x = xi, y = yi)].values))

def test_id_wet_multiple_orbit(self):

fcf = np.random.randn(10, 10)/10 + 0.5
Expand All @@ -294,6 +319,8 @@ def test_id_wet_multiple_orbit(self):
deltaVV = np.random.randn(10, 10, n) + 0.1
deltaCR = np.random.randn(10, 10, n) + 0.1
ims = np.full((10, 10, n), 4, dtype = int)
s1 = np.random.randn(10, 10, n, 3)


x = np.linspace(0, 9, 10)
y = np.linspace(10, 19, 10)
Expand All @@ -306,6 +333,7 @@ def test_id_wet_multiple_orbit(self):
deltaVV = (["x", "y", "time"], deltaVV),
deltaCR = (["x", "y", "time"], deltaCR),
ims = (["x", "y", "time"], ims),
s1 = (["x", "y", "time", "band"], s1)
),
coords = dict(
lon = (["x", "y"], lon),
Expand Down Expand Up @@ -474,6 +502,7 @@ def test_perma_wet_seasons(self):
wet_flag = np.full((10, 10, n), 0.0)
alt_wet_flag = np.full((10, 10, n), 0.0)
freeze_flag = np.full((10, 10, n), 0.0)
s1 = np.random.randn(10, 10, n, 3)

x = np.linspace(0, 9, 10)
y = np.linspace(10, 19, 10)
Expand All @@ -486,11 +515,13 @@ def test_perma_wet_seasons(self):
wet_flag = (["x", "y", "time"], wet_flag),
alt_wet_flag = (["x", "y", "time"], alt_wet_flag),
freeze_flag = (["x", "y", "time"], freeze_flag),
s1 = (["x", "y", "time", "band"], s1)
),
coords = dict(
lon = (["x", "y"], lon),
lat = (["x", "y"], lat),
time = times,
band = ["VV", "VH", "inc"],
relative_orbit = (["time"], ros))
)

Expand All @@ -516,8 +547,6 @@ def test_perma_wet_seasons(self):
ds['alt_wet_flag'].loc[dict(time = t, x = xi, y = yi)] = 0
ds['freeze_flag'].loc[dict(time = t, x = xi, y = yi)] = 1



t = ds.time[11]
ds['wet_flag'].loc[dict(time = t, x = xi, y = yi)] = 0
ds['alt_wet_flag'].loc[dict(time = t, x = xi, y = yi)] = 1
Expand All @@ -533,6 +562,14 @@ def test_perma_wet_seasons(self):
ds['alt_wet_flag'].loc[dict(time = t, x = xi, y = yi)] = 0
ds['freeze_flag'].loc[dict(time = t, x = xi, y = yi)] = 1

# set one to nan and be sure it is nan for perma wet
t = ds.time[15]
xi , yi = 4 ,4
ds['wet_flag'].loc[dict(time = t, x = xi, y = yi)] = 0
ds['alt_wet_flag'].loc[dict(time = t, x = xi, y = yi)] = 0
ds['freeze_flag'].loc[dict(time = t, x = xi, y = yi)] = 1
ds['s1'].loc[dict(time = t, x = xi, y = yi, band = 'VV')] = np.nan

ds = flag_wet_snow(ds)

# check no perma wet before february
Expand Down Expand Up @@ -568,6 +605,8 @@ def test_perma_wet_seasons(self):
self.assertEqual(ds['wet_snow'].sel(time = ds.time[3]).loc[dict(x = 3, y = 3)], 1)
self.assertEqual(ds['wet_snow'].sel(time = ds.time[5]).loc[dict(x = 3, y = 3)], 0)
self.assertEqual(ds['wet_snow'].sel(time = ds.time[7]).loc[dict(x = 3, y = 3)], 0)

self.assertTrue(np.isnan(ds['wet_snow'].loc[dict(time = ds.time[15], x = 4, y = 4)].values))

if __name__ == '__main__':
unittest.main()

0 comments on commit 2f1dbaa

Please sign in to comment.