Skip to content

Errors when trying to do inference with fine-tuned JEPA model #1724

@shmh40

Description

@shmh40

What happened?

When trying to do inference with fine-tuned JEPA model, there are errors. The setup is pre-training JEPA on ERA5 (no target channels as JEPA training objective is in latent space), then finetuning this model to output physical predictions for synop. When running inference, even if streams_output is specified, the code tries to generate predictions for ERA5, and since there are no target channels for ERA5, the preds are None (targets are an empty tensor). The code then breaks in the validation_io at the stage of then zipping up the targets and the preds in validation_io in write_output in line 63 or so. The information regarding which streams to write comes later in the code, so the code still tries to make ERA5 predictions, even though we have (and we want in this case) no target channels.

What are the steps to reproduce the bug?

Run inference when we have JEPA pre-trained a model with ERA5, and then fine-tuned for physical SurfaceCombined prediction. E.g., on Santis run:

uv run inference --from-run-id lu3qz4fj --samples 4 --streams-output SurfaceCombined

Error:

Loss is 0.0, likely incorrect configuration. Check stream support time and training configuration.
  0%|                                                                                                                                                       | 0/4 [00:09<?, ?it/s]
Traceback (most recent call last):
  File "/users/shickman/work/wg_base/develop/WeatherGenerator/src/weathergen/run_train.py", line 79, in inference_from_args
    trainer.inference(cf, devices, args.from_run_id, args.mini_epoch)
  File "/users/shickman/work/wg_base/develop/WeatherGenerator/src/weathergen/train/trainer.py", line 221, in inference
    self.validate(0, self.test_cfg, self.batch_size_test_per_gpu)
  File "/users/shickman/work/wg_base/develop/WeatherGenerator/src/weathergen/train/trainer.py", line 584, in validate
    write_output(
  File "/users/shickman/work/wg_base/develop/WeatherGenerator/src/weathergen/utils/validation_io.py", line 63, in write_output
    for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)):
                                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'NoneType' object is not iterable

Hedgedoc link to logs and more information. This ticket is public, do not attach files directly.

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdata:ioIssues with the zarr output produced during inference/validation

Type

Projects

Status

No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions