Skip to content

Bug: Inference error when running physical jepa pretrained model #1759

@wael-mika

Description

@wael-mika

What happened?

When trained a model using the physical jepa config and tried to run inference, I got the following error:

Traceback (most recent call last):
  File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/run_train.py", line 79, in inference_from_args
    trainer.inference(cf, devices, args.from_run_id, args.mini_epoch)
  File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/train/trainer.py", line 221, in inference
    self.validate(0, self.test_cfg, self.batch_size_test_per_gpu)
  File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/train/trainer.py", line 584, in validate
    write_output(
  File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/utils/validation_io.py", line 63, in write_output
    for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)):
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: zip() argument 2 is shorter than argument 1
[3] > /users/walmikae/weathergen/WeatherGenerator/src/weathergen/utils/validation_io.py(63)write_output()
-> for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)):

The issue is that preds now contain 2 tensors one for the latent and the other for the physical values which should be used in this case. I tried to debug it manually by indexing the correct tensor, but it triggered different errors in the evaluation package IO.

What are the steps to reproduce the bug?

Branch: Develop
HPC: Santis
run_id: j8faj4zo
command: uv run --offline inference --from_run_id=j8faj4zo --samples=32

Config:


streams_directory: "./config/streams/era5_1deg/"
streams: ???

general:
  istep: 0
  rank: ???
  world_size: ???
  multiprocessing_method: "fork"
  desc: ""
  run_id: ???
  run_history: []

train_log_freq:
  terminal: 10
  metrics: 20
  checkpoint: 250

data_loading:
  num_workers: 12
  rng_seed: ???
  memory_pinning: True


training_config:

  training_mode: ["masking", "student_teacher"]

  num_mini_epochs: 48
  samples_per_mini_epoch: 4096
  shuffle: True

  start_date: 1979-01-01T00:00
  end_date: 2022-12-31T00:00

  time_window_step: 06:00:00
  time_window_len: 06:00:00

  window_offset_prediction: 0

  learning_rate_scheduling:
    lr_start: 1e-6
    lr_max: 5e-5
    lr_final_decay: 1e-6
    lr_final: 0.0
    num_steps_warmup: 512
    num_steps_cooldown: 512
    policy_warmup: "cosine"
    policy_decay: "constant"
    policy_cooldown: "linear"
    parallel_scaling_policy: "sqrt"

  optimizer:
    grad_clip: 1.0
    weight_decay: 0.1
    log_grad_norms: False
    adamw:
      beta1: 0.975
      beta2: 0.9875
      eps: 2e-08

  losses: {
    "physical": {
      type: LossPhysical,
      weight: 0.5,
      loss_fcts: {
        "mse": {
          weight: 1.0,
          # complement: predict masked tokens from unmasked tokens
          # Standard masked token modeling paradigm
          target_source_correspondence: { 0 : { 0 : "complement"} },
        },
      },
      target_and_aux_calc: "Physical",
    },
    "student-teacher": {
      type: LossLatentSSLStudentTeacher,
      weight: 0.5,
      loss_fcts: {
        "JEPA": {
            'weight': 8, "loss_extra_args": {}, "out_dim": 2048, "head": transformer,
            "num_blocks": 6, "num_heads": 12, "with_qk_lnorm": True, "intermediate_dim": 768, 
            "dropout_rate": 0.1,
          target_source_correspondence: { 1 : { 1 : "subset"} },
        },
      },
      target_and_aux_calc: { "EMATeacher" :
        { ema_ramp_up_ratio : 0.09,
          ema_halflife_in_thousands: 1e-3,
        }
      }
    },
  }

  # Target views: index 0 for physical, index 1 for teacher
  target_input: {
    "physical_target": {
      masking_strategy: "random",
      num_samples: 1,
      masking_strategy_config: {
        rate: 0.4,                         # 40% kept - same as cropping config
        rate_sampling: False,
      },
    },
    "teacher_random": {
      masking_strategy: "random",
      num_samples: 1,
      masking_strategy_config: {
        rate: 0.6,                         # Teacher sees 60% (larger context)
        rate_sampling: False,
      },
    },
  }

  # Source views: index 0 for physical (complement), index 1 for student (subset)
  model_input: {
    "physical_source": {
      masking_strategy: "random",
      num_samples: 1,
      num_steps_input: 1,
      masking_strategy_config: {
        rate: 0.4,                         # 40% kept - same as cropping config
        rate_sampling: False,
        diffusion_rn: True,                # Enable diffusion random noise
      },
      relationship: "complement",          # Predict masked tokens from unmasked
    },
    "student_random": {
      masking_strategy: "random",
      num_samples: 1,
      num_steps_input: 1,
      masking_strategy_config: {
        rate: 0.4,                         # Student sees 40% - same as cropping config
        rate_sampling: False,
      },
      relationship: "subset",              # Student view is subset of teacher view
    },
  }

  forecast:
    time_step: 06:00:00
    num_steps: 0
    policy: null


validation_config:

  samples_per_mini_epoch: 256
  shuffle: False

  start_date: 2023-10-01T00:00
  end_date: 2023-12-31T00:00

  validate_with_ema:
    enabled: True
    ema_ramp_up_ratio: 0.09
    ema_halflife_in_thousands: 1e-3

  validate_before_training: 0

  output: {
    num_samples: 0,
    normalized_samples: True,
    streams: null,
  }

  losses: {
    "physical": {
      type: LossPhysical,
      weight: 1.0,
      loss_fcts: {
        "mse": {
          weight: 1.0,
        },
      },
    },
  }

Hedgedoc link to logs and more information. This ticket is public, do not attach files directly.

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdataAnything related to the datasets used in the projectdata:ioIssues with the zarr output produced during inference/validation

Type

Projects

Status

No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions