Skip to content

Commit bc5fab6

Browse files
authored
ngen mts (#33)
* update hbv_2 for mts * update dhbv2 models for ngen mts * add mts hbv2 * update mts to make daily parameters optional * fix caching * updates to latest code release * cleanup * update gh actions * remove print statement * update license check * new license check * patch license agreement
1 parent 9263a69 commit bc5fab6

File tree

9 files changed

+1343
-25
lines changed

9 files changed

+1343
-25
lines changed

.github/workflows/lint.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ jobs:
1616

1717
steps:
1818
- name: Checkout code
19-
uses: actions/checkout@v5.0.0
19+
uses: actions/checkout@v6.0.1
2020

2121
- name: Install uv + Python
22-
uses: astral-sh/[email protected].2
22+
uses: astral-sh/[email protected].6
2323
with:
2424
python-version: ${{ matrix.python-version }}
2525
enable-cache: true

.github/workflows/pytest.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ jobs:
1616

1717
steps:
1818
- name: Checkout code
19-
uses: actions/checkout@v5.0.0
19+
uses: actions/checkout@v6.0.1
2020

2121
- name: Install uv + Python
22-
uses: astral-sh/[email protected].2
22+
uses: astral-sh/[email protected].6
2323
with:
2424
python-version: ${{ matrix.python-version }}
2525
enable-cache: true

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ repos:
1010
- id: check-toml
1111

1212
- repo: https://github.com/astral-sh/ruff-pre-commit
13-
rev: v0.14.2
13+
rev: v0.14.8
1414
hooks:
1515
- id: ruff-check
1616
types_or: [python, pyi, jupyter]
@@ -20,7 +20,7 @@ repos:
2020
args: [--config, pyproject.toml]
2121

2222
- repo: https://github.com/kynan/nbstripout
23-
rev: 0.8.1
23+
rev: 0.8.2
2424
hooks:
2525
- id: nbstripout
2626
types: [jupyter]

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# HydroDL2: Differentiable Hydrological Models
22

3-
[![Python](https://img.shields.io/badge/python-3.9%20%7C%203.12%20%7C%203.13-blue)](https://www.python.org/downloads/)
4-
[![PyTorch](https://img.shields.io/badge/PyTorch-2.9.0-EE4C2C?logo=pytorch)](https://pytorch.org/)
3+
[![Python](https://img.shields.io/badge/python-3.9%20%7C%203.12%20%7C%203.13-blue?labelColor=333333)](https://www.python.org/downloads/)
4+
[![PyTorch Version](https://img.shields.io/badge/dynamic/json?label=PyTorch&query=info.version&url=https%3A%2F%2Fpypi.org%2Fpypi%2Ftorch%2Fjson&logo=pytorch&color=EE4C2C&logoColor=F900FF&labelColor=333333)](https://pypi.org/project/torch/)
55

6-
[![Build](https://github.com/mhpi/hydrodl2/actions/workflows/pytest.yaml/badge.svg?branch=master)](https://github.com/mhpi/hydrodl2/actions/workflows/pytest.yaml/)
7-
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
6+
[![Build](https://img.shields.io/github/actions/workflow/status/mhpi/generic_deltamodel/pytest.yaml?branch=master&logo=github&label=tests&labelColor=333333)](https://github.com/mhpi/generic_deltamodel/actions/workflows/pytest.yaml)
7+
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json&labelColor=333333)](https://github.com/astral-sh/ruff)
88

99
---
1010

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ license = { file = "LICENSE" }
1212
authors = [
1313
{ name = "Leo Lonzarich", email = "[email protected]" },
1414
{ name = "Yalan Song", email = "[email protected]" },
15+
{ name = "Wencong Yang", email = "[email protected]" },
1516
{ name = "Tadd Bindas", email = "[email protected]" },
1617
]
1718
maintainers = [

src/hydrodl2/__init__.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
# src/hydrodl2/__init__.py
12
import logging
23
import os
4+
import sys
35
from datetime import datetime
46
from importlib.resources import files
57
from pathlib import Path
@@ -69,7 +71,13 @@ def _check_license_agreement():
6971

7072
print("-" * 40)
7173

72-
response = input("Do you agree to these terms? Type 'Yes' to continue: ")
74+
try:
75+
response = input("Do you agree to these terms? Type 'Yes' to continue: ")
76+
except EOFError:
77+
# If we get here, it means we're in an environment that can't take
78+
# input -- default to no agreement.
79+
print("\n[!] No terminal detected. Skipping license prompt.")
80+
return
7381

7482
if response.strip().lower() in ['yes', 'y']:
7583
try:
@@ -92,7 +100,23 @@ def _check_license_agreement():
92100
raise SystemExit(1)
93101

94102

103+
def _should_skip_license():
104+
"""Returns True if the license check should be bypassed."""
105+
# 1. Check if we are in a Non-Interactive shell (No user to type 'Yes')
106+
if not sys.stdin.isatty():
107+
return True
108+
109+
# 2. Check for common 'Silent' or 'CI' flags
110+
if os.environ.get('CI') or os.environ.get('NGEN_SILENT'):
111+
return True
112+
113+
# 3. Check for the Docker flag (as a fallback)
114+
if os.path.exists('/.dockerenv'):
115+
return True
116+
117+
return False
118+
119+
95120
# This only runs once when package is first imported.
96-
if not os.environ.get('CI'):
97-
# Skip license check in CI envs (e.g., GitHub Actions)
121+
if not _should_skip_license():
98122
_check_license_agreement()

src/hydrodl2/models/hbv/hbv_2.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66

77

88
class Hbv_2(torch.nn.Module):
9-
"""HBV 2.0 ~.
9+
"""HBV 2.0.
1010
11-
Multi-component, multiscale, differentiable PyTorch HBV model with rainfall
11+
Multi-component, multi-scale, differentiable PyTorch HBV model with rainfall
1212
runoff simulation on unit basins.
1313
1414
Authors
1515
-------
16-
- Yalan Song, Leo Lonzarich
16+
- Yalan Song, Leo Lonzarich, Wencong Yang
1717
- (Original NumPy HBV ver.) Beck et al., 2020 (http://www.gloh2o.org/hbv/).
1818
- (HBV-light Version 2) Seibert, 2005
1919
(https://www.geo.uzh.ch/dam/jcr:c8afa73c-ac90-478e-a8c7-929eed7b1b62/HBV_manual_2005.pdf).
@@ -48,14 +48,16 @@ def __init__(
4848
self.dynamic_params = []
4949
self.dy_drop = 0.0
5050
self.variables = ['prcp', 'tmean', 'pet']
51-
self.routing = True
51+
self.routing = False
52+
self.lenF = 15
5253
self.comprout = False
54+
self.muwts = None
5355
self.nearzero = 1e-5
5456
self.nmul = 1
5557
self.cache_states = False
5658
self.device = device
5759

58-
self.states, self._states_cache = None, None
60+
self.states, self._state_cache = None, None
5961

6062
self.state_names = [
6163
'SNOWPACK', # Snowpack storage
@@ -124,7 +126,7 @@ def __init__(
124126
self.comprout = config.get('comprout', self.comprout)
125127
self.nearzero = config.get('nearzero', self.nearzero)
126128
self.nmul = config.get('nmul', self.nmul)
127-
self.cache_states = config.get('cache_states', False)
129+
self.cache_states = config.get('cache_states', self.cache_states)
128130
self._set_parameters()
129131

130132
def _init_states(self, ngrid: int) -> tuple[torch.Tensor]:
@@ -145,7 +147,7 @@ def get_states(self) -> Optional[tuple[torch.Tensor, ...]]:
145147
tuple[torch.Tensor, ...]
146148
A tuple containing the states (SNOWPACK, MELTWATER, SM, SUZ, SLZ).
147149
"""
148-
return self._states_cache
150+
return self._state_cache
149151

150152
def load_states(
151153
self,
@@ -380,10 +382,10 @@ def forward(
380382
)
381383

382384
# State caching
383-
self._states_cache = [s.detach() for s in states]
385+
self._state_cache = states
384386

385387
if self.cache_states:
386-
self.states = self._states_cache
388+
self.states = tuple(s[-1].detach() for s in self._state_cache)
387389

388390
return fluxes
389391

@@ -398,6 +400,8 @@ def _PBM(
398400
) -> Union[tuple, dict[str, torch.Tensor]]:
399401
"""Run through process-based model (PBM).
400402
403+
Flux outputs are in mm/day.
404+
401405
Parameters
402406
----------
403407
forcing
@@ -449,6 +453,13 @@ def _PBM(
449453
SWE_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
450454
capillary_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
451455

456+
# NOTE: new for MTS -- Save model states for all time steps.
457+
SNOWPACK_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
458+
MELTWATER_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
459+
SM_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
460+
SUZ_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
461+
SLZ_sim = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device)
462+
452463
param_dict = {}
453464
for t in range(nsteps):
454465
# Get dynamic parameter values per timestep.
@@ -541,6 +552,7 @@ def _PBM(
541552
Q2 = param_dict['parK2'] * SLZ
542553
SLZ = SLZ - Q2
543554

555+
# --- Outputs ---
544556
Qsimmu[t, :, :] = Q0 + Q1 + Q2
545557
Q0_sim[t, :, :] = Q0
546558
Q1_sim[t, :, :] = Q1
@@ -555,6 +567,13 @@ def _PBM(
555567
tosoil_sim[t, :, :] = tosoil
556568
PERC_sim[t, :, :] = PERC
557569

570+
# NOTE: new for MTS -- Save model states for all time steps.
571+
SNOWPACK_sim[t, :, :] = SNOWPACK
572+
MELTWATER_sim[t, :, :] = MELTWATER
573+
SM_sim[t, :, :] = SM
574+
SUZ_sim[t, :, :] = SUZ
575+
SLZ_sim[t, :, :] = SLZ
576+
558577
# Get the average or weighted average using learned weights.
559578
if self.muwts is None:
560579
Qsimavg = Qsimmu.mean(-1)
@@ -574,7 +593,7 @@ def _PBM(
574593
UH = uh_gamma(
575594
self.routing_param_dict['route_a'].repeat(nsteps, 1).unsqueeze(-1),
576595
self.routing_param_dict['route_b'].repeat(nsteps, 1).unsqueeze(-1),
577-
lenF=15,
596+
lenF=self.lenF,
578597
)
579598
rf = torch.unsqueeze(Qsim, -1).permute([1, 2, 0]) # [gages,vars,time]
580599
UH = UH.permute([1, 2, 0]) # [gages,vars,time]
@@ -603,11 +622,11 @@ def _PBM(
603622
Qs = torch.unsqueeze(Qsimavg, -1)
604623
Q0_rout = Q1_rout = Q2_rout = None
605624

606-
states = (SNOWPACK, MELTWATER, SM, SUZ, SLZ)
625+
states = (SNOWPACK_sim, MELTWATER_sim, SM_sim, SUZ_sim, SLZ_sim)
607626

608627
if self.initialize:
609628
# If initialize is True, only return warmed-up storages.
610-
return states
629+
return {}, states
611630
else:
612631
# Baseflow index (BFI) calculation
613632
BFI_sim = (

0 commit comments

Comments
 (0)