Skip to content

Part of network cannot be sharded during loaded if not used #1750

@kctezcan

Description

@kctezcan

Describe the task. It can be a feature, documentation, etc.

In model_interface.py we are loading the model in load_model().

If I am finteuning a ERA5 model and using ERA5 only as forcing, then when loading the model, the ERA5 specific decoding paraemters do not have sharding parameters and hence cannot be sharded while loading the model. This creates an error.

By adding a check in the parameter loop, we can see if the parameter can be sharded, if not we can load it as the full tensor:

is_model_sharded = cf.with_ddp and cf.with_fsdp
    if is_model_sharded:
        meta_sharded_sd = model.state_dict()
        maybe_sharded_sd = {}
        for param_name, full_tensor in params.items():
            sharded_meta_param = meta_sharded_sd.get(param_name)
            if sharded_meta_param is None:
                print("sharded meta param is None for ", param_name)
                maybe_sharded_sd[param_name] = torch.nn.Parameter(full_tensor)
            else:
                sharded_tensor = distribute_tensor(
                    full_tensor,
                    sharded_meta_param.device_mesh,
                    sharded_meta_param.placements,
                )
                # maybe_sharded_sd[param_name.replace("module.", "")] = nn.Parameter(sharded_tensor)
                maybe_sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor)

This solves the problem and the model can be loaded as usual again.

Hedgedoc URL, if you are keeping notes, plots, logs in hedgedoc.

No response

Area

  • datasets, data readers, data preparation and transfer
  • model
  • science
  • infrastructure and engineering
  • evaluation, export and visualization
  • documentation

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingmodelRelated to model training or definition (not generic infra)model:pretrain

Type

Projects

Status

Todo

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions