66
77
88class Hbv_2 (torch .nn .Module ):
9- """HBV 2.0 ~ .
9+ """HBV 2.0.
1010
11- Multi-component, multiscale , differentiable PyTorch HBV model with rainfall
11+ Multi-component, multi-scale , differentiable PyTorch HBV model with rainfall
1212 runoff simulation on unit basins.
1313
1414 Authors
1515 -------
16- - Yalan Song, Leo Lonzarich
16+ - Yalan Song, Leo Lonzarich, Wencong Yang
1717 - (Original NumPy HBV ver.) Beck et al., 2020 (http://www.gloh2o.org/hbv/).
1818 - (HBV-light Version 2) Seibert, 2005
1919 (https://www.geo.uzh.ch/dam/jcr:c8afa73c-ac90-478e-a8c7-929eed7b1b62/HBV_manual_2005.pdf).
@@ -48,14 +48,16 @@ def __init__(
4848 self .dynamic_params = []
4949 self .dy_drop = 0.0
5050 self .variables = ['prcp' , 'tmean' , 'pet' ]
51- self .routing = True
51+ self .routing = False
52+ self .lenF = 15
5253 self .comprout = False
54+ self .muwts = None
5355 self .nearzero = 1e-5
5456 self .nmul = 1
5557 self .cache_states = False
5658 self .device = device
5759
58- self .states , self ._states_cache = None , None
60+ self .states , self ._state_cache = None , None
5961
6062 self .state_names = [
6163 'SNOWPACK' , # Snowpack storage
@@ -124,7 +126,7 @@ def __init__(
124126 self .comprout = config .get ('comprout' , self .comprout )
125127 self .nearzero = config .get ('nearzero' , self .nearzero )
126128 self .nmul = config .get ('nmul' , self .nmul )
127- self .cache_states = config .get ('cache_states' , False )
129+ self .cache_states = config .get ('cache_states' , self . cache_states )
128130 self ._set_parameters ()
129131
130132 def _init_states (self , ngrid : int ) -> tuple [torch .Tensor ]:
@@ -145,7 +147,7 @@ def get_states(self) -> Optional[tuple[torch.Tensor, ...]]:
145147 tuple[torch.Tensor, ...]
146148 A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
147149 """
148- return self ._states_cache
150+ return self ._state_cache
149151
150152 def load_states (
151153 self ,
@@ -380,10 +382,10 @@ def forward(
380382 )
381383
382384 # State caching
383- self ._states_cache = [ s . detach () for s in states ]
385+ self ._state_cache = states
384386
385387 if self .cache_states :
386- self .states = self ._states_cache
388+ self .states = tuple ( s [ - 1 ]. detach () for s in self ._state_cache )
387389
388390 return fluxes
389391
@@ -398,6 +400,8 @@ def _PBM(
398400 ) -> Union [tuple , dict [str , torch .Tensor ]]:
399401 """Run through process-based model (PBM).
400402
403+ Flux outputs are in mm/day.
404+
401405 Parameters
402406 ----------
403407 forcing
@@ -449,6 +453,13 @@ def _PBM(
449453 SWE_sim = torch .zeros (Pm .size (), dtype = torch .float32 , device = self .device )
450454 capillary_sim = torch .zeros (Pm .size (), dtype = torch .float32 , device = self .device )
451455
456+ # NOTE: new for MTS -- Save model states for all time steps.
457+ SNOWPACK_sim = torch .zeros (Pm .size (), dtype = torch .float32 , device = self .device )
458+ MELTWATER_sim = torch .zeros (Pm .size (), dtype = torch .float32 , device = self .device )
459+ SM_sim = torch .zeros (Pm .size (), dtype = torch .float32 , device = self .device )
460+ SUZ_sim = torch .zeros (Pm .size (), dtype = torch .float32 , device = self .device )
461+ SLZ_sim = torch .zeros (Pm .size (), dtype = torch .float32 , device = self .device )
462+
452463 param_dict = {}
453464 for t in range (nsteps ):
454465 # Get dynamic parameter values per timestep.
@@ -541,6 +552,7 @@ def _PBM(
541552 Q2 = param_dict ['parK2' ] * SLZ
542553 SLZ = SLZ - Q2
543554
555+ # --- Outputs ---
544556 Qsimmu [t , :, :] = Q0 + Q1 + Q2
545557 Q0_sim [t , :, :] = Q0
546558 Q1_sim [t , :, :] = Q1
@@ -555,6 +567,13 @@ def _PBM(
555567 tosoil_sim [t , :, :] = tosoil
556568 PERC_sim [t , :, :] = PERC
557569
570+ # NOTE: new for MTS -- Save model states for all time steps.
571+ SNOWPACK_sim [t , :, :] = SNOWPACK
572+ MELTWATER_sim [t , :, :] = MELTWATER
573+ SM_sim [t , :, :] = SM
574+ SUZ_sim [t , :, :] = SUZ
575+ SLZ_sim [t , :, :] = SLZ
576+
558577 # Get the average or weighted average using learned weights.
559578 if self .muwts is None :
560579 Qsimavg = Qsimmu .mean (- 1 )
@@ -574,7 +593,7 @@ def _PBM(
574593 UH = uh_gamma (
575594 self .routing_param_dict ['route_a' ].repeat (nsteps , 1 ).unsqueeze (- 1 ),
576595 self .routing_param_dict ['route_b' ].repeat (nsteps , 1 ).unsqueeze (- 1 ),
577- lenF = 15 ,
596+ lenF = self . lenF ,
578597 )
579598 rf = torch .unsqueeze (Qsim , - 1 ).permute ([1 , 2 , 0 ]) # [gages,vars,time]
580599 UH = UH .permute ([1 , 2 , 0 ]) # [gages,vars,time]
@@ -603,11 +622,11 @@ def _PBM(
603622 Qs = torch .unsqueeze (Qsimavg , - 1 )
604623 Q0_rout = Q1_rout = Q2_rout = None
605624
606- states = (SNOWPACK , MELTWATER , SM , SUZ , SLZ )
625+ states = (SNOWPACK_sim , MELTWATER_sim , SM_sim , SUZ_sim , SLZ_sim )
607626
608627 if self .initialize :
609628 # If initialize is True, only return warmed-up storages.
610- return states
629+ return {}, states
611630 else :
612631 # Baseflow index (BFI) calculation
613632 BFI_sim = (
0 commit comments