Skip to content

Commit

Permalink
Merge pull request #109 from yurujaja/xview-oversampling
Browse files Browse the repository at this point in the history
Xview2: oversampling images with building damage
  • Loading branch information
SebastianGer authored Oct 25, 2024
2 parents 5ac49ab + 60ceeb9 commit 3fdd72b
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
1 change: 1 addition & 0 deletions configs/dataset/xview2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ auto_download: False
img_size: 1024
multi_temporal: False
multi_modal: False
oversample_building_damage: True

# classes
ignore_index: -1
Expand Down
64 changes: 62 additions & 2 deletions pangaea/datasets/xview2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Sources:
# - https://github.com/PaulBorneP/Xview2_Strong_Baseline/blob/master/datasets/base_dataset.py
# - https://github.com/PaulBorneP/Xview2_Strong_Baseline/blob/master/datasets/supervised_dataset.py
# - https://github.com/PaulBorneP/Xview2_Strong_Baseline/blob/master/legacy/datasets.py

from typing import Sequence, Dict, Any, Union, Literal, Tuple
import time
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(
data_max: dict[str, list[str]],
download_url: str,
auto_download: bool,
oversample_building_damage: bool
):
"""Initialize the xView2 dataset.
Link: https://xview2.org/dataset
Expand Down Expand Up @@ -69,6 +71,7 @@ def __init__(
e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]}
download_url (str): url to download the dataset.
auto_download (bool): whether to download the dataset automatically.
oversample_building_damage (bool): whether to oversample images with building damage
"""
super(xView2, self).__init__(
split=split,
Expand Down Expand Up @@ -104,9 +107,9 @@ def __init__(
self.ignore_index = ignore_index
self.download_url = download_url
self.auto_download = auto_download
self.oversample_building_damage = oversample_building_damage

self.all_files = self.get_all_files()


def get_all_files(self) -> Sequence[str]:
all_files = []
Expand All @@ -123,7 +126,13 @@ def get_all_files(self) -> Sequence[str]:

if self.split != "test":
train_val_idcs = self.get_stratified_train_val_split(all_files)

if self.split == "train" and self.oversample_building_damage:
train_val_idcs[self.split] = self.oversample_building_files(all_files, train_val_idcs[self.split])


all_files = [all_files[i] for i in train_val_idcs[self.split]]


return all_files

Expand All @@ -140,6 +149,34 @@ def get_stratified_train_val_split(all_files) -> Tuple[Sequence[int], Sequence[i
stratify=disaster_names)
return {"train": train_idxs, "val": val_idxs}

def oversample_building_files(self, all_files, train_idxs):
# Oversamples buildings on the image-level, by including each image with any building pixels twice in the training set.
file_classes = []
for i, fn in enumerate(all_files):
fl = np.zeros((4,), dtype=bool)
# Only read images that are included in train_idxs
if i in train_idxs:
msk1 = cv2.imread(fn.replace('/images/', '/masks/').replace('_pre_disaster', '_post_disaster'),
cv2.IMREAD_UNCHANGED)
for c in range(1, 5):
fl[c - 1] = c in msk1
file_classes.append(fl)
file_classes = np.asarray(file_classes)

new_train_idxs = []
for i in train_idxs:
new_train_idxs.append(i)
# If any building damage was present in the image, add the image to the training set a second time.
if file_classes[i, 1:].max():
new_train_idxs.append(i)
# If minor or medium damage were present, add it a third time, since these two classes are very hard to detect.
# Source: https://github.com/DIUx-xView/xView2_first_place/blob/master/train34_cls.py
if file_classes[i, 1:3].max():
new_train_idxs.append(i)
train_idxs = np.asarray(new_train_idxs)
return train_idxs


def __len__(self) -> int:
return len(self.all_files)

Expand Down Expand Up @@ -206,4 +243,27 @@ def download(self, silent=False):
tar.extractall(output_path)
print("done.")

os.remove(output_path / temp_file_name)
os.remove(output_path / temp_file_name)

if __name__=="__main__":
dataset = xView2(
split="train",
dataset_name="xView2",
root_path="./data/xView2",
download_url="https://the-dataset-is-not-publicly-available.com",
auto_download=False,
img_size=1024,
multi_temporal=False,
multi_modal=False,
classes=["No building", "No damage","Minor damage","Major damage","Destroyed"],
num_classes=5,
ignore_index=-1,
bands=["B4", "B3", "B2"],
distribution = [0.9415, 0.0448, 0.0049, 0.0057, 0.0031],
data_mean=[66.7703, 88.4452, 85.1047],
data_std=[48.3066, 51.9129, 62.7612],
data_min=[0.0, 0.0, 0.0],
data_max=[255, 255, 255],
)
x,y = dataset[0]
print(x["optical"].shape, y.shape)

0 comments on commit 3fdd72b

Please sign in to comment.