Skip to content

Commit 4ef6e3e

Browse files
Fix missing checkpoint for forecast engine (#1771)
* implemented * remove eval in interface * lint * incoporate requested changes * fix imports * Fixed missing checkpoint --------- Co-authored-by: moritzhauschulz <[email protected]>
1 parent 9db66cb commit 4ef6e3e

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/weathergen/model/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import torch
2020
import torch.nn as nn
21+
from torch.utils.checkpoint import checkpoint
2122

2223
from weathergen.common.config import Config
2324
from weathergen.datasets.batch import ModelBatch
@@ -585,7 +586,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput:
585586
for step in batch.get_output_idxs():
586587
# apply forecasting engine (if present)
587588
if self.forecast_engine:
588-
tokens = self.forecast_engine(tokens, step)
589+
tokens = checkpoint(self.forecast_engine, tokens, step)
589590

590591
# decoder predictions
591592
output = self.predict_decoders(model_params, step, tokens, batch, output)

0 commit comments

Comments
 (0)