Skip to content

Commit

Permalink
Merge pull request #81 from yurujaja/spacenet7
Browse files Browse the repository at this point in the history
SpaceNet 7
  • Loading branch information
VMarsocci authored Oct 2, 2024
2 parents 257660c + 1ca6e96 commit 1ef146d
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 50 deletions.
2 changes: 1 addition & 1 deletion configs/dataset/spacenet7.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ multi_modal: False

# classes
ignore_index: -1
num_classes: 2
num_classes: 1
classes:
- Background
- Building
Expand Down
2 changes: 1 addition & 1 deletion configs/dataset/spacenet7_domainshift.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ multi_modal: False

# classes
ignore_index: -1
num_classes: 2
num_classes: 1
classes:
- Background
- Building
Expand Down
5 changes: 3 additions & 2 deletions configs/dataset/spacenet7cd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ j_split: 512
multi_temporal: 2
multi_modal: False

dataset_multiplier: 1 # multiplies sample in dataset during training
dataset_multiplier: 5 # multiplies sample in dataset during training
minimum_temporal_gap: 5

# classes
ignore_index: -1
num_classes: 2
num_classes: 1
classes:
- No Change
- Change
Expand Down
48 changes: 48 additions & 0 deletions configs/dataset/spacenet7cd_domainshift.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
_target_: pangaea.datasets.spacenet7.SN7CD
dataset_name: SN7CD
root_path: ./data/spacenet7
download_url: https://drive.google.com/uc?id=1BADSEjxYKFZZlM-tEkRUfHvHi5XdaVV9
auto_download: True

img_size: 256 # the image size is used to tile the SpaceNet 7 images (1024, 1024)
domain_shift: True
# parameters for within-scene splits (no domain shift)
i_split: 768
j_split: 512

multi_temporal: 2
multi_modal: False

dataset_multiplier: 5 # multiplies sample in dataset during training
minimum_temporal_gap: 5

# classes
ignore_index: -1
num_classes: 1
classes:
- No Change
- Change
distribution: # TODO: update for CD
- 0.92530769
- 0.07469231

# data stats
bands:
optical:
- B4 # Band 1 (Red)
- B3 # Band 2 (Green)
- B2 # Band 3 (Blue)
data_mean:
optical:
- 121.826
- 106.52838
- 78.372116
data_std:
optical:
- 56.717068
- 44.517075
- 40.451515
data_min:
optical: [0.0, 0.0, 0.0]
data_max:
optical: [255.0, 255.0, 255.0]
4 changes: 4 additions & 0 deletions configs/decoder/seg_siamunet_conc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: pangaea.decoders.unet.SiamConcUNet

num_classes: ${dataset.num_classes}
finetune: ${finetune}
4 changes: 4 additions & 0 deletions configs/decoder/seg_siamunet_diff.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: pangaea.decoders.unet.SiamDiffUNet

num_classes: ${dataset.num_classes}
finetune: ${finetune}
File renamed without changes.
84 changes: 40 additions & 44 deletions pangaea/datasets/spacenet7.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(
i_splt (int): .
j_split (int): . #ISSUES
"""
super(AbstractSN7, self).__init__(
super().__init__(
split=split,
dataset_name=dataset_name,
multi_modal=multi_modal,
Expand All @@ -174,9 +174,6 @@ def __init__(
data_max=data_max,
download_url=download_url,
auto_download=auto_download,
# domain_shift=domain_shift,
# i_split=i_split,
# j_split=j_split,
)


Expand All @@ -185,17 +182,15 @@ def __init__(
with open(metadata_file, 'r') as f:
self.metadata = json.load(f)

self.img_size = 1024 # size of the SpaceNet 7 images
# unpacking config
self.tile_size = img_size # size used for tiling the images
assert self.img_size % self.tile_size == 0
self.sn7_img_size = 1024 # size of the SpaceNet 7 images
self.img_size = img_size # size used for tiling the images
assert self.sn7_img_size % self.img_size == 0

self.data_mean = data_mean
self.data_std = data_std
self.data_min = data_min
self.data_max = data_max
self.classes = classes
self.img_size = img_size
self.distribution = distribution
self.num_classes = self.class_num = num_classes
self.ignore_index = ignore_index
Expand All @@ -220,7 +215,7 @@ def load_planet_mosaic(self, aoi_id: str, year: int, month: int) -> np.ndarray:
folder = self.root_path / 'train' / aoi_id / 'images_masked'
file = folder / f'global_monthly_{year}_{month:02d}_mosaic_{aoi_id}.tif'
with rasterio.open(str(file), mode='r') as src:
img = src.read(out_shape=(1024, 1024), resampling=rasterio.enums.Resampling.nearest)
img = src.read(out_shape=(self.sn7_img_size, self.sn7_img_size), resampling=rasterio.enums.Resampling.nearest)
# 4th band (last oen) is alpha band
img = img[:-1]
return img.astype(np.float32)
Expand All @@ -229,7 +224,7 @@ def load_building_label(self, aoi_id: str, year: int, month: int) -> np.ndarray:
folder = self.root_path / 'train' / aoi_id / 'labels_raster'
file = folder / f'global_monthly_{year}_{month:02d}_mosaic_{aoi_id}_Buildings.tif'
with rasterio.open(str(file), mode='r') as src:
label = src.read(out_shape=(1024, 1024), resampling=rasterio.enums.Resampling.nearest)
label = src.read(out_shape=(self.sn7_img_size, self.sn7_img_size), resampling=rasterio.enums.Resampling.nearest)
label = (label > 0).squeeze()
return label.astype(np.int64)

Expand Down Expand Up @@ -306,7 +301,7 @@ def __init__(
):
"""Initialize the SpaceNet dataset for building mapping.
"""
super(SN7MAPPING, self).__init__(
super().__init__(
split=split,
dataset_name=dataset_name,
multi_modal=multi_modal,
Expand Down Expand Up @@ -353,15 +348,15 @@ def __init__(
'month': timestamp['month'],
}
# tiling the timestamps
for i in range(0, self.img_size, self.tile_size):
for j in range(0, self.img_size, self.tile_size):
for i in range(0, self.sn7_img_size, self.img_size):
for j in range(0, self.sn7_img_size, self.img_size):
item['i'] = i
item['j'] = j
self.items.append(dict(item))

else: # within-scenes split
assert self.i_split % self.tile_size == 0 and self.j_split % self.tile_size == 0
assert self.tile_size <= self.i_split and self.tile_size <= self.j_split
assert self.i_split % self.img_size == 0 and self.j_split % self.img_size == 0
assert self.img_size <= self.i_split and self.img_size <= self.j_split
self.aoi_ids = list(self.sn7_aois)
for aoi_id in self.aoi_ids:
timestamps = list(self.metadata[aoi_id])
Expand All @@ -374,18 +369,18 @@ def __init__(
}
if split == 'train':
i_min, i_max = 0, self.i_split
j_min, j_max = 0, self.img_size
j_min, j_max = 0, self.sn7_img_size
elif split == 'val':
i_min, i_max = self.i_split, self.img_size
i_min, i_max = self.i_split, self.sn7_img_size
j_min, j_max = 0, self.j_split
elif split == 'test':
i_min, i_max = self.i_split, self.img_size
j_min, j_max = self.j_split, self.img_size
i_min, i_max = self.i_split, self.sn7_img_size
j_min, j_max = self.j_split, self.sn7_img_size
else:
raise Exception('Unkown split')
# tiling the timestamps
for i in range(i_min, i_max, self.tile_size):
for j in range(j_min, j_max, self.tile_size):
for i in range(i_min, i_max, self.img_size):
for j in range(j_min, j_max, self.img_size):
item['i'] = i
item['j'] = j
self.items.append(dict(item))
Expand All @@ -403,8 +398,8 @@ def __getitem__(self, index):

# cut to tile
i, j = item['i'], item['j']
image = image[:, i:i + self.tile_size, j:j + self.tile_size]
target = target[i:i + self.tile_size, j:j + self.tile_size]
image = image[:, i:i + self.img_size, j:j + self.img_size]
target = target[i:i + self.img_size, j:j + self.img_size]

image = torch.from_numpy(image)
target = torch.from_numpy(target)
Expand Down Expand Up @@ -446,16 +441,16 @@ def __init__(
domain_shift: bool,
i_split: int,
j_split: int,
eval_mode: bool,
dataset_multiplier: int,
minimum_temporal_gap: int,
):
"""Initialize the SpaceNet dataset for change detection.
...
eval_mode (bool): select if evaluation is happening. Instanciate true for val and test
dataset_multiplier (int): multiplies sample in dataset during training.
"""
super(SN7MAPPING, self).__init__(
super().__init__(
split=split,
dataset_name=dataset_name,
multi_modal=multi_modal,
Expand All @@ -476,16 +471,14 @@ def __init__(
domain_shift=domain_shift,
i_split=i_split,
j_split=j_split,
# eval_mode=eval_mode,
# dataset_multiplier=dataset_multiplier,

)

self.T = self.multi_temporal
assert self.T > 1

self.eval_mode = eval_mode
self.multiplier = 1 if eval_mode else dataset_multiplier
self.eval_mode = False if split == 'train' else True
self.multiplier = 1 if self.eval_mode else dataset_multiplier
self.min_gap = minimum_temporal_gap

self.split = split
self.items = []
Expand All @@ -504,32 +497,32 @@ def __init__(
for aoi_id in self.aoi_ids:
item = { 'aoi_id': aoi_id }
# tiling the timestamps
for i in range(0, self.img_size, self.tile_size):
for j in range(0, self.img_size, self.tile_size):
for i in range(0, self.sn7_img_size, self.img_size):
for j in range(0, self.sn7_img_size, self.img_size):
item['i'] = i
item['j'] = j
self.items.append(dict(item))

else: # within-scenes split
assert self.i_split % self.tile_size == 0 and self.j_split % self.tile_size == 0
assert self.tile_size <= self.i_split and self.tile_size <= self.j_split
assert self.i_split % self.img_size == 0 and self.j_split % self.img_size == 0
assert self.img_size <= self.i_split and self.img_size <= self.j_split
self.aoi_ids = list(self.sn7_aois)
for aoi_id in self.aoi_ids:
item = { 'aoi_id': aoi_id }
if split == 'train':
i_min, i_max = 0, self.i_split
j_min, j_max = 0, self.img_size
j_min, j_max = 0, self.sn7_img_size
elif split == 'val':
i_min, i_max = self.i_split, self.img_size
i_min, i_max = self.i_split, self.sn7_img_size
j_min, j_max = 0, self.j_split
elif split == 'test':
i_min, i_max = self.i_split, self.img_size
j_min, j_max = self.j_split, self.img_size
i_min, i_max = self.i_split, self.sn7_img_size
j_min, j_max = self.j_split, self.sn7_img_size
else:
raise Exception('Unkown split')
# tiling the timestamps
for i in range(i_min, i_max, self.tile_size):
for j in range(j_min, j_max, self.tile_size):
for i in range(i_min, i_max, self.img_size):
for j in range(j_min, j_max, self.img_size):
item['i'] = i
item['j'] = j
self.items.append(dict(item))
Expand All @@ -549,7 +542,10 @@ def __getitem__(self, index):
t_values = list(np.linspace(0, len(timestamps), self.T, endpoint=False, dtype=int))
else:
if self.T == 2:
t_values = [0, -1]
# t_values = [0, -1]
t1 = np.random.randint(0, len(timestamps) - self.min_gap)
t2 = np.random.randint(t1 + self.min_gap, len(timestamps))
t_values = [t1, t2]
else: # randomly add intermediate timestamps
t_values = [0] + sorted(np.random.randint(1, len(timestamps) - 1, size=self.T - 2)) + [-1]

Expand All @@ -570,8 +566,8 @@ def __getitem__(self, index):

# cut to tile
i, j = item['i'], item['j']
image = image[:, :, i:i + self.tile_size, j:j + self.tile_size]
target = target[i:i + self.tile_size, j:j + self.tile_size]
image = image[:, :, i:i + self.img_size, j:j + self.img_size]
target = target[i:i + self.img_size, j:j + self.img_size]

# weight for oversampling
weight = torch.empty(target.shape)
Expand Down
2 changes: 0 additions & 2 deletions pangaea/decoders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def __init__(
encoder: Encoder,
num_classes: int,
finetune: bool,
strategy: str,
):
super().__init__(
encoder=encoder,
Expand All @@ -143,7 +142,6 @@ def __init__(
encoder: Encoder,
num_classes: int,
finetune: bool,
strategy: str,
):
super().__init__(
encoder=encoder,
Expand Down

0 comments on commit 1ef146d

Please sign in to comment.