Skip to content

Commit

Permalink
Adds oversampling of images with building damage as default setting f…
Browse files Browse the repository at this point in the history
…or xview train set
  • Loading branch information
SebastianGer committed Oct 24, 2024
1 parent 102c575 commit f932a8b
Showing 1 changed file with 59 additions and 2 deletions.
61 changes: 59 additions & 2 deletions pangaea/datasets/xview2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Sources:
# - https://github.com/PaulBorneP/Xview2_Strong_Baseline/blob/master/datasets/base_dataset.py
# - https://github.com/PaulBorneP/Xview2_Strong_Baseline/blob/master/datasets/supervised_dataset.py
# - https://github.com/PaulBorneP/Xview2_Strong_Baseline/blob/master/legacy/datasets.py

from typing import Sequence, Dict, Any, Union, Literal, Tuple
import time
Expand Down Expand Up @@ -106,7 +107,6 @@ def __init__(
self.auto_download = auto_download

self.all_files = self.get_all_files()


def get_all_files(self) -> Sequence[str]:
all_files = []
Expand All @@ -123,7 +123,13 @@ def get_all_files(self) -> Sequence[str]:

if self.split != "test":
train_val_idcs = self.get_stratified_train_val_split(all_files)

if self.split == "train":
train_val_idcs[self.split] = self.oversample_building_files(all_files, train_val_idcs[self.split])


all_files = [all_files[i] for i in train_val_idcs[self.split]]


return all_files

Expand All @@ -140,6 +146,34 @@ def get_stratified_train_val_split(all_files) -> Tuple[Sequence[int], Sequence[i
stratify=disaster_names)
return {"train": train_idxs, "val": val_idxs}

def oversample_building_files(self, all_files, train_idxs):
# Oversamples buildings on the image-level, by including each image with any building pixels twice in the training set.
file_classes = []
for i, fn in enumerate(all_files):
fl = np.zeros((4,), dtype=bool)
# Only read images that are included in train_idxs
if i in train_idxs:
msk1 = cv2.imread(fn.replace('/images/', '/masks/').replace('_pre_disaster', '_post_disaster'),
cv2.IMREAD_UNCHANGED)
for c in range(1, 5):
fl[c - 1] = c in msk1
file_classes.append(fl)
file_classes = np.asarray(file_classes)

new_train_idxs = []
for i in train_idxs:
new_train_idxs.append(i)
# If any building damage was present in the image, add the image to the training set a second time.
if file_classes[i, 1:].max():
new_train_idxs.append(i)
# If minor or medium damage were present, add it a third time, since these two classes are very hard to detect.
# Source: https://github.com/DIUx-xView/xView2_first_place/blob/master/train34_cls.py
if file_classes[i, 1:3].max():
new_train_idxs.append(i)
train_idxs = np.asarray(new_train_idxs)
return train_idxs


def __len__(self) -> int:
return len(self.all_files)

Expand Down Expand Up @@ -206,4 +240,27 @@ def download(self, silent=False):
tar.extractall(output_path)
print("done.")

os.remove(output_path / temp_file_name)
os.remove(output_path / temp_file_name)

if __name__=="__main__":
dataset = xView2(
split="train",
dataset_name="xView2",
root_path="./data/xView2",
download_url="https://the-dataset-is-not-publicly-available.com",
auto_download=False,
img_size=1024,
multi_temporal=False,
multi_modal=False,
classes=["No building", "No damage","Minor damage","Major damage","Destroyed"],
num_classes=5,
ignore_index=-1,
bands=["B4", "B3", "B2"],
distribution = [0.9415, 0.0448, 0.0049, 0.0057, 0.0031],
data_mean=[66.7703, 88.4452, 85.1047],
data_std=[48.3066, 51.9129, 62.7612],
data_min=[0.0, 0.0, 0.0],
data_max=[255, 255, 255],
)
x,y = dataset[0]
print(x["optical"].shape, y.shape)

0 comments on commit f932a8b

Please sign in to comment.