diff --git a/configs/dataset/xview2.yaml b/configs/dataset/xview2.yaml index 4fd027c..cac06a8 100644 --- a/configs/dataset/xview2.yaml +++ b/configs/dataset/xview2.yaml @@ -6,6 +6,7 @@ auto_download: False img_size: 1024 multi_temporal: False multi_modal: False +oversample_building_damage: True # classes ignore_index: -1 diff --git a/pangaea/datasets/xview2.py b/pangaea/datasets/xview2.py index 2b415c7..820da6e 100644 --- a/pangaea/datasets/xview2.py +++ b/pangaea/datasets/xview2.py @@ -39,6 +39,7 @@ def __init__( data_max: dict[str, list[str]], download_url: str, auto_download: bool, + oversample_building_damage: bool ): """Initialize the xView2 dataset. Link: https://xview2.org/dataset @@ -70,6 +71,7 @@ def __init__( e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]} download_url (str): url to download the dataset. auto_download (bool): whether to download the dataset automatically. + oversample_building_damage (bool): whether to oversample images with building damage """ super(xView2, self).__init__( split=split, @@ -105,6 +107,7 @@ def __init__( self.ignore_index = ignore_index self.download_url = download_url self.auto_download = auto_download + self.oversample_building_damage = oversample_building_damage self.all_files = self.get_all_files() @@ -124,7 +127,7 @@ def get_all_files(self) -> Sequence[str]: if self.split != "test": train_val_idcs = self.get_stratified_train_val_split(all_files) - if self.split == "train": + if self.split == "train" and self.oversample_building_damage: train_val_idcs[self.split] = self.oversample_building_files(all_files, train_val_idcs[self.split])