Skip to content

Commit

Permalink
fixed dataset downloads (#70)
Browse files Browse the repository at this point in the history
* fixed downloads for all datasets

* fixed sen1floods11

* Update download signature

---------

Co-authored-by: gle-bellier <[email protected]>
  • Loading branch information
VMarsocci and gle-bellier authored Sep 27, 2024
1 parent 7aa6e09 commit b4d5663
Show file tree
Hide file tree
Showing 12 changed files with 41 additions and 92 deletions.
2 changes: 1 addition & 1 deletion .github/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ We have designed the repo to allow for using your own datasets with minimal effo
}

@staticmethod
def download(dataset_config, silent=False):
def download(self, silent=False):
# Implement if your dataset requires downloading
pass
```
Expand Down
18 changes: 9 additions & 9 deletions pangaea/datasets/ai4smallfarms.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,6 @@ def __init__(
auto_download=auto_download,
)

self.root_path = pathlib.Path(root_path)
self.split = split
self.image_dir = self.root_path.joinpath(f"sentinel-2-asia/{split}/images")
self.mask_dir = self.root_path.joinpath(f"sentinel-2-asia/{split}/masks")
self.image_list = sorted(glob(str(self.image_dir.joinpath("*.tif"))))
self.mask_list = sorted(glob(str(self.mask_dir.joinpath("*.tif"))))

self.data_mean = data_mean
self.data_std = data_std
self.data_min = data_min
Expand All @@ -101,6 +94,13 @@ def __init__(
self.download_url = download_url
self.auto_download = auto_download

self.root_path = pathlib.Path(root_path)
self.split = split
self.image_dir = self.root_path.joinpath(f"sentinel-2-asia/{split}/images")
self.mask_dir = self.root_path.joinpath(f"sentinel-2-asia/{split}/masks")
self.image_list = sorted(glob(str(self.image_dir.joinpath("*.tif"))))
self.mask_list = sorted(glob(str(self.mask_dir.joinpath("*.tif"))))

def __len__(self):
return len(self.image_list)

Expand Down Expand Up @@ -133,8 +133,8 @@ def __getitem__(self, index):
}

@staticmethod
def download(dataset_config: dict, silent=False):
root_path = pathlib.Path(dataset_config["root_path"])
def download(self, silent=False):
root_path = pathlib.Path(self.root_path)

# Create the root directory if it does not exist
if not root_path.exists():
Expand Down
5 changes: 4 additions & 1 deletion pangaea/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch.utils.data import Dataset

import os

class GeoFMDataset(Dataset):
"""Base class for all datasets."""
Expand Down Expand Up @@ -72,6 +72,9 @@ def __init__(
self.download_url = download_url
self.auto_download = auto_download

if not os.path.exists(self.root_path):
self.download(self)

def __len__(self) -> int:
"""Returns the length of the dataset.
Expand Down
9 changes: 1 addition & 8 deletions pangaea/datasets/biomassters.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,5 @@ def __getitem__(self, index):
}

# @staticmethod
# def get_splits(dataset_config):
# dataset_train = BioMassters(cfg=dataset_config, split='train')
# dataset_val = BioMassters(cfg=dataset_config, split='val')
# dataset_test = BioMassters(cfg=dataset_config, split='test')
# return dataset_train, dataset_val, dataset_test

# @staticmethod
# def download(dataset_config:dict, silent=False):
# def download(self, silent=False):
# pass
15 changes: 4 additions & 11 deletions pangaea/datasets/croptypemapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,24 +222,17 @@ def pad_or_crop(self, tensor):
# else:
# tensor = tensor[..., :self.grid_size]
return tensor

# @staticmethod
# def get_splits(dataset_config):
# dataset_train = CropTypeMappingSouthSudan(cfg=dataset_config, split="train")
# dataset_val = CropTypeMappingSouthSudan(cfg=dataset_config, split="val")
# dataset_test = CropTypeMappingSouthSudan(cfg=dataset_config, split="test")
# return dataset_train, dataset_val, dataset_test

@staticmethod
def download(dataset_config: dict, silent=False):
if os.path.exists(dataset_config["root_path"]):
def download(self, silent=False):
if os.path.exists(self.root_path):
if not silent:
print("CropTypeMapping Dataset folder exists, skipping downloading dataset.")
return

output_path = dataset_config["root_path"]
output_path = self.root_path
os.makedirs(output_path, exist_ok=True)
url = dataset_config["download_url"]
url = self.download_url

temp_file = os.path.join(output_path, "archive.tar.gz")

Expand Down
2 changes: 0 additions & 2 deletions pangaea/datasets/fivebillionpixels.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
from pangaea.datasets.utils import DownloadProgressBar

from pangaea.datasets.base import GeoFMDataset
# from utils.registry import DATASET_REGISTRY

# @DATASET_REGISTRY.register()
class FiveBillionPixels(GeoFMDataset):
def __init__(
self,
Expand Down
14 changes: 5 additions & 9 deletions pangaea/datasets/hlsburnscars.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,13 @@ def __init__(
self.download_url = download_url
self.auto_download = auto_download

# ISSUE
self.split_mapping = {'train': 'training', 'val': 'validation', 'test': 'validation'}

all_files = sorted(glob(os.path.join(self.root_path, self.split_mapping[self.split], '*merged.tif')))
all_targets = sorted(glob(os.path.join(self.root_path, self.split_mapping[self.split], '*mask.tif')))

if self.split != "test":
split_indices = self.get_stratified_train_val_split(all_files)
split_indices = self.get_train_val_split(all_files)
if self.split == "train":
indices = split_indices["train"]
else:
Expand All @@ -130,16 +129,13 @@ def __init__(


@staticmethod
def get_stratified_train_val_split(all_files, split) -> Tuple[Sequence[int], Sequence[int]]:
def get_train_val_split(all_files) -> Tuple[Sequence[int], Sequence[int]]:

# Fixed stratified sample to split data into train/val.
# This keeps 90% of datapoints belonging to an individual event in the training set and puts the remaining 10% in the validation set.
# disaster_names = list(
# map(lambda path: pathlib.Path(path).name.split("_")[0], all_files))
train_idxs, val_idxs = train_test_split(np.arange(len(all_files)),
test_size=0.1,
random_state=23,
# stratify=disaster_names
)
return {"train": train_idxs, "val": val_idxs}

Expand Down Expand Up @@ -185,9 +181,9 @@ def get_stratified_train_val_split(all_files) -> Tuple[Sequence[int], Sequence[i
return {"train": train_idxs, "val": val_idxs}

@staticmethod
def download(dataset_config:dict, silent=False):
output_path = pathlib.Path(dataset_config["root_path"])
url = dataset_config["download_url"]
def download(self, silent=False):
output_path = pathlib.Path(self.root_path)
url = self.download_url

try:
os.makedirs(output_path, exist_ok=False)
Expand Down
15 changes: 4 additions & 11 deletions pangaea/datasets/mados.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ def get_band(path):
return int(path.split('_')[-2])

@staticmethod
def download(dataset_config: dict, silent=False):
output_path = pathlib.Path(dataset_config["root_path"])
url = dataset_config["download_url"]
def download(self, silent=False):
output_path = pathlib.Path(self.root_path)
url = self.download_url

existing_dirs = list(output_path.glob("Scene_*"))
if existing_dirs:
Expand Down Expand Up @@ -219,11 +219,4 @@ def download(dataset_config: dict, silent=False):
zip_ref.extractall(output_path, members)
print("done.")

(output_path / temp_file_name).unlink()

# @staticmethod
# def get_splits(dataset_config):
# dataset_train = MADOS(cfg=dataset_config, split="train")
# dataset_val = MADOS(cfg=dataset_config, split="val")
# dataset_test = MADOS(cfg=dataset_config, split="test")
# return dataset_train, dataset_val, dataset_test
(output_path / temp_file_name).unlink()
19 changes: 6 additions & 13 deletions pangaea/datasets/sen1floods11.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pangaea.datasets.utils import download_bucket_concurrently
from pangaea.datasets.base import GeoFMDataset

# @DATASET_REGISTRY.register()
class Sen1Floods11(GeoFMDataset):

def __init__(
Expand Down Expand Up @@ -65,6 +64,9 @@ def __init__(
auto_download (bool): whether to download the dataset automatically.
gcs_bucket (str): subset for downloading the dataset.
"""

self.gcs_bucket = gcs_bucket

super(Sen1Floods11, self).__init__(
split=split,
dataset_name=dataset_name,
Expand All @@ -83,7 +85,6 @@ def __init__(
data_max=data_max,
download_url=download_url,
auto_download=auto_download,
# gcs_bucket=gcs_bucket,
)

self.root_path = root_path
Expand All @@ -101,7 +102,6 @@ def __init__(
self.ignore_index = ignore_index
self.download_url = download_url
self.auto_download = auto_download
self.gcs_bucket = gcs_bucket

self.split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'}

Expand Down Expand Up @@ -167,19 +167,12 @@ def __getitem__(self, index):
}
return output

# @staticmethod
# def get_splits(dataset_config):
# dataset_train = Sen1Floods11(dataset_config, split="train")
# dataset_val = Sen1Floods11(dataset_config, split="val")
# dataset_test = Sen1Floods11(dataset_config, split="test")
# return dataset_train, dataset_val, dataset_test

@staticmethod
def download(dataset_config: dict, silent=False):
if os.path.exists(dataset_config["root_path"]):
def download(self, silent=False):
if os.path.exists(self.root_path):
if not silent:
print("Sen1Floods11 Dataset folder exists, skipping downloading dataset.")
return
download_bucket_concurrently(dataset_config["gcs_bucket"], dataset_config["root_path"])
download_bucket_concurrently(self.gcs_bucket, self.root_path)


6 changes: 3 additions & 3 deletions pangaea/datasets/spacenet7.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def get_band(path):
return int(path.split('_')[-2])

@staticmethod
def download(dataset_config: dict, silent=False):
output_path = Path(dataset_config["root_path"])
def download(self, silent=False):
output_path = Path(self.root_path)

if not output_path.exists():
output_path.mkdir()
Expand All @@ -255,7 +255,7 @@ def download(dataset_config: dict, silent=False):
return

# download from Google Drive
url = dataset_config["download_url"]
url = self.download_url
tar_file = output_path / f'spacenet7.tar.gz'
gdown.download(url, str(tar_file), quiet=False)

Expand Down
2 changes: 1 addition & 1 deletion pangaea/datasets/utae_dynamicen.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,5 +229,5 @@ def __getitem__(self, index):
# return dataset_train, dataset_val, dataset_test

@staticmethod
def download(dataset_config: dict, silent=False):
def download(self, silent=False):
pass
26 changes: 3 additions & 23 deletions pangaea/datasets/xview2.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,31 +174,11 @@ def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, Any, str]]:
'metadata': {"filename":fn}
}

# return {
# 'image': {
# 't0' : {
# 'optical': img_pre,
# },
# 't1': {
# 'optical': img_post,
# },
# },
# 'target': msk,
# 'metadata': {"filename":fn}
# }

# @staticmethod
# def get_splits(dataset_config):
# dataset_train = xView2(cfg=dataset_config, split="train")
# dataset_val = xView2(cfg=dataset_config, split="val")
# dataset_test = xView2(cfg=dataset_config, split="test")
# return dataset_train, dataset_val, dataset_test


@staticmethod
def download(dataset_config:dict, silent=False):
output_path = pathlib.Path(dataset_config["root_path"])
url = dataset_config["download_url"]
def download(self, silent=False):
output_path = pathlib.Path(self.root_path)
url = self.download_url

try:
os.makedirs(output_path, exist_ok=False)
Expand Down

0 comments on commit b4d5663

Please sign in to comment.