Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4e6c7f7
add multi_step_output
dietervdb-meteo Nov 13, 2025
413d90e
technically working multiple out
dietervdb-meteo Nov 13, 2025
286c40b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2025
5841832
fix valid dates
dietervdb-meteo Nov 14, 2025
82ccdbc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2025
c168043
tmp handling of trace
dietervdb-meteo Nov 14, 2025
fbdb630
fix type
dietervdb-meteo Nov 14, 2025
4a16932
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2025
771d4db
updated docstring
dietervdb-meteo Nov 14, 2025
81f1b0d
add shape comment
dietervdb-meteo Nov 14, 2025
8a79217
clarify rollout step size
dietervdb-meteo Nov 14, 2025
7f7578f
Merge branch 'main' into feat/multi-step-pred
dietervdb-meteo Nov 14, 2025
9b1582a
Merge branch 'main' into feat/multi-step-pred
dietervdb-meteo Nov 20, 2025
260c736
Merge branch 'main' into feat/multi-step-pred
dietervdb-meteo Nov 27, 2025
48b0c98
Merge branch 'main' into feat/multi-step-pred
dietervdb-meteo Dec 4, 2025
ae2a220
Merge branch 'main' into feat/multi-step-pred
dietervdb-meteo Dec 5, 2025
5f98b8f
fix step entry of the state
dietervdb-meteo Dec 5, 2025
3feaaf8
temporarily disable tp output
dietervdb-meteo Dec 5, 2025
d0c9d86
Merge branch 'main' into feat/multi-step-pred
dietervdb-meteo Jan 20, 2026
82c2ee3
re-enable tp output
dietervdb-meteo Jan 20, 2026
d70ee69
Merge branch 'main' into feat/multi-step-pred
dietervdb-meteo Jan 26, 2026
44c02f2
Merge branch 'main' into feat/multi-step-pred
mc4117 Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/anemoi/inference/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,11 @@ def multi_step_input(self) -> Any:
"""Get the multi-step input."""
return self._metadata.multi_step_input

@property
def multi_step_output(self) -> Any:
"""Get the multi-step output."""
return self._metadata.multi_step_output

###########################################################################
# Data retrieval
###########################################################################
Expand Down
6 changes: 6 additions & 0 deletions src/anemoi/inference/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ def multi_step_input(self) -> int:
"""Number of past steps needed for the initial conditions tensor."""
return self._config_training.multistep_input

@cached_property
def multi_step_output(self) -> int:
"""Number of future steps predicted by single model forward."""
# For backward compatibility we set a default of 1
return self._config_training.get("multistep_output", 1)

@cached_property
def prognostic_output_mask(self) -> IntArray:
"""Return the prognostic output mask."""
Expand Down
164 changes: 114 additions & 50 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def predict_step(
def forecast_stepper(
self, start_date: datetime.datetime, lead_time: datetime.timedelta
) -> Generator[tuple[datetime.timedelta, datetime.datetime, datetime.datetime, bool], None, None]:
"""Generate step and date variables for the forecast loop.
"""Generate step and date variables for the forecast autoregressive loop.

Parameters
----------
Expand All @@ -658,16 +658,27 @@ def forecast_stepper(
is_last_step : bool
True if it's the last step of the forecast
"""
steps = lead_time // self.checkpoint.timestep
rollout_step_size = self.checkpoint.timestep * self.checkpoint.multi_step_output
steps = lead_time // rollout_step_size

LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps)
LOG.info(
"Lead time: %s, time stepping: %s Forecasting %s steps through %s autoregressive steps of %s predictions each.",
lead_time,
self.checkpoint.timestep,
self.checkpoint.multi_step_output * steps,
steps,
self.checkpoint.multi_step_output,
)

for s in range(steps):
step = (s + 1) * self.checkpoint.timestep
valid_date = start_date + step
next_date = valid_date
step = (s + 1) * rollout_step_size
valid_dates = [
start_date + s * rollout_step_size + self.checkpoint.timestep * (i + 1)
for i in range(self.checkpoint.multi_step_output)
]
next_dates = valid_dates
is_last_step = s == steps - 1
yield step, valid_date, next_date, is_last_step
yield step, valid_dates, next_dates, is_last_step

def forecast(
self, lead_time: str, input_tensor_numpy: FloatArray, input_state: State
Expand Down Expand Up @@ -721,16 +732,19 @@ def forecast(
if self.verbosity > 0:
self._print_input_tensor("First input tensor", input_tensor_torch)

for s, (step, date, next_date, is_last_step) in enumerate(self.forecast_stepper(start, lead_time)):
title = f"Forecasting step {step} ({date})"

new_state["date"] = date
new_state["previous_step"] = new_state.get("step")
new_state["step"] = step
for s, (step, dates, next_dates, is_last_step) in enumerate(self.forecast_stepper(start, lead_time)):
dates_str = "("
for d in dates:
dates_str += f"{d}, "
dates_str = f"{dates_str[:-2]})"
title = (
f"Forecasting autoregressive step {s}: horizon {step}, freq. {self.checkpoint.timestep} {dates_str}"
)

if self.trace:
# TODO(dieter): check what below is about and how to handle date(s)
self.trace.write_input_tensor(
date,
dates[-1],
s,
input_tensor_torch.cpu().numpy(),
variable_to_input_tensor_index,
Expand All @@ -740,34 +754,52 @@ def forecast(

# Predict next state of atmosphere
with torch.inference_mode(), amp_ctx, ProfilingLabel("Predict step", self.use_profiler), Timer(title):
y_pred = self.predict_step(self.model, input_tensor_torch, fcstep=s, step=step, date=date)
# TODO(dieter) what are these kwargs about? maybe related to interpolator?? check out
y_pred = self.predict_step(self.model, input_tensor_torch, fcstep=s, step=step, date=dates[-1])
# (batch, [time], ensemble, grid/values, variables) -> (time, values, variables)
ndim = y_pred.ndim
if ndim != 5:
# for backwards compatibility
outputs = torch.squeeze(y_pred, dim=(0, 1)).unsqueeze(0)
else:
outputs = torch.squeeze(y_pred, dim=(0, 2))

output = torch.squeeze(y_pred, dim=(0, 1)) # shape: (values, variables)
new_states = [] # TODO(dieter) not clear if needed, but some forcings might need the new states

# Update state
with ProfilingLabel("Updating state (CPU)", self.use_profiler):
for i in range(output.shape[1]):
new_state["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i]
for i in range(self.checkpoint.multi_step_output):
new_state["date"] = dates[i]
new_state["previous_step"] = new_state.get("step")
new_state["step"] = step + (1 + i - self.checkpoint.multi_step_output) * self.checkpoint.timestep

if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_output_tensor("Output tensor", output.cpu().numpy())
output = outputs[i, ...]

if self.trace:
self.trace.write_output_tensor(
date,
s,
output.cpu().numpy(),
self.checkpoint.output_tensor_index_to_variable,
self.checkpoint.timestep,
)
# Update state
with ProfilingLabel("Updating state (CPU)", self.use_profiler):
for i in range(output.shape[1]):
new_state["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i]

yield new_state
if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_output_tensor("Output tensor", output.cpu().numpy())

# No need to prepare next input tensor if we are at the last step
if self.trace:
# TODO(dieter): check what below is about and how to handle date(s)
self.trace.write_output_tensor(
dates[-1],
s,
output.cpu().numpy(),
self.checkpoint.output_tensor_index_to_variable,
self.checkpoint.timestep,
)

yield new_state
new_states.append(new_state)

# No need to prepare next input tensor if we are at the last autoregressive step
if is_last_step:
break

self.output_state_hook(new_state)
# TODO(dieter) support this hook with multi-step output
# self.output_state_hook(new_state)

# Update tensor for next iteration
with ProfilingLabel("Update tensor for next step", self.use_profiler):
Expand All @@ -779,11 +811,21 @@ def forecast(

del y_pred # Recover memory

# TODO(dieter):
# how do forcings use the new_state(s)?
# ComputedForcings only uses it to get latlons
# For CoupledForcings unclear how it is used: don't worry just yet about supporting it
# ConstantForcings irrelevant
# BoundaryForcings currently only work from dataset, there load_forcings_state takes state as argument but doesn't use it
# so for now we can get away with:
new_state = new_states[-1]
#################################

input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(
input_tensor_torch, new_state, next_date, check
input_tensor_torch, new_state, next_dates, check
)
input_tensor_torch = self.add_boundary_forcings_to_input_tensor(
input_tensor_torch, new_state, next_date, check
input_tensor_torch, new_state, next_dates, check
)

if not check.all():
Expand Down Expand Up @@ -836,8 +878,10 @@ def copy_prognostic_fields_to_input_tensor(

prognostic_fields = torch.index_select(y_pred, dim=-1, index=pmask_out)

input_tensor_torch = input_tensor_torch.roll(-1, dims=1)
input_tensor_torch[:, -1, :, pmask_in] = prognostic_fields
input_tensor_torch = input_tensor_torch.roll(-self.checkpoint.multi_step_output, dims=1)

for i in range(1, self.checkpoint.multi_step_output + 1):
input_tensor_torch[:, -i, :, pmask_in] = prognostic_fields[:, -i, ...]

pmask_in_np = pmask_in.detach().cpu().numpy()
if check[pmask_in_np].any():
Expand All @@ -857,7 +901,11 @@ def copy_prognostic_fields_to_input_tensor(
return input_tensor_torch

def add_dynamic_forcings_to_input_tensor(
self, input_tensor_torch: "torch.Tensor", state: State, date: datetime.datetime, check: BoolArray
self,
input_tensor_torch: "torch.Tensor",
state: State,
dates: list[datetime.datetime],
check: BoolArray,
) -> "torch.Tensor":
"""Add dynamic forcings to the input tensor.

Expand All @@ -867,7 +915,7 @@ def add_dynamic_forcings_to_input_tensor(
The input tensor.
state : State
The state.
date : datetime.datetime
dates : list[datetime.datetime]
The date.
check : BoolArray
The check array.
Expand All @@ -881,6 +929,7 @@ def add_dynamic_forcings_to_input_tensor(
if self.hacks:
if "dynamic_forcings_date" in self.development_hacks:
date = self.development_hacks["dynamic_forcings_date"]
dates = [date]
warnings.warn(f"🧑‍💻 Using `dynamic_forcings_date` hack: {date} 🧑‍💻")

# TODO: check if there were not already loaded as part of the input state
Expand All @@ -889,15 +938,20 @@ def add_dynamic_forcings_to_input_tensor(
# batch is always 1

for source in self.dynamic_forcings_inputs:
forcings = source.load_forcings_array([date], state) # shape: (variables, dates, values)
forcings = source.load_forcings_array(dates, state) # shape: (variables, dates, values)

forcings = np.squeeze(forcings, axis=1) # Drop the dates dimension
forcings = np.swapaxes(forcings, 0, 1) # shape: (dates, variable, values)

forcings = np.swapaxes(forcings[np.newaxis, np.newaxis, ...], -2, -1) # shape: (1, 1, values, variables)
forcings = np.swapaxes(
forcings[np.newaxis, :, np.newaxis, ...], -2, -1
) # shape: (1, dates, 1, values, variables)

forcings = torch.from_numpy(forcings).to(self.device) # Copy to device

input_tensor_torch[:, -1, :, source.mask] = forcings # Copy forcings to last 'multi_step_input' row
for i in range(1, self.checkpoint.multi_step_output + 1):
input_tensor_torch[:, -i, :, source.mask] = forcings[
:, -i, ...
] # Copy forcings to corresponding 'multi_step_input' row

assert not check[source.mask].any() # Make sure we are not overwriting some values
check[source.mask] = True
Expand All @@ -912,7 +966,11 @@ def add_dynamic_forcings_to_input_tensor(
return input_tensor_torch

def add_boundary_forcings_to_input_tensor(
self, input_tensor_torch: "torch.Tensor", state: State, date: datetime.datetime, check: BoolArray
self,
input_tensor_torch: "torch.Tensor",
state: State,
dates: list[datetime.datetime],
check: BoolArray,
) -> "torch.Tensor":
"""Add boundary forcings to the input tensor.

Expand All @@ -922,7 +980,7 @@ def add_boundary_forcings_to_input_tensor(
The input tensor.
state : State
The state.
date : datetime.datetime
dates : list[datetime.datetime]
The date.
check : BoolArray
The check array.
Expand All @@ -936,14 +994,20 @@ def add_boundary_forcings_to_input_tensor(
# batch is always 1
sources = self.boundary_forcings_inputs
for source in sources:
forcings = source.load_forcings_array([date], state) # shape: (variables, dates, values)
forcings = source.load_forcings_array(dates, state) # shape: (variables, dates, values)

forcings = np.squeeze(forcings, axis=1) # Drop the dates dimension
forcings = np.swapaxes(forcings, 0, 1) # shape: (dates, variable, values)

forcings = np.swapaxes(forcings[np.newaxis, np.newaxis, ...], -2, -1) # shape: (1, 1, values, variables)
forcings = np.swapaxes(
forcings[np.newaxis, :, np.newaxis, ...], -2, -1
) # shape: (1, dates, 1, values, variables)
forcings = torch.from_numpy(forcings).to(self.device) # Copy to device
total_mask = np.ix_([0], [-1], source.spatial_mask, source.variables_mask)
input_tensor_torch[total_mask] = forcings # Copy forcings to last 'multi_step_input' row

for i in range(1, self.checkpoint.multi_step_output + 1):
total_mask = np.ix_([0], [-i], source.spatial_mask, source.variables_mask)
input_tensor_torch[total_mask] = forcings[
:, -i, ...
] # Copy forcings to corresponding 'multi_step_input' row

for n in source.variables_mask:
self._input_kinds[self._input_tensor_by_name[n]] = Kind(boundary=True, forcing=True, **source.kinds)
Expand Down
Loading