Skip to content

Commit e9dae8a

Browse files
authored
Merge pull request #25 from spaceml-org/vertexai
New Limb Masking + Nonlinear DL + F10.7 Experiment
2 parents 8c6ebd7 + 41be0b5 commit e9dae8a

35 files changed

+49749
-547
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ wandb
22
output
33
outputs
44
*.tar
5+
notebooks/imgs_for_google_emb/*
56
notebooks/camera_ready/*/lightning_logs
67

78
# aux directories

experiments/for_google_emb.yaml

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# default.yaml
2+
3+
# MODEL SUMMARY
4+
# | Name | Type | Params
5+
# -------------------------------------------------------
6+
# 0 | autoencoder | MaskedAutoencoderViT3D | 333 M
7+
# -------------------------------------------------------
8+
# 329 M Trainable params
9+
# 4.7 M Non-trainable params
10+
# 333 M Total params
11+
# 1,335.838 Total estimated model params size (MB)
12+
13+
# general
14+
log_level: 'DEBUG'
15+
experiment:
16+
name: "mae-helioprojected-2011" # generate random name in wandb
17+
project: "sdofm"
18+
task: "pretrain" # options: train, evaluate (not implemented)
19+
model: "samae"
20+
backbone_checkpoint: null
21+
seed: 0
22+
disable_cuda: false
23+
resuming: false
24+
wandb:
25+
enable: true
26+
entity: "fdlx"
27+
group: "sdofm-phase1"
28+
job_type: "pretrain"
29+
tags: []
30+
notes: ""
31+
output_directory: "wandb_output"
32+
log_model: "all" # can be True (final checkpoint), False (no checkpointing), or "all" (for all epoches)
33+
gcp_storage: # this will checkpoint all epoches, perhaps clean up this config
34+
enabled: true
35+
bucket: "sdofm-checkpoints"
36+
fold: null
37+
evaluate: false # skip training and only evaluate (requires checkpoint to be set)
38+
checkpoint: null # this is the wandb run_id of the checkpoint to load
39+
device: null # this is set automatically using the disable_cuda flag and torch.cuda.is_available()
40+
precision: '16' #-true' # (32, 64) for cuda, ('32-true', '16-true', 'bf16-true') for tpu
41+
log_n_batches: 1000 # log every n training batches
42+
save_results: true # save full results to file and wandb
43+
accelerator: "auto" # options are "auto", "gpu", "tpu", "ipu", or "cpu"
44+
profiler: null #'XLAProfiler' # 'XLAProfiler' # options are XLAProfiler/PyTorchProfiler Warning: XLA for TPUs only works on single world size
45+
distributed:
46+
enabled: true
47+
world_size: "auto" # The "auto" option recognizes the machine you are on, and selects the appropriate number of accelerators.
48+
log_every_n_steps: 50
49+
50+
# dataset configuration
51+
data:
52+
min_date: '2015-02-01 00:00:00.00' # NOT IMPLEMENTED # minimum is '2010-09-09 00:00:11.08'
53+
max_date: '2015-05-31 23:59:59.99' # NOT IMPLEMENTED # maximum is '2023-05-26 06:36:08.072'
54+
month_splits: # non selected months will form training set
55+
train: [1] #,2,3,4,5,6,7,8,9,10]
56+
val: [2]
57+
test: [3,4]
58+
holdout: []
59+
num_workers: 16 # set appropriately for your machine
60+
prefetch_factor: 3
61+
num_frames: 1 # WARNING: This is only read for FINETUNING, model num_frames overrides in BACKBONE
62+
drop_frame_dim: false
63+
# output_directory: "wandb_output"
64+
sdoml:
65+
base_directory: "/mnt/sdoml"
66+
sub_directory:
67+
hmi: "HMI.zarr"
68+
aia: "AIA.zarr"
69+
eve: "EVE_legacy.zarr"
70+
cache: "cache"
71+
components: null # null for select all magnetic components ["Bx", "By", "Bz"]
72+
wavelengths: null # null for select all wavelengths channels ["131A","1600A","1700A","171A","193A","211A","304A","335A","94A"]
73+
ions: null # null to select all ion channels ["C III", "Fe IX", "Fe VIII", "Fe X", "Fe XI", "Fe XII", "Fe XIII", "Fe XIV", "Fe XIX", "Fe XV", "Fe XVI", "Fe XVIII", "Fe XVI_2", "Fe XX", "Fe XX_2", "Fe XX_3", "H I", "H I_2", "H I_3", "He I", "He II", "He II_2", "He I_2", "Mg IX", "Mg X", "Mg X_2", "Ne VII", "Ne VIII", "O II", "O III", "O III_2", "O II_2", "O IV", "O IV_2", "O V", "O VI", "S XIV", "Si XII", "Si XII_2"]
74+
frequency: '12min' # smallest is 12min
75+
mask_with_hmi_threshold: null # None/null for no mask, float for threshold
76+
feature_engineering:
77+
enabled: true
78+
dclass: 'HelioProjected'
79+
80+
# model configurations
81+
model:
82+
# PRETRAINERS
83+
mae:
84+
img_size: 512
85+
patch_size: 16
86+
num_frames: 1
87+
tubelet_size: 1
88+
in_chans: 9
89+
embed_dim: 128
90+
depth: 24
91+
num_heads: 16
92+
decoder_embed_dim: 512
93+
decoder_depth: 8
94+
decoder_num_heads: 16
95+
mlp_ratio: 4.0
96+
norm_layer: 'LayerNorm'
97+
norm_pix_loss: False
98+
masking_ratio: 0.75
99+
samae:
100+
# uses all parameters as in mae plus these
101+
masking_type: "solar_aware" # 'random' or 'solar_aware'
102+
active_region_mu_degs: 15.73
103+
active_region_std_degs: 6.14
104+
active_region_scale: 1.0
105+
active_region_abs_lon_max_degs: 60
106+
active_region_abs_lat_max_degs: 60
107+
nvae:
108+
use_se: true
109+
res_dist: true
110+
num_x_bits: 8
111+
num_latent_scales: 3 # 5
112+
num_groups_per_scale: 1 # 16
113+
num_latent_per_group: 1 # 10
114+
ada_groups: true
115+
min_groups_per_scale: 1
116+
num_channels_enc: 30
117+
num_channels_dec: 30
118+
num_preprocess_blocks: 2 # 1
119+
num_preprocess_cells: 2
120+
num_cell_per_cond_enc: 2
121+
num_postprocess_blocks: 2 # 1
122+
num_postprocess_cells: 2
123+
num_cell_per_cond_dec: 2
124+
num_mixture_dec: 1
125+
num_nf: 2
126+
kl_anneal_portion: 0.3
127+
kl_const_portion: 0.0001
128+
kl_const_coeff: 0.0001
129+
# learning_rate: 1e-2
130+
# weight_decay: 3e-4
131+
weight_decay_norm_anneal: true
132+
weight_decay_norm_init: 1.
133+
weight_decay_norm: 1e-2
134+
135+
# FINE-TUNERS
136+
degragation:
137+
num_neck_filters: 32
138+
output_dim: 1 # not sure why this is implemented for autocorrelation, should be a scalar
139+
loss: "mse" # options: "mse", "heteroscedastic"
140+
freeze_encoder: true
141+
142+
# ML optimization arguments:
143+
opt:
144+
loss: "mse" # options: "mae", "mse", "mape"
145+
scheduler: "constant" #other options: "cosine", "plateau", "exp"
146+
scheduler_warmup: 0
147+
batch_size: 1
148+
learning_rate: 0.0001
149+
weight_decay: 3e-4 # 0.0
150+
optimiser: "adam"
151+
epochs: 2
152+
patience: 2
153+
154+
# hydra configuration
155+
# hydra:
156+
# sweeper:
157+
# params:
158+
# model.mae.embed_dim: 256, 512
159+
# model.mae.masking_ratio: 0.5, 0.75
160+
# model.samae.masking_type: "random", "solar_aware"
161+
162+
hydra:
163+
mode: RUN

experiments/pretrain_32.2M_mae_HP_r512_e128_p16.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ experiment:
3737
evaluate: false # skip training and only evaluate (requires checkpoint to be set)
3838
checkpoint: null # this is the wandb run_id of the checkpoint to load
3939
device: null # this is set automatically using the disable_cuda flag and torch.cuda.is_available()
40-
precision: '16' #-true' # (32, 64) for cuda, ('32-true', '16-true', 'bf16-true') for tpu
40+
precision: '16-mixed' #-true' # (32, 64) for cuda, ('32-true', '16-true', 'bf16-true') for tpu
4141
log_n_batches: 1000 # log every n training batches
4242
save_results: true # save full results to file and wandb
4343
accelerator: "auto" # options are "auto", "gpu", "tpu", "ipu", or "cpu"
@@ -73,6 +73,7 @@ data:
7373
ions: null # null to select all ion channels ["C III", "Fe IX", "Fe VIII", "Fe X", "Fe XI", "Fe XII", "Fe XIII", "Fe XIV", "Fe XIX", "Fe XV", "Fe XVI", "Fe XVIII", "Fe XVI_2", "Fe XX", "Fe XX_2", "Fe XX_3", "H I", "H I_2", "H I_3", "He I", "He II", "He II_2", "He I_2", "Mg IX", "Mg X", "Mg X_2", "Ne VII", "Ne VIII", "O II", "O III", "O III_2", "O II_2", "O IV", "O IV_2", "O V", "O VI", "S XIV", "Si XII", "Si XII_2"]
7474
frequency: '12min' # smallest is 12min
7575
mask_with_hmi_threshold: null # None/null for no mask, float for threshold
76+
apply_mask: false
7677
feature_engineering:
7778
enabled: true
7879
dclass: 'HelioProjected'
@@ -148,9 +149,12 @@ model:
148149
learning_rate: 0.0001
149150
weight_decay: 3e-4 # 0.0
150151
optimiser: "adam"
151-
epochs: 2
152+
epochs: 100
152153
patience: 2
153154

155+
misc:
156+
limb_mask: false
157+
154158
# hydra configuration
155159
# hydra:
156160
# sweeper:
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# default.yaml
2+
3+
# MODEL SUMMARY
4+
# | Name | Type | Params
5+
# -------------------------------------------------------
6+
# 0 | autoencoder | MaskedAutoencoderViT3D | 333 M
7+
# -------------------------------------------------------
8+
# 329 M Trainable params
9+
# 4.7 M Non-trainable params
10+
# 333 M Total params
11+
# 1,335.838 Total estimated model params size (MB)
12+
13+
# general
14+
log_level: 'DEBUG'
15+
experiment:
16+
name: "mae-log-helioprojected-limbmasked-2011subset-r512-e128-p16" # generate random name in wandb
17+
project: "sdofm"
18+
task: "pretrain" # options: train, evaluate (not implemented)
19+
model: "mae"
20+
backbone_checkpoint: null
21+
seed: 0
22+
disable_cuda: false
23+
resuming: false
24+
wandb:
25+
enable: true
26+
entity: "fdlx"
27+
group: "sdofm-phase1"
28+
job_type: "pretrain"
29+
tags: []
30+
notes: ""
31+
output_directory: "wandb_output"
32+
log_model: "all" # can be True (final checkpoint), False (no checkpointing), or "all" (for all epoches)
33+
gcp_storage: # this will checkpoint all epoches, perhaps clean up this config
34+
enabled: true
35+
bucket: "sdofm-checkpoints"
36+
fold: null
37+
evaluate: false # skip training and only evaluate (requires checkpoint to be set)
38+
checkpoint: null # this is the wandb run_id of the checkpoint to load
39+
device: null # this is set automatically using the disable_cuda flag and torch.cuda.is_available()
40+
precision: '16-mixed' #-true' # (32, 64) for cuda, ('32-true', '16-true', 'bf16-true') for tpu
41+
log_n_batches: 1000 # log every n training batches
42+
save_results: true # save full results to file and wandb
43+
accelerator: "auto" # options are "auto", "gpu", "tpu", "ipu", or "cpu"
44+
profiler: null #'XLAProfiler' # 'XLAProfiler' # options are XLAProfiler/PyTorchProfiler Warning: XLA for TPUs only works on single world size
45+
distributed:
46+
enabled: true
47+
world_size: "auto" # The "auto" option recognizes the machine you are on, and selects the appropriate number of accelerators.
48+
log_every_n_steps: 50
49+
50+
# dataset configuration
51+
data:
52+
min_date: '2011-01-01 00:00:00.00' # NOT IMPLEMENTED # minimum is '2010-09-09 00:00:11.08'
53+
max_date: '2011-03-31 23:59:59.99' # NOT IMPLEMENTED # maximum is '2023-05-26 06:36:08.072'
54+
month_splits: # non selected months will form training set
55+
train: [1] #,2,3,4,5,6,7,8,9,10]
56+
val: [2]
57+
test: [3]
58+
holdout: []
59+
num_workers: 16 # set appropriately for your machine
60+
prefetch_factor: 3
61+
num_frames: 1 # WARNING: This is only read for FINETUNING, model num_frames overrides in BACKBONE
62+
drop_frame_dim: false
63+
# output_directory: "wandb_output"
64+
sdoml:
65+
base_directory: "/mnt/sdoml"
66+
sub_directory:
67+
hmi: "HMI.zarr"
68+
aia: "AIA.zarr"
69+
eve: "EVE_legacy.zarr"
70+
cache: "cache"
71+
components: null # null for select all magnetic components ["Bx", "By", "Bz"]
72+
wavelengths: null # null for select all wavelengths channels ["131A","1600A","1700A","171A","193A","211A","304A","335A","94A"]
73+
ions: null # null to select all ion channels ["C III", "Fe IX", "Fe VIII", "Fe X", "Fe XI", "Fe XII", "Fe XIII", "Fe XIV", "Fe XIX", "Fe XV", "Fe XVI", "Fe XVIII", "Fe XVI_2", "Fe XX", "Fe XX_2", "Fe XX_3", "H I", "H I_2", "H I_3", "He I", "He II", "He II_2", "He I_2", "Mg IX", "Mg X", "Mg X_2", "Ne VII", "Ne VIII", "O II", "O III", "O III_2", "O II_2", "O IV", "O IV_2", "O V", "O VI", "S XIV", "Si XII", "Si XII_2"]
74+
frequency: '12min' # smallest is 12min
75+
mask_with_hmi_threshold: null # None/null for no mask, float for threshold
76+
apply_mask: true
77+
feature_engineering:
78+
enabled: true
79+
dclass: 'Log'
80+
81+
# model configurations
82+
model:
83+
# PRETRAINERS
84+
mae:
85+
img_size: 512
86+
patch_size: 16
87+
num_frames: 1
88+
tubelet_size: 1
89+
in_chans: 9
90+
embed_dim: 128
91+
depth: 24
92+
num_heads: 16
93+
decoder_embed_dim: 512
94+
decoder_depth: 8
95+
decoder_num_heads: 16
96+
mlp_ratio: 4.0
97+
norm_layer: 'LayerNorm'
98+
norm_pix_loss: False
99+
masking_ratio: 0.75
100+
samae:
101+
# uses all parameters as in mae plus these
102+
masking_type: "solar_aware" # 'random' or 'solar_aware'
103+
active_region_mu_degs: 15.73
104+
active_region_std_degs: 6.14
105+
active_region_scale: 1.0
106+
active_region_abs_lon_max_degs: 60
107+
active_region_abs_lat_max_degs: 60
108+
nvae:
109+
use_se: true
110+
res_dist: true
111+
num_x_bits: 8
112+
num_latent_scales: 3 # 5
113+
num_groups_per_scale: 1 # 16
114+
num_latent_per_group: 1 # 10
115+
ada_groups: true
116+
min_groups_per_scale: 1
117+
num_channels_enc: 30
118+
num_channels_dec: 30
119+
num_preprocess_blocks: 2 # 1
120+
num_preprocess_cells: 2
121+
num_cell_per_cond_enc: 2
122+
num_postprocess_blocks: 2 # 1
123+
num_postprocess_cells: 2
124+
num_cell_per_cond_dec: 2
125+
num_mixture_dec: 1
126+
num_nf: 2
127+
kl_anneal_portion: 0.3
128+
kl_const_portion: 0.0001
129+
kl_const_coeff: 0.0001
130+
# learning_rate: 1e-2
131+
# weight_decay: 3e-4
132+
weight_decay_norm_anneal: true
133+
weight_decay_norm_init: 1.
134+
weight_decay_norm: 1e-2
135+
136+
# FINE-TUNERS
137+
degragation:
138+
num_neck_filters: 32
139+
output_dim: 1 # not sure why this is implemented for autocorrelation, should be a scalar
140+
loss: "mse" # options: "mse", "heteroscedastic"
141+
freeze_encoder: true
142+
143+
# ML optimization arguments:
144+
opt:
145+
loss: "mse" # options: "mae", "mse", "mape"
146+
scheduler: "constant" #other options: "cosine", "plateau", "exp"
147+
scheduler_warmup: 0
148+
batch_size: 1
149+
learning_rate: 0.0001
150+
weight_decay: 3e-4 # 0.0
151+
optimiser: "adam"
152+
epochs: 2
153+
patience: 2
154+
155+
misc:
156+
limb_mask: true
157+
158+
# hydra configuration
159+
# hydra:
160+
# sweeper:
161+
# params:
162+
# model.mae.embed_dim: 256, 512
163+
# model.mae.masking_ratio: 0.5, 0.75
164+
# model.samae.masking_type: "random", "solar_aware"
165+
166+
hydra:
167+
mode: RUN

0 commit comments

Comments
 (0)