Skip to content

Commit

Permalink
Merge pull request #67 from yurujaja/regression_fix
Browse files Browse the repository at this point in the history
Minor regression code and command fixes
  • Loading branch information
RituYadav92 authored Sep 27, 2024
2 parents b4d5663 + b2db422 commit 774d4c4
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 10 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
encoder=prithvi \
decoder=reg_upernet \
preprocessing=reg_default \
criterion=cross_entropy \
criterion=mse \
task=regression
```
To use SatlasNet encoder, the `configs/encoder/satlasnet_si.yaml` is required.
Expand All @@ -230,7 +230,7 @@ torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
encoder=prithvi \
decoder=reg_upernet_mt_ltae \
preprocessing=reg_default \
criterion=cross_entropy \
criterion=mse \
task=regression
```

Expand Down
4 changes: 2 additions & 2 deletions configs/dataset/biomassters.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
_target_: pangaea.datasets.biomassters.BioMassters
dataset_name: BioMassters
root_path: ./data/Biomassters
# download_url: #https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars/resolve/main/hls_burn_scars.tar.gz?download=true
download_url:
auto_download: False
img_size: 256
temporal: 6 #6 (summer month use if multi_temp is 1)
temp: 6 #6 (select month to use if single temporal (multi_temp : 1))
multi_temporal: 12
multi_modal: True

Expand Down
7 changes: 2 additions & 5 deletions pangaea/datasets/biomassters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def read_imgs(multi_temporal, temp , fname, data_dir, img_size):
imgs_s1, imgs_s2, mask = [], [], []
if multi_temporal==1:
month_list = [temp]
else:
month_list = list(range(12))
else:
month_list = list(range(int(multi_temporal)))

for month in month_list:

Expand Down Expand Up @@ -47,10 +47,8 @@ def read_imgs(multi_temporal, temp , fname, data_dir, img_size):

imgs_s1 = np.stack(imgs_s1, axis=1)
imgs_s2 = np.stack(imgs_s2, axis=1)

return imgs_s1, imgs_s2, mask

# @DATASET_REGISTRY.register()
class BioMassters(GeoFMDataset):
def __init__(
self,
Expand Down Expand Up @@ -122,7 +120,6 @@ def __init__(
data_max=data_max,
download_url=download_url,
auto_download=auto_download,
# temp = temp,
)

self.root_path = root_path
Expand Down
1 change: 1 addition & 0 deletions pangaea/decoders/upernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def __init__(
finetune=finetune,
)

self.model_name = "Reg_UPerNet"
if not self.finetune:
for param in self.encoder.parameters():
param.requires_grad = False
Expand Down
3 changes: 2 additions & 1 deletion pangaea/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def __init__(
], f"Invalid precision {precision}, use 'fp32', 'fp16' or 'bfp16'."
self.enable_mixed_precision = precision != "fp32"
self.precision = torch.float16 if (precision == "fp16") else torch.bfloat16
self.scaler = torch.GradScaler("cuda", enabled=self.enable_mixed_precision)
# self.scaler = torch.GradScaler("cuda", enabled=self.enable_mixed_precision)
self.scaler = torch.cuda.amp.GradScaler("cuda", enabled=self.enable_mixed_precision)

self.start_epoch = 0

Expand Down

0 comments on commit 774d4c4

Please sign in to comment.