Skip to content

Support for multiple dataset checkpoints #402

@gmertes

Description

@gmertes

We need to support checkpoints that have the new multi dataset format (ecmwf/anemoi-core#594).

The goal is to do it backwards compatible, and support both legacy single-dataset and new multi-dataset style checkpoints. The plan is to do it in two steps. Step 1 will be merged together with the PR in core to maintain feature parity, step 2 is a longer term task.

Step 1

Provide support for multi-dataset checkpoints, but only those trained with a single dataset in and out. This is equivalent to the current style of checkpoints, except that dataset related metadata is now contained under a new default dictionary key data.

Supporting this is straightforward: we only have to map the old metadata to the new data entry, the contents of the metadata remain the same. For this we create a wrapper around the current Metadata class that redirects all existing attributes to the new data entry. For the rest of the code this is transparant and fully backwards compatible.

This allows us to merge the PR in core and not lose functionality in inference while we work on step 2.

Step 1.5

Together with the multi dataset checkpoints, we are also updating the checkpoint metadata with better inference metadata. Instead of 'reverse engineering' things like the name_to_index or the list of prognostic variables from the data indices, we can now just get them directly from the metadata.

We can start using these new entries in the wrapper described in step 1, updating the old attributes where applicable.

Step 2

Support the "real" multi-datasets case of N datasets in and out (e.g.: cerra + era in and out). This requires a refactor of all the logic related to input and output handling, as well as the rollout loop. We will need N inputs and outputs. For every tensor we currently create of each variable class (prognostic, diagnostic or forcing), we will now have to create them N times.

We plan the creation of a new object, provisionally called the TensorHandler. The TensorHandler contains all the logic related to input, output and creation of tensors for a single dataset. We create N TensorHandler, and each one only has a view into the metadata for its corresponding dataset. This is where we use the Metadata wrapper from Step 1. The TensorHandler knows nothing about multi-datasets, it only knows about a single dataset.

In the runner, we then call all N TensorHandler objects to prepare the input tensor of each dataset and do the rollout.

Extra things to consider:

  • The State will also have to be reworked. We have two options: we add an extra dictionary level to the current State object, one entry for each dataset, or we create a new multi dataset state object, that contains N State objects.
  • The config will also need an extra level to specify separate inputs and outputs per dataset. How we do this is TBD, the simplest way is to add yet another level of dictionary.
  • The predict_step interface has also changed (dictionary of tensors instead of a single tensor). To maintain backwards compatibility with single-datasets and old checkpoints, we will support both formats with a switch based on the checkpoint format.

Metadata

Metadata

Assignees

Labels

Type

Projects

Status

Now In Progress

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions