Skip to content

Commit 60ceeb9

Browse files
committed
Makes it optional to oversample images with building damage
1 parent f932a8b commit 60ceeb9

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

configs/dataset/xview2.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ auto_download: False
66
img_size: 1024
77
multi_temporal: False
88
multi_modal: False
9+
oversample_building_damage: True
910

1011
# classes
1112
ignore_index: -1

pangaea/datasets/xview2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
data_max: dict[str, list[str]],
4040
download_url: str,
4141
auto_download: bool,
42+
oversample_building_damage: bool
4243
):
4344
"""Initialize the xView2 dataset.
4445
Link: https://xview2.org/dataset
@@ -70,6 +71,7 @@ def __init__(
7071
e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]}
7172
download_url (str): url to download the dataset.
7273
auto_download (bool): whether to download the dataset automatically.
74+
oversample_building_damage (bool): whether to oversample images with building damage
7375
"""
7476
super(xView2, self).__init__(
7577
split=split,
@@ -105,6 +107,7 @@ def __init__(
105107
self.ignore_index = ignore_index
106108
self.download_url = download_url
107109
self.auto_download = auto_download
110+
self.oversample_building_damage = oversample_building_damage
108111

109112
self.all_files = self.get_all_files()
110113

@@ -124,7 +127,7 @@ def get_all_files(self) -> Sequence[str]:
124127
if self.split != "test":
125128
train_val_idcs = self.get_stratified_train_val_split(all_files)
126129

127-
if self.split == "train":
130+
if self.split == "train" and self.oversample_building_damage:
128131
train_val_idcs[self.split] = self.oversample_building_files(all_files, train_val_idcs[self.split])
129132

130133

0 commit comments

Comments
 (0)