-
Notifications
You must be signed in to change notification settings - Fork 51
Open
Labels
bugSomething isn't workingSomething isn't workingdataAnything related to the datasets used in the projectAnything related to the datasets used in the projectdata:ioIssues with the zarr output produced during inference/validationIssues with the zarr output produced during inference/validation
Description
What happened?
When trained a model using the physical jepa config and tried to run inference, I got the following error:
Traceback (most recent call last):
File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/run_train.py", line 79, in inference_from_args
trainer.inference(cf, devices, args.from_run_id, args.mini_epoch)
File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/train/trainer.py", line 221, in inference
self.validate(0, self.test_cfg, self.batch_size_test_per_gpu)
File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/train/trainer.py", line 584, in validate
write_output(
File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/utils/validation_io.py", line 63, in write_output
for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: zip() argument 2 is shorter than argument 1
[3] > /users/walmikae/weathergen/WeatherGenerator/src/weathergen/utils/validation_io.py(63)write_output()
-> for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)):
The issue is that preds now contain 2 tensors one for the latent and the other for the physical values which should be used in this case. I tried to debug it manually by indexing the correct tensor, but it triggered different errors in the evaluation package IO.
What are the steps to reproduce the bug?
Branch: Develop
HPC: Santis
run_id: j8faj4zo
command: uv run --offline inference --from_run_id=j8faj4zo --samples=32
Config:
streams_directory: "./config/streams/era5_1deg/"
streams: ???
general:
istep: 0
rank: ???
world_size: ???
multiprocessing_method: "fork"
desc: ""
run_id: ???
run_history: []
train_log_freq:
terminal: 10
metrics: 20
checkpoint: 250
data_loading:
num_workers: 12
rng_seed: ???
memory_pinning: True
training_config:
training_mode: ["masking", "student_teacher"]
num_mini_epochs: 48
samples_per_mini_epoch: 4096
shuffle: True
start_date: 1979-01-01T00:00
end_date: 2022-12-31T00:00
time_window_step: 06:00:00
time_window_len: 06:00:00
window_offset_prediction: 0
learning_rate_scheduling:
lr_start: 1e-6
lr_max: 5e-5
lr_final_decay: 1e-6
lr_final: 0.0
num_steps_warmup: 512
num_steps_cooldown: 512
policy_warmup: "cosine"
policy_decay: "constant"
policy_cooldown: "linear"
parallel_scaling_policy: "sqrt"
optimizer:
grad_clip: 1.0
weight_decay: 0.1
log_grad_norms: False
adamw:
beta1: 0.975
beta2: 0.9875
eps: 2e-08
losses: {
"physical": {
type: LossPhysical,
weight: 0.5,
loss_fcts: {
"mse": {
weight: 1.0,
# complement: predict masked tokens from unmasked tokens
# Standard masked token modeling paradigm
target_source_correspondence: { 0 : { 0 : "complement"} },
},
},
target_and_aux_calc: "Physical",
},
"student-teacher": {
type: LossLatentSSLStudentTeacher,
weight: 0.5,
loss_fcts: {
"JEPA": {
'weight': 8, "loss_extra_args": {}, "out_dim": 2048, "head": transformer,
"num_blocks": 6, "num_heads": 12, "with_qk_lnorm": True, "intermediate_dim": 768,
"dropout_rate": 0.1,
target_source_correspondence: { 1 : { 1 : "subset"} },
},
},
target_and_aux_calc: { "EMATeacher" :
{ ema_ramp_up_ratio : 0.09,
ema_halflife_in_thousands: 1e-3,
}
}
},
}
# Target views: index 0 for physical, index 1 for teacher
target_input: {
"physical_target": {
masking_strategy: "random",
num_samples: 1,
masking_strategy_config: {
rate: 0.4, # 40% kept - same as cropping config
rate_sampling: False,
},
},
"teacher_random": {
masking_strategy: "random",
num_samples: 1,
masking_strategy_config: {
rate: 0.6, # Teacher sees 60% (larger context)
rate_sampling: False,
},
},
}
# Source views: index 0 for physical (complement), index 1 for student (subset)
model_input: {
"physical_source": {
masking_strategy: "random",
num_samples: 1,
num_steps_input: 1,
masking_strategy_config: {
rate: 0.4, # 40% kept - same as cropping config
rate_sampling: False,
diffusion_rn: True, # Enable diffusion random noise
},
relationship: "complement", # Predict masked tokens from unmasked
},
"student_random": {
masking_strategy: "random",
num_samples: 1,
num_steps_input: 1,
masking_strategy_config: {
rate: 0.4, # Student sees 40% - same as cropping config
rate_sampling: False,
},
relationship: "subset", # Student view is subset of teacher view
},
}
forecast:
time_step: 06:00:00
num_steps: 0
policy: null
validation_config:
samples_per_mini_epoch: 256
shuffle: False
start_date: 2023-10-01T00:00
end_date: 2023-12-31T00:00
validate_with_ema:
enabled: True
ema_ramp_up_ratio: 0.09
ema_halflife_in_thousands: 1e-3
validate_before_training: 0
output: {
num_samples: 0,
normalized_samples: True,
streams: null,
}
losses: {
"physical": {
type: LossPhysical,
weight: 1.0,
loss_fcts: {
"mse": {
weight: 1.0,
},
},
},
}
Hedgedoc link to logs and more information. This ticket is public, do not attach files directly.
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingdataAnything related to the datasets used in the projectAnything related to the datasets used in the projectdata:ioIssues with the zarr output produced during inference/validationIssues with the zarr output produced during inference/validation
Type
Projects
Status
No status