Skip to content

Commit

Permalink
#47 added _step logic for tendency based learning
Browse files Browse the repository at this point in the history
Co-authored-by: Jakob Schloer <[email protected]>
  • Loading branch information
Rilwan-Adewoyin committed Sep 4, 2024
1 parent ecdba39 commit 677320d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ def _step(
) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]:
return self.step_functions[self.prediction_mode](batch, batch_idx, validation_mode, in_place_proc)

# NOTE (jakob-schloer): Observation on nomenclature - is this _step_residual function only residual if the "self.model" has a residual structure???
# NOTE (rilwan-adewoying): Naming problems can maybe be solved by moving alot of this tendency logic in _step_tendency to advance_input

def _step_residual(
self,
batch: torch.Tensor,
Expand Down
3 changes: 2 additions & 1 deletion src/anemoi/training/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def model(self) -> GraphForecaster:
"graph_data": self.graph_data,
"metadata": self.metadata,
"statistics": self.datamodule.statistics,
"statistics_tendencies": self.datamodule.statistics_tendencies,
}
if self.load_weights_only:
LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
Expand Down Expand Up @@ -328,7 +329,7 @@ def train(self) -> None:
# run a fixed no of batches per epoch (helpful when debugging)
limit_train_batches=self.config.dataloader.limit_batches.training,
limit_val_batches=self.config.dataloader.limit_batches.validation,
num_sanity_val_steps=4,
num_sanity_val_steps=self.config.training.num_sanity_val_steps,
accumulate_grad_batches=self.config.training.accum_grad_batches,
gradient_clip_val=self.config.training.gradient_clip.val,
gradient_clip_algorithm=self.config.training.gradient_clip.algorithm,
Expand Down

0 comments on commit 677320d

Please sign in to comment.