-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added solar aware MAE, solar limb masking in datamodule, and adaption…
… of benchmarking tools
- Loading branch information
1 parent
47ce577
commit 77fd6ad
Showing
11 changed files
with
1,111 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"The autoreload extension is already loaded. To reload it, use:\n", | ||
" %reload_ext autoreload\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"from pathlib import Path\n", | ||
"\n", | ||
"import pytorch_lightning as pl\n", | ||
"import torch\n", | ||
"import wandb\n", | ||
"from sdofm import utils\n", | ||
"from sdofm.datasets import SDOMLDataModule, DimmedSDOMLDataModule\n", | ||
"from sdofm.pretraining import SAMAE" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import omegaconf\n", | ||
"cfg = omegaconf.OmegaConf.load(\"../experiments/pretrain_tiny_mae.yaml\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.\n", | ||
"[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.\n", | ||
"[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"data_module = SDOMLDataModule(\n", | ||
" hmi_path=None,\n", | ||
" aia_path=os.path.join(\n", | ||
" cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia\n", | ||
" ),\n", | ||
" eve_path=None,\n", | ||
" components=cfg.data.sdoml.components,\n", | ||
" wavelengths=cfg.data.sdoml.wavelengths,\n", | ||
" ions=cfg.data.sdoml.ions,\n", | ||
" frequency=cfg.data.sdoml.frequency,\n", | ||
" batch_size=cfg.model.opt.batch_size,\n", | ||
" num_workers=cfg.data.num_workers,\n", | ||
" val_months=cfg.data.month_splits.val,\n", | ||
" test_months=cfg.data.month_splits.test,\n", | ||
" holdout_months=cfg.data.month_splits.holdout,\n", | ||
" cache_dir=os.path.join(\n", | ||
" cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.cache\n", | ||
" ),\n", | ||
")\n", | ||
"data_module.setup()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 23, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = SAMAE(\n", | ||
" **cfg.model.mae,\n", | ||
" **cfg.model.samae,\n", | ||
" optimiser=cfg.model.opt.optimiser,\n", | ||
" lr=cfg.model.opt.learning_rate,\n", | ||
" weight_decay=cfg.model.opt.weight_decay,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 26, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"GPU available: True (cuda), used: True\n", | ||
"TPU available: False, using: 0 TPU cores\n", | ||
"IPU available: False, using: 0 IPUs\n", | ||
"HPU available: False, using: 0 HPUs\n", | ||
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", | ||
"\n", | ||
" | Name | Type | Params\n", | ||
"-----------------------------------------------------------------\n", | ||
"0 | autoencoder | SolarAwareMaskedAutoencoderViT3D | 3.3 M \n", | ||
"-----------------------------------------------------------------\n", | ||
"3.0 M Trainable params\n", | ||
"262 K Non-trainable params\n", | ||
"3.3 M Total params\n", | ||
"13.005 Total estimated model params size (MB)\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "2e5a818f36c54a2ab890c11acfd02d01", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Sanity Checking: | | 0/? [00:00<?, ?it/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "37f9ef68d4334ae098f777aa1c20c99c", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Training: | | 0/? [00:00<?, ?it/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/opt/conda/envs/sdofm/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"trainer = pl.Trainer(\n", | ||
" devices=1, accelerator=cfg.experiment.accelerator, max_epochs=cfg.model.opt.epochs\n", | ||
")\n", | ||
"trainer.fit(model=model, datamodule=data_module)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "sdofm", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.