@@ -39,6 +39,7 @@ def __init__(
39
39
data_max : dict [str , list [str ]],
40
40
download_url : str ,
41
41
auto_download : bool ,
42
+ oversample_building_damage : bool
42
43
):
43
44
"""Initialize the xView2 dataset.
44
45
Link: https://xview2.org/dataset
@@ -70,6 +71,7 @@ def __init__(
70
71
e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]}
71
72
download_url (str): url to download the dataset.
72
73
auto_download (bool): whether to download the dataset automatically.
74
+ oversample_building_damage (bool): whether to oversample images with building damage
73
75
"""
74
76
super (xView2 , self ).__init__ (
75
77
split = split ,
@@ -105,6 +107,7 @@ def __init__(
105
107
self .ignore_index = ignore_index
106
108
self .download_url = download_url
107
109
self .auto_download = auto_download
110
+ self .oversample_building_damage = oversample_building_damage
108
111
109
112
self .all_files = self .get_all_files ()
110
113
@@ -124,7 +127,7 @@ def get_all_files(self) -> Sequence[str]:
124
127
if self .split != "test" :
125
128
train_val_idcs = self .get_stratified_train_val_split (all_files )
126
129
127
- if self .split == "train" :
130
+ if self .split == "train" and self . oversample_building_damage :
128
131
train_val_idcs [self .split ] = self .oversample_building_files (all_files , train_val_idcs [self .split ])
129
132
130
133
0 commit comments