Skip to content

Commit

Permalink
66 add encoders inputoutput shapes to docstrings (#75)
Browse files Browse the repository at this point in the history
* Update base encoder forward docstring with shapes

* Update SegUPerNet forward docstring

* Update SegMTUPerNet forward docstring with input/output shapes
  • Loading branch information
gle-bellier authored Sep 29, 2024
1 parent 19d4310 commit 80ea64d
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 9 deletions.
32 changes: 30 additions & 2 deletions pangaea/decoders/upernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,19 @@ def _forward_feature(self, inputs):
feats = self.fpn_bottleneck(fpn_outs)
return feats

def forward(self, img, output_shape=None):
"""Forward function."""
def forward(self, img: dict[str, torch.Tensor], output_shape: torch.Size | None =None) -> torch.Tensor:
"""Compute the segmentation output.
Args:
img (dict[str, torch.Tensor]): input data structured as a dictionary:
img = {modality1: tensor1, modality2: tensor2, ...}, e.g. img = {"optical": tensor1, "sar": tensor2}.
with tensor1 and tensor2 of shape (B C T=1 H W) with C the number of encoders'bands for the given modality.
output_shape (torch.Size | None, optional): output's spatial dims (H, W) (equals to the target spatial dims).
Defaults to None.
Returns:
torch.Tensor: output tensor of shape (B, num_classes, H', W') with (H' W') coressponding to the output_shape.
"""

# img[modality] of shape [B C T=1 H W]
if self.encoder.multi_temporal:
Expand Down Expand Up @@ -206,6 +217,8 @@ def forward(self, img, output_shape=None):
# fixed bug just for optical single modality
if output_shape is None:
output_shape = img[list(img.keys())[0]].shape[-2:]

# interpolate to the target spatial dims
output = F.interpolate(output, size=output_shape, mode="bilinear")

return output
Expand Down Expand Up @@ -249,6 +262,19 @@ def __init__(
def forward(
self, img: dict[str, torch.Tensor], output_shape: torch.Size | None = None
) -> torch.Tensor:
"""Compute the segmentation output for multi-temporal data.
Args:
img (dict[str, torch.Tensor]): input data structured as a dictionary:
img = {modality1: tensor1, modality2: tensor2, ...}, e.g. img = {"optical": tensor1, "sar": tensor2}.
with tensor1 and tensor2 of shape (B C T H W) with C the number of encoders'bands for the given modality,
and T the number of time steps.
output_shape (torch.Size | None, optional): output's spatial dims (H, W) (equals to the target spatial dims).
Defaults to None.
Returns:
torch.Tensor: output tensor of shape (B, num_classes, H', W') with (H' W') coressponding to the output_shape.
"""
# If the encoder handles multi_temporal we feed it with the input
if self.encoder.multi_temporal:
if not self.finetune:
Expand Down Expand Up @@ -293,6 +319,8 @@ def forward(

if output_shape is None:
output_shape = img[list(img.keys())[0]].shape[-2:]

# interpolate to the target spatial dims
output = F.interpolate(output, size=output_shape, mode="bilinear")

return output
Expand Down
38 changes: 31 additions & 7 deletions pangaea/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,26 @@


class DownloadProgressBar:
def __init__(self, text="Downloading..."):
"""Download progress bar.
"""
def __init__(self, text: str="Downloading...") -> None:
"""Initialize the DownloadProgressBar.
Args:
text (str, optional): pbar text. Defaults to "Downloading...".
"""
self.pbar = None
self.text = text

def __call__(self, block_num, block_size, total_size):
def __call__(self, block_num: int, block_size: int, total_size: int) -> None:
"""Update the progress bar.
Args:
block_num (int): number of blocks.
block_size (int): size of the blocks.
total_size (int): total size of the download.
"""

if self.pbar is None:
self.pbar = tqdm.tqdm(
desc=self.text,
Expand Down Expand Up @@ -116,21 +131,30 @@ def freeze(self) -> None:
for param in self.parameters():
param.requires_grad = False

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Compute the forward pass of the encoder.
def forward(self, x:dict[str, torch.Tensor]) -> list[torch.Tensor]:
"""Foward pass of the encoder.
Args:
x (torch.Tensor): input image.
x (dict[str, torch.Tensor]): encoder's input structured as a dictionary:
x = {modality1: tensor1, modality2: tensor2, ...}, e.g. x = {"optical": tensor1, "sar": tensor2}.
If the encoder is multi-temporal (self.multi_temporal==True), input tensor shape is (B C T H W) with C the
number of bands required by the encoder for the given modality and T the number of time steps. If the
encoder is not multi-temporal, input tensor shape is (B C H W) with C the number of bands required by the
encoder for the given modality.
Raises:
NotImplementedError: raise if the method is not implemented.
Returns:
torch.Tensor: embedding generated by the encoder.
list[torch.Tensor]: list of the embeddings for each modality. For single-temporal encoders, the list's
elements are of shape (B, embed_dim, H', W'). For multi-temporal encoders, the list's elements are of shape
(B, C', T, H', W') with T the number of time steps if the encoder does not have any time-merging strategy,
else (B, C', H', W') if the encoder has a time-merging strategy (where C'==self.output_dim).
"""
raise NotImplementedError

def download_model(self) -> None:
"""Download the model if the weights are not already downloaded.
"""
if self.download_url and not os.path.isfile(self.encoder_weights):
# TODO: change this path
os.makedirs("pretrained_models", exist_ok=True)
Expand Down

0 comments on commit 80ea64d

Please sign in to comment.