What is DOFA: DOFA is a unified multimodal foundation model for different data modalities in remote sensing and Earth observation.
Differences with existing foundation models: DOFA is pre-trained using five different data modalities in remote sensing and Earth observation. It can handle images with any number of input channels.
DOFA is inspired by neuroplasticity Neuroplasticity is an important brain mechanism for adjusting to new experiences or environmental shifts. Inspired by this concept, we design DOFA to emulate this mechanism for processing multimodal EO data.
Please refer to the paper Neural Plasticity-Inspired Foundation Model for Observing the Earth Crossing Modalities for more details.
-
The learned multimodal representation may not effectively capture such an intersensor relationship.
-
The performance of foundation models will degrade when downstream tasks require the utilization of data from unseen sensors with varying numbers of spectral bands and spatial resolutions or different wavelength regimes.
-
The development of individual, customized foundation models requires considerably more computing resources and human efforts.
-
The increasing number of specialized foundation models makes it difficult to select the most appropriate one for a specific downstream task.
The requirements of DOFA can be installed as follows:
> pip install -r requirements.txt
Pre-trained model weights can be downloaded from HuggingFace.
Please refer to demo.ipynb for more details.
DOFA supports input images with any number of channels using our pre-trained foundation models. The following examples show how to use DOFA for Sentinel-1 (SAR), Sentinel-2, NAIP RGB. We will add example usage for Gaofen Multispectral, and Hyperspectral data soon.
The following examples show that how to use a single DOFA model to process image data from different modalities with any number of channels!
python download_weights.py
from models_dwv import vit_base_patch16
check_point = torch.load('./checkpoints/DOFA_ViT_base_e100.pth')
vit_model = vit_base_patch16()
msg = vit_model.load_state_dict(check_point, strict=False)
# missing_keys=['fc_norm.weight', 'fc_norm.bias', 'head.weight', 'head.bias'], unexpected_keys=['mask_token', 'norm.weight', 'norm.bias', 'projector.weight', 'projector.bias']
vit_model = vit_model.cuda()
Now you can use the loaded single DOFA model to process image data from different modalities with any number of channels!
# Step 1: Data preprocessing (normalization and resize)
import torch
import rasterio
import kornia as K
import numpy as np
# vh,vv
S1_MEAN = [166.36275909, 88.45542715]# / 255.0
S1_STD = [64.83126309, 43.07350145]# /255.0
S2_MEAN = [114.1099739 , 114.81779093, 126.63977424, 84.33539309,
97.84789168, 103.94461911, 101.435633 , 72.32804172,
56.66528851]
S2_STD = [77.84352553, 69.96844919, 67.42465279, 64.57022983, 61.72545487,
61.34187099, 60.29744676, 47.88519516, 42.55886798]
NAIP_MEAN = [123.675, 116.28, 103.53] # ImageNet stats for now
NAIP_STD = [58.395, 57.12, 57.375] # ImageNet stats for now
Gaufen_MEAN = [123.94924583, 92.58088583, 97.28130189, 90.31526596]
Gaufen_STD = [67.34487297, 62.8271046 , 60.5856767 , 60.3946299]
class DataAugmentation(torch.nn.Module):
def __init__(self, mean, std):
super().__init__()
self.transform = torch.nn.Sequential(
K.augmentation.RandomResizedCrop(size=(224,224), scale=(0.2,1.0)),
K.augmentation.Normalize(mean=mean,std=std)
)
@torch.no_grad()
def forward(self,x):
x_out = self.transform(x)
return x_out
transform = DataAugmentation(mean=S1_MEAN,std=S1_STD)
def preprocess_s1(vh_path, vv_path):
with rasterio.open(vh_path) as f1:
vh = f1.read()
with rasterio.open(vv_path) as f2:
vv = f2.read()
s1_img = np.concatenate((vh,vv),0).astype('float32')
s1_img = torch.from_numpy(s1_img)
s1_img = transform(s1_img).squeeze(0)
return s1_img
import matplotlib.pyplot as plt
# Load Sentinel-1 images from the given example data
C = 2 # can be 2,3,4,6,9,12,13,202 or any number if you can provide the wavelengths of them
image1 = './data/s1/vv/1869_3575.png'
image2 = './data/s1/vh/1869_3575.png'
s1_img = preprocess_s1(image1,image2)
fig, ax = plt.subplots(nrows=1, ncols=C, figsize=(10, 10))
for i,row in enumerate(ax):
row.imshow(s1_img[i])
s1_img = s1_img.view([1,2,224,224]).cuda()
wavelengths = [3.75, 3.75]
out_feat = vit_model.forward_features(s1_img, wave_list=wavelengths)
out_logits = vit_model.forward(s1_img, wave_list=wavelengths)
print(out_feat.shape)
print(out_logits.shape)
import glob
C = 9
transform = DataAugmentation(mean=S2_MEAN,std=S2_STD)
def preprocess_s2(img_path):
chs = []
s2_files = glob.glob(f"{img_path}/*/*.png")
for path in s2_files:
with rasterio.open(path) as f:
ch = f.read()
chs.append(ch)
s2_img = np.concatenate(chs, 0).astype("float32")
s2_img = torch.from_numpy(s2_img)
s2_img = transform(s2_img).squeeze(0)
return s2_img
Visualize the Sentinel-2 imagery
s2_img = preprocess_s2("data/s2/S2A_MSIL1C_20170528T050611_N0205_R076_T44NMM_20170528T050606")
fig, ax = plt.subplots(nrows=1, ncols=C, figsize=(10, 10))
for i,row in enumerate(ax):
row.imshow(s2_img[i])
row.axis("off")
s2_img = s2_img.view([1,C,224,224]).cuda()
/opt/miniconda3/lib/python3.11/site-packages/rasterio/__init__.py:304: NotGeoreferencedWarning: Dataset has no geotransform, gcps, or rpcs. The identity matrix will be returned.
dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
wavelengths = [0.665, 0.56, 0.49, 0.705, 0.74, 0.783, 0.842, 1.61, 2.19]
out_feat = vit_model.forward_features(s2_img, wave_list=wavelengths)
out_logits = vit_model.forward(s2_img, wave_list=wavelengths)
print(out_feat.shape)
print(out_logits.shape)
# Let's only keep the first 5 channels
wavelengths = [0.665, 0.56, 0.49, 0.705, 0.74]
out_feat = vit_model.forward_features(s2_img[:,:5,...], wave_list=wavelengths)
out_logits = vit_model.forward(s2_img[:,:5,...], wave_list=wavelengths)
print(out_feat.shape)
print(out_logits.shape)
C = 3
transform = DataAugmentation(mean=NAIP_MEAN, std=NAIP_STD)
def preprocess_rgb(img_path):
with rasterio.open(img_path) as f:
rgb_img = f.read().astype("float32")
rgb_img = torch.from_numpy(rgb_img)
rgb_img = transform(rgb_img).squeeze(0)
return rgb_img
import cv2
rgb_path = 'data/naip/36861_49963.png'
naip_img = preprocess_rgb(rgb_path)
plt.imshow(cv2.cvtColor(cv2.imread(rgb_path), cv2.COLOR_BGR2RGB))
plt.axis("off")
naip_img = naip_img.view([1,C,224,224]).cuda()
wavelengths = [0.665, 0.56, 0.49]
out_feat = vit_model.forward_features(naip_img, wave_list=wavelengths)
out_logits = vit_model.forward(naip_img, wave_list=wavelengths)
print(out_feat.shape)
print(out_logits.shape)
Usage for Hyperspectral images is similar to other images.
Alternatively, DOFA can be used via the TorchGeo library:
import torch
from torchgeo.models import DOFABase16_Weights, dofa_base_patch16_224
# Example NAIP image (wavelengths in $\mu$m)
x = torch.rand(2, 4, 224, 224)
wavelengths = [0.48, 0.56, 0.64, 0.81]
# Use pre-trained model weights
model = dofa_base_patch16_224(weights=DOFABase16_Weights.DOFA_MAE)
# Make a prediction (model may need to be fine-tuned first)
y = model(x, wavelengths)
If you find the DOFA useful in your research, please kindly cite our paper:
@article{xiong2024neural,
title={Neural Plasticity-Inspired Foundation Model for Observing the {Earth} Crossing Modalities},
author={Xiong, Zhitong and Wang, Yi and Zhang, Fahong and Stewart, Adam J and Hanna, Jo{\"e}lle and Borth, Damian and Papoutsis, Ioannis and Saux, Bertrand Le and Camps-Valls, Gustau and Zhu, Xiao Xiang},
journal={arXiv preprint arXiv:2403.15356},
year={2024}
}