Skip to content

Commit 3759c64

Browse files
authored
Merge pull request #25 from mhpi/multimodel_dev
Data Loader Update
2 parents 1df026b + 5d8a2b6 commit 3759c64

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1609
-1341
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,4 @@ deltaModel/conf/observations/*
4343
archive/
4444
runs/
4545
results/
46+
validation/

deltaModel/__main__.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44

55
import hydra
66
import torch
7-
from core.data.dataset_loading import get_dataset_dict
7+
from omegaconf import DictConfig
8+
89
from core.utils import initialize_config, print_config, set_randomseed
10+
from core.utils.module_loaders import load_data_loader, load_trainer
911
from models.model_handler import ModelHandler as dModel
10-
from omegaconf import DictConfig
1112
from trainers.trainer import Trainer
1213

1314
log = logging.getLogger(__name__)
1415

1516

16-
1717
@hydra.main(
1818
version_base='1.3',
1919
config_path='conf/',
@@ -31,15 +31,22 @@ def main(config: DictConfig) -> None:
3131
print_config(config)
3232

3333
### Create/Load differentiable model ###
34-
model = dModel(config, verbose=True) #.to(config['device'])
34+
model = dModel(config, verbose=True)
3535

3636
### Process datasets ###
37-
log.info("Processing datasets...")
38-
train_dataset = get_dataset_dict(config, train=True)
39-
eval_dataset = get_dataset_dict(config, train=False)
37+
log.info("Loading dataset...")
38+
data_loader = load_data_loader(config['data_loader'])
39+
data_loader = data_loader(config, test_split=True, overwrite=False)
4040

4141
### Create Trainer object ###
42-
trainer = Trainer(config, model, train_dataset, eval_dataset, verbose=True)
42+
trainer = load_trainer(config['trainer'])
43+
trainer = trainer(
44+
config,
45+
model,
46+
data_loader.train_dataset,
47+
data_loader.eval_dataset,
48+
verbose=True
49+
)
4350

4451
mode = config['mode']
4552
if mode == 'train':

deltaModel/conf/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class ObservationConfig(BaseModel):
105105
name: str = 'not_defined'
106106
train_path: str = 'not_defined'
107107
test_path: str = 'not_defined'
108-
start_time_all: str = 'not_defined'
108+
start_time: str = 'not_defined'
109109
end_time_all: str = 'not_defined'
110110
forcings_all: List[str] = Field(default_factory=list, description="List of dynamic input variables.")
111111
attributes_all: List[str] = Field(default_factory=list, description="List of static input variables.")
@@ -137,6 +137,9 @@ class Config(BaseModel):
137137
random_seed: int = 0
138138
device: str = 'cpu'
139139
gpu_id: int = 0
140+
data_loader: str = 'none'
141+
data_sampler: str = 'none'
142+
trainer: str = 'none'
140143
save_path: str
141144
train: TrainingConfig
142145
test: TestingConfig
@@ -182,6 +185,9 @@ def check_device(cls, values):
182185
random_seed=42,
183186
device='cuda',
184187
gpu_id=0,
188+
data_loader='base_loader',
189+
data_sampler='base_sampler',
190+
trainer='trainer',
185191
save_path='../results',
186192
train={
187193
'start_time': '2000/01/01',

deltaModel/conf/config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ random_seed: 111111
1111
device: cuda
1212
gpu_id: 0
1313

14+
data_loader: base_loader
15+
data_sampler: base_sampler
16+
trainer: trainer
17+
1418
save_path: ../results
1519

1620

deltaModel/conf/observations/none.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ name: None
22
train_path: /path/to/training/data
33
test_path: /path/to/validation/data
44

5-
start_time_all: 2000/01/01
6-
end_time_all: 2024/12/31
5+
start_time: 2000/01/01
6+
end_time: 2024/12/31
77

88
forcings_all: [
99
x1_var,

deltaModel/core/calc/metrics.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515
class Metrics(BaseModel):
1616
"""Metrics for model evaluation.
1717
18-
Using pydantic BaseModel for validation.
18+
Using Pydantic BaseModel for validation.
1919
Metrics are calculated at each grid point and are listed below.
2020
2121
Adapted from Tadd Bindas, Yalan Song, Farshid Rahmani.
22-
23-
Note: Considering conversion to torch.nn.module for codebase consistency.
2422
"""
2523
model_config = ConfigDict(arbitrary_types_allowed=True)
2624
pred: npt.NDArray[np.float32]
@@ -70,9 +68,11 @@ def __init__(
7068

7169
super().__init__(pred=pred, target=target)
7270

73-
def model_post_init(self, __context: Any):
71+
def model_post_init(self, __context: Any) -> Any:
7472
"""Calculate metrics.
75-
73+
74+
This method is called after the model is initialized.
75+
7676
Parameters
7777
----------
7878
__context : Any
@@ -208,7 +208,7 @@ def validate_pred(cls, metrics: Any) -> Any:
208208
raise ValueError(msg)
209209
return metrics
210210

211-
def calc_statistics(self, *args, **kwargs) -> Dict[str, Dict[str, float]]:
211+
def calc_stats(self, *args, **kwargs) -> Dict[str, Dict[str, float]]:
212212
"""Calculate aggregate statistics of metrics."""
213213
stats = {}
214214
model_dict = self.model_dump()
@@ -225,15 +225,15 @@ def calc_statistics(self, *args, **kwargs) -> Dict[str, Dict[str, float]]:
225225
}
226226
return stats
227227

228-
def dump_agg_statistics(self, path: str) -> None:
228+
def model_dump_agg_stats(self, path: str) -> None:
229229
"""Dump aggregate statistics (median, mean, std) to json or csv.
230230
231231
Parameters
232232
----------
233233
path : str
234234
Path to save file.
235235
"""
236-
stats = self.calc_statistics()
236+
stats = self.calc_stats()
237237

238238
if path.endswith('.json'):
239239
with open(path, 'w') as f:
@@ -271,7 +271,7 @@ def dump_metrics(self, path: str) -> None:
271271
"""
272272
# Save aggregate statistics
273273
save_path = os.path.join(path, 'metrics_agg.json')
274-
self.dump_agg_statistics(save_path)
274+
self.model_dump_agg_stats(save_path)
275275

276276
# Save raw metrics
277277
save_path = os.path.join(path, f'metrics.json')
@@ -280,16 +280,6 @@ def dump_metrics(self, path: str) -> None:
280280
with open(save_path, "w") as f:
281281
json.dump(json_dat, f)
282282

283-
@property
284-
def ngrid(self) -> int:
285-
"""Calculate number of items in grid."""
286-
return self.pred.shape[0]
287-
288-
@property
289-
def nt(self) -> int:
290-
"""Calculate number of time steps."""
291-
return self.pred.shape[1]
292-
293283
def tile_mean(self, data: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
294284
"""Calculate mean of target.
295285
@@ -305,6 +295,16 @@ def tile_mean(self, data: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
305295
"""
306296
return np.tile(np.nanmean(data, axis=1), (self.nt, 1)).transpose()
307297

298+
@property
299+
def ngrid(self) -> int:
300+
"""Calculate number of items in grid."""
301+
return self.pred.shape[0]
302+
303+
@property
304+
def nt(self) -> int:
305+
"""Calculate number of time steps."""
306+
return self.pred.shape[1]
307+
308308
@staticmethod
309309
def _bias(
310310
pred: npt.NDArray[np.float32],

0 commit comments

Comments
 (0)