-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdemo.py
430 lines (357 loc) · 26.5 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
# This code was written as a brief demonstrator script of the rest of repository
# to explore forecasting capacities for storm surges at the Mediterranean sea
# in the context of the 3rd MedCyclones Workshop & Training School 2024.
# See https://nikal.eventsair.com/medcyclone-workshop-2024/.
# The script is not stand-alone and should be run in the context of the repository.
# For further references, be referred to https://github.com/PatrickESA/StormSurgeCastNet
# and the associated scientific publication by Ebel et al (2024).
import os
import sys
import glob
import argparse
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import shapely
import datetime
import rioxarray
import pandas as pd
os.environ['WANDB_DISABLED'] = 'true'
os.environ['WANDB_SILENT']="true"
os.environ['CUBLAS_WORKSPACE_CONFIG'] = '4096:8'
from utide import solve, reconstruct
from scipy.ndimage import uniform_filter1d
import torch
sharing_strategy = "file_system"
torch.multiprocessing.set_sharing_strategy(sharing_strategy)
import dask
dask.config.set(scheduler='synchronous')
dirname = os.path.dirname(os.path.abspath(__file__))
sys.path.append(dirname)
from util import utils
from parse_args import create_parser
from train import seed_packages
from util.model_utils import get_model, load_checkpoint
from util.dataLoader import get_gtsm_database, get_lsm, get_era5, rasterize_gauges, rasterize_dense, roi_around, resize_box, gtsm_in_bounds
def parse_even_more_args(argv):
parser = argparse.ArgumentParser("Download global GTSM data")
parser.add_argument("--event", type=str, default='Zorbas', help="e.g. 'Zorbas' or 'Ianos'")
parser.add_argument("--model", type=str, default='utae', help="e.g. 'metnet3' or 'utae'")
parser.add_argument("--plot", dest="plot", action="store_true", help="whether to create some plots.")
parser.add_argument("--end_h", type=int, default=12, help="Input sequence length in hours.")
parser.add_argument("--lead_h", type=int, default=6, help="Lead time in hours.")
return parser.parse_args(argv)
parser = create_parser(mode='train')
config = utils.str2list(parser.parse_args(), list_args=["encoder_widths", "decoder_widths", "out_conv"])
even_more_config = parse_even_more_args(sys.argv[1:])
for key, val in vars(even_more_config).items(): vars(config)[key] = val
seed_packages(config.rdm_seed)
# fetch tidal gauge data, see:
# map: https://www.ioc-sealevelmonitoring.org/map.php?code=kala
# Sealevel Station Catalogue: http://www.ioc-sealevelmonitoring.org/ssc/
# define the tidal gauges of interest for storm surge forecasting in the context of medicane Zorbas
# see: https://en.wikipedia.org/wiki/Cyclone_Zorbas
# "Collected data indicated that the surge generated by Zorbas reached a maximum value between about 0.8 m and 1.2 m above mean sea level (msl)
# along the coast of south-eastern Sicily", https://www.sciencedirect.com/science/article/abs/pii/S0025322721001389
# Reggio Calabria https://www.ioc-sealevelmonitoring.org/station.php?code=RC09
# rad: https://www.ioc-sealevelmonitoring.org/bgraph.php?code=RC09&output=tab&period=7&endtime=2018-10-02
# Messina https://www.ioc-sealevelmonitoring.org/station.php?code=ME13
# rad: https://www.ioc-sealevelmonitoring.org/bgraph.php?code=ME13&output=tab&period=7&endtime=2018-10-02
# Gokceada http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-gokc
# https://www.ioc-sealevelmonitoring.org/bgraph.php?code=gokc&output=tab&period=7&endtime=2018-10-02
# Kalamata http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-kala
# pr1: https://www.ioc-sealevelmonitoring.org/bgraph.php?code=kala&output=tab&period=7&endtime=2018-10-02
# Katakolo http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-kata
# pr1: https://www.ioc-sealevelmonitoring.org/bgraph.php?code=kata&output=tab&period=7&endtime=2018-10-02
# Otranto http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-otro
# rad: https://www.ioc-sealevelmonitoring.org/bgraph.php?code=OT15&output=tab&period=7&endtime=2018-10-02
# Peiraias http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-peir
# pr1: https://www.ioc-sealevelmonitoring.org/bgraph.php?code=peir&output=tab&period=7&endtime=2018-10-02
# ---
# Catania http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-catn
# NO DATA --- data gap specifically for September 2018
# Zakynthos https://www.ioc-sealevelmonitoring.org/station.php?code=zkth
# NO DATA
# Koroni https://www.ioc-sealevelmonitoring.org/station.php?code=koro
# NO DATA
# Kapsali https://www.ioc-sealevelmonitoring.org/station.php?code=kaps
# NO DATA
# Kasteli https://www.ioc-sealevelmonitoring.org/station.php?code=kast
# NO DATA
# Paleochora https://www.ioc-sealevelmonitoring.org/station.php?code=pale
# NO DATA
# Kyparissia http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-kypa
# NO DATA
# Thessaloniki http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-thes
# NO DATA
# define the tidal gauges of interest for storm surge forecasting in the context of medicane Ianos
# see: https://en.wikipedia.org/wiki/Cyclone_Ianos
# Catania http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-catn
# https://www.ioc-sealevelmonitoring.org/bgraph.php?code=CT03&output=tab&period=7&endtime=2020-09-21
# Gokceada http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-gokc
# https://www.ioc-sealevelmonitoring.org/bgraph.php?code=gokc&output=tab&period=7&endtime=2020-09-21
# Kalamata http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-kala
# https://www.ioc-sealevelmonitoring.org/bgraph.php?code=kala&output=tab&period=7&endtime=2020-09-21
# Katakolo http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-kata
# https://www.ioc-sealevelmonitoring.org/bgraph.php?code=kata&output=tab&period=7&endtime=2020-09-21
# Kyparissia http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-kypa
# https://www.ioc-sealevelmonitoring.org/bgraph.php?code=kypa&output=tab&period=7&endtime=2020-09-21
# Otranto http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-otro
# https://www.ioc-sealevelmonitoring.org/bgraph.php?code=OT15&output=tab&period=7&endtime=2020-09-21
# Peiraias http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-peir
# https://www.ioc-sealevelmonitoring.org/bgraph.php?code=peir&output=tab&period=7&endtime=2020-09-21
# Thessaloniki http://www.ioc-sealevelmonitoring.org/ssc/stationdetails.php?id=SSC-thes
# https://www.ioc-sealevelmonitoring.org/bgraph.php?code=thes&output=tab&period=7&endtime=2020-09-21
# ---
# Zakynthos https://www.ioc-sealevelmonitoring.org/station.php?code=zkth
# NO DATA
# Koroni https://www.ioc-sealevelmonitoring.org/station.php?code=koro
# NO DATA
# Kapsali https://www.ioc-sealevelmonitoring.org/station.php?code=kaps
# NO DATA
# Kasteli https://www.ioc-sealevelmonitoring.org/station.php?code=kast
# NO DATA
# Paleochora https://www.ioc-sealevelmonitoring.org/station.php?code=pale
# NO DATA
def fetch_data(event):
if event=='Zorbas':
gauges = {'Reggio Calabria': {'lon': 15.648916666667, 'lat': 38.121719444444, 'url': 'https://www.ioc-sealevelmonitoring.org/bgraph.php?code=RC09&output=tab&period=7&endtime=2018-10-02'},
'Messina': {'lon': 15.5635, 'lat': 38.1963, 'url': 'https://www.ioc-sealevelmonitoring.org/bgraph.php?code=ME13&output=tab&period=7&endtime=2018-10-02'},
'Gokceada': {'lon': 25.893889, 'lat': 40.2325, 'url': 'https://www.ioc-sealevelmonitoring.org/bgraph.php?code=gokc&output=tab&period=7&endtime=2018-10-02'},
'Kalamata': {'lon': 22.109833, 'lat': 37.021533, 'url': 'https://www.ioc-sealevelmonitoring.org/bgraph.php?code=kala&output=tab&period=7&endtime=2018-10-02'},
'Katakolo': {'lon': 21.319233, 'lat': 37.64045, 'url': 'https://www.ioc-sealevelmonitoring.org/bgraph.php?code=kata&output=tab&period=7&endtime=2018-10-02'},
'Otranto': {'lon': 18.4969, 'lat': 40.1464, 'url': 'https://www.ioc-sealevelmonitoring.org/bgraph.php?code=OT15&output=tab&period=7&endtime=2018-10-02'},
'Peiraias': {'lon': 23.621217, 'lat': 37.934733, 'url': 'https://www.ioc-sealevelmonitoring.org/bgraph.php?code=peir&output=tab&period=7&endtime=2018-10-02'}
}
elif event=='Ianos':
gauges = {'Catania': {'lon': 15.0938, 'lat': 37.498, 'url': 'https://www.ioc-sealevelmonitoring.org/bgraph.php?code=CT03&output=tab&period=7&endtime=2020-09-21'},
'Gokceada': {'lon': 25.893889, 'lat': 40.2325, 'url': 'https://www.ioc-sealevelmonitoring.org/bgraph.php?code=gokc&output=tab&period=7&endtime=2020-09-21'},
'Kalamata': {'lon': 22.109833, 'lat': 37.021533, 'url': 'https://www.ioc-sealevelmonitoring.org/bgraph.php?code=kala&output=tab&period=7&endtime=2020-09-21'},
'Katakolo': {'lon': 21.319233, 'lat': 37.64045, 'url': 'https://www.ioc-sealevelmonitoring.org/bgraph.php?code=kata&output=tab&period=7&endtime=2020-09-21'},
'Kyparissia': {'lon': 21.66, 'lat': 37.26, 'url': 'https://www.ioc-sealevelmonitoring.org/bgraph.php?code=kypa&output=tab&period=7&endtime=2020-09-21'},
'Otranto': {'lon': 18.4969, 'lat': 40.1464, 'url': 'https://www.ioc-sealevelmonitoring.org/bgraph.php?code=OT15&output=tab&period=7&endtime=2020-09-21'},
'Peiraias': {'lon': 23.621217, 'lat': 37.934733, 'url': 'https://www.ioc-sealevelmonitoring.org/bgraph.php?code=peir&output=tab&period=7&endtime=2020-09-21'},
'Thessaloniki': {'lon': 22.934933, 'lat': 40.632542, 'url': 'https://www.ioc-sealevelmonitoring.org/bgraph.php?code=thes&output=tab&period=7&endtime=2020-09-21'},
}
else: raise KeyError('Unknown event requested')
return gauges
def process_gauges(out_dir, gauges):
if os.path.exists(out_dir):
print('Analysis-ready data already exists. Skipping data processing.')
else:
print('Fetching and pre-processing data.')
# create directory to export pre-processed data to
os.makedirs(out_dir, exist_ok=True)
# fetch dates and sea level heights per tidal gauge
for site, subdict in gauges.items():
df = pd.read_html(subdict['url'])[0]
t = pd.to_datetime(df[0][1:]).values
slh = df[1][1:].values.astype(float)
gauges[site]['dates'] = t
gauges[site]['height'] = slh
# preprocess IOC gauges so they match GTSM dimension
# - due to specifics & limitation of the data, this pipeline may slightly deviate from the usual one
for site, subdict in gauges.items():
station_valid = ~ np.isnan(subdict['height'])
# 1. detrend mean sea level,
# see: https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.detrend.html
print('Detrending')
# note: full pipeline detrends, but we only demean over short time periods
detrend_sl = subdict['height'][station_valid] - np.mean(subdict['height'][station_valid])
# 2. decompose time series into tide & non-tidal components
# see: https://github.com/wesleybowman/UTide/blob/master/utide/_solve.py,
print('Solving')
coef = solve(
subdict['dates'][station_valid], # Time in days since `epoch`, time coordinates
detrend_sl, # Sea-surface height, velocity component, etc. (also see arg time_series_v), put gauge values here
lat=subdict['lat'], # Latitude in degrees (GESLA latitudes are given in [-90, 90])
nodal=False, # True (default) to include nodal/satellite corrections
trend=True, # True (default) to include a linear trend in the model
method="ols", # solvers: {'ols', 'robust'}
conf_int="linear", # confidence intervaL calculation technique, required for reconst
Rayleigh_min=1.0, # Minimum conventional Rayleigh criterion for automatic constituent selection; default is 1.
)
# see: https://github.com/wesleybowman/UTide/blob/master/utide/_reconstruct.py
print('Reconstructing')
tide = reconstruct(
subdict['dates'][station_valid], # Time in days since `epoch`, e.g.: obs.index
coef, # Data structure returned by `utide.solve`
epoch=None,
verbose=True,
constit=None,
min_SNR=2,
min_PE=0,
)
# 2. remove tide component
# Scalar time series is returned as `tide.h`
tide_height = tide['h']
non_tide = detrend_sl - tide_height
# 3. filter the non-tidal residual
subdict['height'][station_valid] = non_tide
subdict['height'][~station_valid] = non_tide.mean()
surge = uniform_filter1d(subdict['height'], 5, mode='reflect')[station_valid]
xd = xr.Dataset(data_vars=dict(sea_level=(['station', 'date_time'], non_tide[None,:]), # use non_tide[None,:] or surge[None,:]
longitude=(['station'], [subdict['lon']]),
latitude =(['station'], [subdict['lat']])),
coords=dict(date_time=subdict['dates'][station_valid], station=[site]),
attrs=dict(description="Filtered and re-processed IOC data.")
)
# downsample to hourly temporal resolution
xd = xd.resample(date_time='H').mean()
# export preprocessed data as netCDF files
xd.to_netcdf(os.path.join(out_dir, f'{site}.nc'))
print(f'Exported {site} as netCDF file.')
def get_sample(config, nIOC, site, sdx, coords, start_h, end_h, lead_h, ncGTSM, ncERA5, stats_data, mask_layer):
# get in-situ tide gauge data
if all([isinstance(var, int) for var in [start_h, end_h, lead_h]]):
data = nIOC.sel(station=site).isel(date_time=slice(start_h, end_h + lead_h))
data_in = data.isel(date_time=slice(start_h, end_h))
elif all([isinstance(var, np.datetime64) for var in [start_h, end_h, lead_h]]) :
data = nIOC.sel(station=site).sel(date_time=slice(start_h, lead_h))
data_in = data.sel(date_time=slice(start_h, end_h))
lead_h = np.array((lead_h - end_h).astype('float32')) # assumed to be int now
else: raise ValueError
data_out = data.isel(date_time=-1)
start_time, end_time, target_time = data_in.date_time.values[0], data_in.date_time.values[-1], data_out.date_time.values
# get polygon of ROI, centered around tidal gauge
roi = roi_around([(data_out.longitude.values, data_out.latitude.values)], config.res, config.context, incl_center=True)[0]
# get gauges
in_idx = [idx for idx in range(len(nIOC.station.values)) if roi.contains(coords[idx]) and (idx!=sdx or config.hyperlocal)]
subset_gauges_in = nIOC.isel(station=in_idx).sel(date_time=slice(start_time, end_time))
subset_gauges_target = nIOC.sel(station=site).sel(date_time=target_time)
target_bound = np.ones(1, bool)
# get input gauges
if len(in_idx) == 0: # input time series contains no gauges (might be due to dropout)
print("Sample contains no GESLA tide gauges at input time series") # no worries, might be due to data dropout
gauges_in = np.full((config.input_t, 2*config.context, 2*config.context), np.nan)
else:
gauges_in = np.array([rasterize_gauges(roi, 2*config.context, subset_gauges_in.sel(date_time=gin), fill_empty=np.nan) for gin in subset_gauges_in.date_time], np.float32)
# get target gauges
if target_bound.sum() == 0:
print("Sample contains no GESLA tide gauges at target time point")
gauges_out = np.full((1, 2*config.context, 2*config.context), np.nan)
else:
gauges_out = np.array([rasterize_gauges(roi, 2*config.context, subset_gauges_target, fill_empty=np.nan)], np.float32)
# get GTSM data
coarse_in = ncGTSM.sel(time=slice(start_time, end_time))
coarse_target = ncGTSM.sel(time=target_time)
lg_box = resize_box(roi, config.res, width=2*config.context)
subset_coarse_in = gtsm_in_bounds(coarse_in.compute(), lg_box.bounds)
subset_coarse_target = gtsm_in_bounds(coarse_target.compute(), lg_box.bounds)
gtsm_in = rasterize_dense(subset_coarse_in, roi, 2*config.context)
gtsm_out = rasterize_dense(subset_coarse_target, roi, 2*config.context)
# get ERA5
weather_in = ncERA5.sel(time=slice(start_time, end_time))
era5_in = get_era5(weather_in, roi, 2*config.context).astype(np.float32)
# translate lead times to temporal encoding
start_date = datetime.datetime(year=1979, month=1, day=1)
td_in = np.array((pd.to_datetime(data_in.date_time)-start_date).total_seconds() // 3600, np.float32)
# z-standardize ERA5 time series
era5_mean = np.stack([stats_data['mean']['ERA5']['msl'], stats_data['mean']['ERA5']['u10'], stats_data['mean']['ERA5']['v10']])[None,:,None,None]
era5_std = np.stack([stats_data['std']['ERA5']['msl'], stats_data['std']['ERA5']['u10'], stats_data['std']['ERA5']['v10']])[None,:,None,None]
# standardize input gauges with GESLA statistics, compute mask of valid pixels, then impute un-filled entries with zeros
gauges_in = (gauges_in-stats_data['mean']['GESLA'])/stats_data['std']['GESLA']
nan_mask = np.isnan(gauges_in)
gauges_in[nan_mask] = 0
# apply land-sea mask to GTSM targets, but ensure that valid gauge sites remain unmasked
# - LSM are needed for training but not at inference time
mask = get_lsm(mask_layer, roi, 2*config.context + 1, pixel_size=config.res)
mask = mask[:, :nan_mask.shape[-2], :nan_mask.shape[-1]]
mask[~nan_mask.min(axis=0,keepdims=True)] = False # set gauged pixels to sea within land-sea mask
gtsm_unmasked = gtsm_out.copy()
gtsm_out[mask] = np.nan
return {'input': {'era5': ((era5_in-era5_mean)/era5_std).astype('float32'),
'sparse': gauges_in[:,None].astype('float32'),
'gtsm': ((gtsm_in-stats_data['mean']['GTSM'])/stats_data['std']['GTSM'])[:,None,...].astype('float32'),
'ls_mask': mask.astype('float32'),
'valid_mask': (1-nan_mask).astype('float32')[:,None,...].astype('float32'), # flag pixels with NaNs at any time point, flip mask
'td': td_in,
'td_lead': np.array(lead_h).astype('float32'),
'lon': roi.centroid.x,
'lat': roi.centroid.y
},
'target': {'sparse': ((gauges_out-stats_data['mean']['GESLA'])/stats_data['std']['GESLA']).astype('float32'),
'gtsm': ((gtsm_out-stats_data['mean']['GTSM'])/stats_data['std']['GTSM']).astype('float32'),
'gtsm_unmasked': gtsm_unmasked,
'id': data_out.station.values,
'lon_gauge': data_out.longitude.values, # longitude of only target gauge
'lat_gauge': data_out.latitude.values # latitude of only target gauge
},
}
def main(config):
if config.model == 'metnet3':
config.end_h = 12 # MetNet can't vary sequence length,
config.input_t = 12 # fix to 12 h
if config.hyperlocal: raise NotImplementedError
gauges = fetch_data(config.event)
out_dir = os.path.join(config.root, f'IOC_medsea_{config.event}')
process_gauges(out_dir, gauges)
if config.event == 'Ianos': raise KeyError('No ERA5 & GTSM simulations for this event.')
era5_path = os.path.join(config.root, 'ERA5', 'stormSurge_hourly_79_18', '')
ncERA5 = xr.open_mfdataset(glob.glob(os.path.join(era5_path, '*.nc')))
# get land-sea mask
mask_path = os.path.join(config.root, 'aux', 'landWater2020_1000m.tif')
mask_layer = rioxarray.open_rasterio(mask_path)
# get mean and std statistics
stats_file = os.path.join(config.root, 'stats.npy')
stats_data = None if not os.path.isfile(stats_file) else np.load(stats_file, allow_pickle='TRUE').item()
if stats_data is None: raise FileNotFoundError
# get model and re-instantiate checkpoints
device = torch.device(config.device)
model = get_model(config)
model = model.to(device)
model.eval()
chkp_path = os.path.join(os.path.expanduser('~'), 'Models', 'release_surgecastnet', f'{config.model}.pth.tar')
try: load_checkpoint(config, config.weight_folder, model, "", chkp_path)
except: raise FileNotFoundError
# do inference
ioc_path = os.path.join(config.root, f'IOC_medsea_{config.event}', '')
nIOC = xr.open_mfdataset(glob.glob(os.path.join(ioc_path, '*.nc')))
coords = shapely.points(coords=[(nIOC.isel(date_time=0).longitude.values[idx], nIOC.isel(date_time=0).latitude.values[idx]) for idx in range(len(nIOC.station.values))])
ncGTSM,_ = get_gtsm_database(os.path.expanduser(config.root))
pred, targ = [], []
for sdx, site in enumerate(nIOC.station.values): # ... for each roi
# define indices of input times and lead time
if config.event == 'Zorbas':
target = [np.datetime64('2018-09-26T13'), np.datetime64('2018-09-26T13'), np.datetime64('2018-09-26T13'),
np.datetime64('2018-09-29T07'), np.datetime64('2018-09-29T13'), np.datetime64('2018-09-29T18'),
np.datetime64('2018-10-01T11')]
times = {'Reggio Calabria': {'start_h': target[0] - np.timedelta64(config.lead_h) - np.timedelta64(config.end_h) + np.timedelta64(1), 'end_h': target[0] - np.timedelta64(config.lead_h), 'lead_h': target[0]},
'Messina': {'start_h': target[1] - np.timedelta64(config.lead_h) - np.timedelta64(config.end_h) + np.timedelta64(1), 'end_h': target[1] - np.timedelta64(config.lead_h), 'lead_h': target[1]},
'Otranto': {'start_h': target[2] - np.timedelta64(config.lead_h) - np.timedelta64(config.end_h) + np.timedelta64(1), 'end_h': target[2] - np.timedelta64(config.lead_h), 'lead_h': target[2]},
'Katakolo': {'start_h': target[3] - np.timedelta64(config.lead_h) - np.timedelta64(config.end_h) + np.timedelta64(1), 'end_h': target[3] - np.timedelta64(config.lead_h), 'lead_h': target[3]},
'Kalamata': {'start_h': target[4] - np.timedelta64(config.lead_h) - np.timedelta64(config.end_h) + np.timedelta64(1), 'end_h': target[4] - np.timedelta64(config.lead_h), 'lead_h': target[4]},
'Peiraias': {'start_h': target[5] - np.timedelta64(config.lead_h) - np.timedelta64(config.end_h) + np.timedelta64(1), 'end_h': target[5] - np.timedelta64(config.lead_h), 'lead_h': target[5]},
'Gokceada': {'start_h': target[6] - np.timedelta64(config.lead_h) - np.timedelta64(config.end_h) + np.timedelta64(1), 'end_h': target[6] - np.timedelta64(config.lead_h), 'lead_h': target[6]},
}
else: raise KeyError('No time information for this event.')
sample = get_sample(config, nIOC, site, sdx, coords,
times[site]['start_h'], times[site]['end_h'], times[site]['lead_h'],
ncGTSM, ncERA5, stats_data, mask_layer)
inputs = {'A': torch.cat((torch.Tensor(sample['input']['sparse']), torch.Tensor(sample['input']['valid_mask']), torch.Tensor(sample['input']['era5']), torch.Tensor(sample['input']['gtsm'])), dim=1).unsqueeze(0),
'B': torch.cat((torch.Tensor(sample['target']['sparse']), torch.Tensor(sample['target']['gtsm'])), dim=0).unsqueeze(0),
'dates': torch.Tensor(sample['input']['td']).unsqueeze(0), 'masks': torch.Tensor(sample['input']['ls_mask']).unsqueeze(0), 'lead': torch.Tensor(sample['input']['td_lead']).unsqueeze(0)}
with torch.no_grad():
# compute predictions
model.set_input(inputs)
model.forward()
out = model.fake_B
# get predictions at gage positions
valid_mask = ~np.isnan(sample['target']['sparse'])[0, ...]
valid_gauge = out[0, 0, 0, valid_mask]
valid_gauge_m = stats_data['std']['GESLA'] * valid_gauge + stats_data['mean']['GESLA']
target_gauge = inputs['B'][0, 0, valid_mask]
target_gauge_m = stats_data['std']['GESLA'] * target_gauge + stats_data['mean']['GESLA']
# do some plotting
if config.plot and site == 'Kalamata':
plt.imsave(f'demo_{config.model}.png', out[0,0,-1, ...].cpu(), cmap='Blues')
plt.imsave(f'demo_target.png', sample['target']['gtsm_unmasked'][0,...], cmap='Blues')
plt.imsave(f'demo_mask.png', np.isnan(sample['target']['gtsm'][0,...]), cmap='binary')
pred.append(valid_gauge_m.item()), targ.append(target_gauge_m.item())
print(f'{site} predictions {valid_gauge_m.item()}, lead time {int(inputs["lead"])} h')
print(f'{site} gauge level {target_gauge_m.item()}')
res = np.array(pred) - np.array(targ)
print(f"MAE: {np.mean(np.abs(res))}, MSE: {np.mean(res**2)}")
if __name__ == "__main__":
main(config)
exit()