Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
**/__pycache__/
*.pyc

old_files/
pretrained/
Expand All @@ -15,4 +16,4 @@ data/mados/splits/*
!data/mados/splits/tiny_X.txt

.vscode
.idea
.idea
12 changes: 4 additions & 8 deletions datasets/biomassters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
import torch
import pandas as pd
import pathlib
import rasterio
from tifffile import imread
from os.path import join as opj
from .utils import read_tif
import tifffile
from utils.registry import DATASET_REGISTRY

def read_imgs(multi_temporal, temp , fname, data_dir, img_size):
Expand All @@ -22,7 +19,7 @@ def read_imgs(multi_temporal, temp , fname, data_dir, img_size):

s1_filepath = data_dir.joinpath(s1_fname)
if s1_filepath.exists():
img_s1 = imread(s1_filepath)
img_s1 = tifffile.imread(s1_filepath)
m = img_s1 == -9999
img_s1 = img_s1.astype('float32')
img_s1 = np.where(m, 0, img_s1)
Expand All @@ -31,7 +28,7 @@ def read_imgs(multi_temporal, temp , fname, data_dir, img_size):

s2_filepath = data_dir.joinpath(s2_fname)
if s2_filepath.exists():
img_s2 = imread(s2_filepath)
img_s2 = tifffile.imread(s2_filepath)
img_s2 = img_s2.astype('float32')
else:
img_s2 = np.zeros((img_size, img_size) + (11,), dtype='float32')
Expand Down Expand Up @@ -77,8 +74,7 @@ def __getitem__(self, index):
fname = str(chip_id)+'_agbm.tif'

imgs_s1, imgs_s2, mask = read_imgs(self.multi_temporal, self.temp, fname, self.dir_features, self.img_size)
with rasterio.open(self.dir_labels.joinpath(fname)) as lbl:
target = lbl.read(1)
target = tifffile.imread(self.dir_labels.joinpath(fname), key=0)
target = np.nan_to_num(target)

imgs_s1 = torch.from_numpy(imgs_s1).float()
Expand Down
1 change: 0 additions & 1 deletion datasets/fivebillionpixels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import time
import torch
import numpy as np
import rasterio
import random
from glob import glob

Expand Down
10 changes: 3 additions & 7 deletions datasets/hlsburnscars.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import time
import torch
import numpy as np
import rasterio
import tifffile
from glob import glob

import torch
import torchvision.transforms.functional as TF
import torchvision.transforms as T

Expand Down Expand Up @@ -41,11 +40,8 @@ def __len__(self):
return len(self.image_list)

def __getitem__(self, index):
with rasterio.open(self.image_list[index]) as src:
image = src.read()
with rasterio.open(self.target_list[index]) as src:
target = src.read(1)

image = tifffile.imread(self.image_list[index])
target = tifffile.imread(self.target_list[index], key=0)
image = torch.from_numpy(image)
target = torch.from_numpy(target.astype(np.int64))

Expand Down
96 changes: 48 additions & 48 deletions datasets/mados.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'''
"""
Adapted from: https://github.com/gkakogeorgiou/mados
'''
"""

import os
import time
Expand All @@ -10,16 +10,10 @@
import zipfile

from glob import glob
import rasterio
import tifffile
import numpy as np

import warnings

warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

import torch
import torchvision.transforms.functional as TF
import torchvision.transforms as T

from .utils import DownloadProgressBar
from utils.registry import DATASET_REGISTRY
Expand All @@ -29,39 +23,45 @@
# MADOS DATASET #
###############################################################


@DATASET_REGISTRY.register()
class MADOS(torch.utils.data.Dataset):
def __init__(self, cfg, split, is_train=True):

self.root_path = cfg['root_path']
self.data_mean = cfg['data_mean']
self.data_std = cfg['data_std']
self.classes = cfg['classes']
self.root_path = cfg["root_path"]
self.data_mean = cfg["data_mean"]
self.data_std = cfg["data_std"]
self.classes = cfg["classes"]
self.class_num = len(self.classes)
self.split = split
self.is_train = is_train

self.ROIs_split = np.genfromtxt(os.path.join(self.root_path, 'splits', f'{split}_X.txt'), dtype='str')
self.ROIs_split = np.genfromtxt(
os.path.join(self.root_path, "splits", f"{split}_X.txt"), dtype="str"
)

self.image_list = []
self.target_list = []

self.tiles = sorted(glob(os.path.join(self.root_path, '*')))
self.tiles = sorted(glob(os.path.join(self.root_path, "*")))

for tile in self.tiles:
splits = [f.split('_cl_')[-1] for f in glob(os.path.join(tile, '10', '*_cl_*'))]
splits = [
f.split("_cl_")[-1] for f in glob(os.path.join(tile, "10", "*_cl_*"))
]

for crop in splits:
crop_name = os.path.basename(tile) + '_' + crop.split('.tif')[0]
crop_name = os.path.basename(tile) + "_" + crop.split(".tif")[0]

if crop_name in self.ROIs_split:
all_bands = glob(os.path.join(tile, '*', '*L2R_rhorc*_' + crop))
all_bands = glob(os.path.join(tile, "*", "*L2R_rhorc*_" + crop))
all_bands = sorted(all_bands, key=self.get_band)
# all_bands = np.array(all_bands)

self.image_list.append(all_bands)

cl_path = os.path.join(tile, '10', os.path.basename(tile) + '_L2R_cl_' + crop)
cl_path = os.path.join(
tile, "10", os.path.basename(tile) + "_L2R_cl_" + crop
)
self.target_list.append(cl_path)

def __len__(self):
Expand All @@ -72,42 +72,39 @@ def getnames(self):

def __getitem__(self, index):

all_bands = self.image_list[index]
band_paths = self.image_list[index]
current_image = []
for c, band in enumerate(all_bands):
upscale_factor = int(os.path.basename(os.path.dirname(band))) // 10
with rasterio.open(band, mode='r') as src:
this_band = src.read(1,
out_shape=(int(src.height * upscale_factor), int(src.width * upscale_factor)),
resampling=rasterio.enums.Resampling.nearest
)
this_band = torch.from_numpy(this_band)
#this_band[torch.isnan(this_band)] = self.data_mean['optical'][c]
current_image.append(this_band)

image = torch.stack(current_image)
for path in band_paths:
upscale_factor = int(os.path.basename(os.path.dirname(path))) // 10

band = tifffile.imread(path)
band_tensor = torch.from_numpy(band)
band_tensor.unsqueeze_(0).unsqueeze_(0)
band_tensor = torch.nn.functional.interpolate(
band_tensor, scale_factor=upscale_factor, mode="nearest"
).squeeze_(0)
current_image.append(band_tensor)

image = torch.cat(current_image)
invalid_mask = torch.isnan(image)
image[invalid_mask] = 0


with rasterio.open(self.target_list[index], mode='r') as src:
target = src.read(1)
target = tifffile.imread(self.target_list[index])
target = torch.from_numpy(target.astype(np.int64))
target = target - 1

output = {
'image': {
'optical': image,
"image": {
"optical": image,
},
'target': target,
'metadata': {}
"target": target,
"metadata": {},
}

return output

@staticmethod
def get_band(path):
return int(path.split('_')[-2])
return int(path.split("_")[-2])

@staticmethod
def download(dataset_config: dict, silent=False):
Expand All @@ -128,15 +125,17 @@ def download(dataset_config: dict, silent=False):
try:
urllib.request.urlretrieve(url, output_path / temp_file_name, pbar)
except urllib.error.HTTPError as e:
print('Error while downloading dataset: The server couldn\'t fulfill the request.')
print('Error code: ', e.code)
print(
"Error while downloading dataset: The server couldn't fulfill the request."
)
print("Error code: ", e.code)
return
except urllib.error.URLError as e:
print('Error while downloading dataset: Failed to reach a server.')
print('Reason: ', e.reason)
print("Error while downloading dataset: Failed to reach a server.")
print("Reason: ", e.reason)
return

with zipfile.ZipFile(output_path / temp_file_name, 'r') as zip_ref:
with zipfile.ZipFile(output_path / temp_file_name, "r") as zip_ref:
print(f"Extracting to {output_path} ...")
# Remove top-level dir in ZIP file for nicer data dir structure
members = []
Expand All @@ -155,4 +154,5 @@ def get_splits(dataset_config):
dataset_train = MADOS(cfg=dataset_config, split="train", is_train=True)
dataset_val = MADOS(cfg=dataset_config, split="val", is_train=False)
dataset_test = MADOS(cfg=dataset_config, split="test", is_train=False)
return dataset_train, dataset_val, dataset_test
return dataset_train, dataset_val, dataset_test

14 changes: 6 additions & 8 deletions datasets/pastis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import geopandas as gpd
import numpy as np
import pandas as pd
import rasterio
import tifffile
import torch
from einops import rearrange
from omegaconf import OmegaConf
Expand Down Expand Up @@ -142,17 +142,15 @@ def __getitem__(self, i):

for modality in self.modalities:
if modality == "aerial":
with rasterio.open(
os.path.join(
path = os.path.join(
self.path,
"DATA_SPOT/PASTIS_SPOT6_RVB_1M00_2019/SPOT6_RVB_1M00_2019_"
+ str(name)
+ ".tif",
)
) as f:
output["aerial"] = split_image(
torch.FloatTensor(f.read()), self.nb_split, part
)
)
output["aerial"] = split_image(
torch.FloatTensor(tifffile.imread(path), self.nb_split, part)
)
elif modality == "s1-median":
modality_name = "s1a"
images = split_image(
Expand Down
14 changes: 5 additions & 9 deletions datasets/sen1floods11.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import geopandas
import numpy as np
import pandas as pd
import rasterio
import tifffile
import torch

from .utils import download_bucket_concurrently
Expand Down Expand Up @@ -59,16 +59,12 @@ def _get_date(self, index):
return date_np

def __getitem__(self, index):
with rasterio.open(self.s2_image_list[index]) as src:
s2_image = src.read()
s2_image = tifffile.imread(self.s2_image_list[index])

with rasterio.open(self.s1_image_list[index]) as src:
s1_image = src.read()
# Convert the missing values (clouds etc.)
s1_image = np.nan_to_num(s1_image)
s1_image = tifffile.imread(self.s1_image_list[index])
s1_image = np.nan_to_num(s1_image)

with rasterio.open(self.target_list[index]) as src:
target = src.read(1)
target = tifffile.imread(self.target_list[index], key=0)

timestamp = self._get_date(index)

Expand Down
15 changes: 8 additions & 7 deletions datasets/spacenet7.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import json
from glob import glob
import rasterio
import cv2
import tifffile
import numpy as np

import torch
Expand Down Expand Up @@ -132,17 +133,17 @@ def __len__(self) -> int:
def load_planet_mosaic(self, aoi_id: str, year: int, month: int) -> np.ndarray:
folder = self.root_path / 'train' / aoi_id / 'images_masked'
file = folder / f'global_monthly_{year}_{month:02d}_mosaic_{aoi_id}.tif'
with rasterio.open(str(file), mode='r') as src:
img = src.read(out_shape=(1024, 1024), resampling=rasterio.enums.Resampling.nearest)
# 4th band (last oen) is alpha band
img = img[:-1]
img = tifffile.imread(file)
img = cv2.resize(img, dsize=(self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)
# 4th band (last one) is alpha band
img = img.transpose(2, 0, 1)[:-1]
return img.astype(np.float32)

def load_building_label(self, aoi_id: str, year: int, month: int) -> np.ndarray:
folder = self.root_path / 'train' / aoi_id / 'labels_raster'
file = folder / f'global_monthly_{year}_{month:02d}_mosaic_{aoi_id}_Buildings.tif'
with rasterio.open(str(file), mode='r') as src:
label = src.read(out_shape=(1024, 1024), resampling=rasterio.enums.Resampling.nearest)
label = tifffile.imread(file)
label = cv2.resize(label, dsize=(self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)
label = (label > 0).squeeze()
return label.astype(np.int64)

Expand Down
Loading
Loading