diff --git a/pangaea/decoders/upernet.py b/pangaea/decoders/upernet.py index 5646ff8f..a725a221 100644 --- a/pangaea/decoders/upernet.py +++ b/pangaea/decoders/upernet.py @@ -173,8 +173,19 @@ def _forward_feature(self, inputs): feats = self.fpn_bottleneck(fpn_outs) return feats - def forward(self, img, output_shape=None): - """Forward function.""" + def forward(self, img: dict[str, torch.Tensor], output_shape: torch.Size | None =None) -> torch.Tensor: + """Compute the segmentation output. + + Args: + img (dict[str, torch.Tensor]): input data structured as a dictionary: + img = {modality1: tensor1, modality2: tensor2, ...}, e.g. img = {"optical": tensor1, "sar": tensor2}. + with tensor1 and tensor2 of shape (B C T=1 H W) with C the number of encoders'bands for the given modality. + output_shape (torch.Size | None, optional): output's spatial dims (H, W) (equals to the target spatial dims). + Defaults to None. + + Returns: + torch.Tensor: output tensor of shape (B, num_classes, H', W') with (H' W') coressponding to the output_shape. + """ # img[modality] of shape [B C T=1 H W] if self.encoder.multi_temporal: @@ -206,6 +217,8 @@ def forward(self, img, output_shape=None): # fixed bug just for optical single modality if output_shape is None: output_shape = img[list(img.keys())[0]].shape[-2:] + + # interpolate to the target spatial dims output = F.interpolate(output, size=output_shape, mode="bilinear") return output @@ -249,6 +262,19 @@ def __init__( def forward( self, img: dict[str, torch.Tensor], output_shape: torch.Size | None = None ) -> torch.Tensor: + """Compute the segmentation output for multi-temporal data. + + Args: + img (dict[str, torch.Tensor]): input data structured as a dictionary: + img = {modality1: tensor1, modality2: tensor2, ...}, e.g. img = {"optical": tensor1, "sar": tensor2}. + with tensor1 and tensor2 of shape (B C T H W) with C the number of encoders'bands for the given modality, + and T the number of time steps. + output_shape (torch.Size | None, optional): output's spatial dims (H, W) (equals to the target spatial dims). + Defaults to None. + + Returns: + torch.Tensor: output tensor of shape (B, num_classes, H', W') with (H' W') coressponding to the output_shape. + """ # If the encoder handles multi_temporal we feed it with the input if self.encoder.multi_temporal: if not self.finetune: @@ -293,6 +319,8 @@ def forward( if output_shape is None: output_shape = img[list(img.keys())[0]].shape[-2:] + + # interpolate to the target spatial dims output = F.interpolate(output, size=output_shape, mode="bilinear") return output diff --git a/pangaea/encoders/base.py b/pangaea/encoders/base.py index 0ad3c238..7075f5c8 100644 --- a/pangaea/encoders/base.py +++ b/pangaea/encoders/base.py @@ -11,11 +11,26 @@ class DownloadProgressBar: - def __init__(self, text="Downloading..."): + """Download progress bar. + """ + def __init__(self, text: str="Downloading...") -> None: + """Initialize the DownloadProgressBar. + + Args: + text (str, optional): pbar text. Defaults to "Downloading...". + """ self.pbar = None self.text = text - def __call__(self, block_num, block_size, total_size): + def __call__(self, block_num: int, block_size: int, total_size: int) -> None: + """Update the progress bar. + + Args: + block_num (int): number of blocks. + block_size (int): size of the blocks. + total_size (int): total size of the download. + """ + if self.pbar is None: self.pbar = tqdm.tqdm( desc=self.text, @@ -116,21 +131,30 @@ def freeze(self) -> None: for param in self.parameters(): param.requires_grad = False - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Compute the forward pass of the encoder. + def forward(self, x:dict[str, torch.Tensor]) -> list[torch.Tensor]: + """Foward pass of the encoder. Args: - x (torch.Tensor): input image. - + x (dict[str, torch.Tensor]): encoder's input structured as a dictionary: + x = {modality1: tensor1, modality2: tensor2, ...}, e.g. x = {"optical": tensor1, "sar": tensor2}. + If the encoder is multi-temporal (self.multi_temporal==True), input tensor shape is (B C T H W) with C the + number of bands required by the encoder for the given modality and T the number of time steps. If the + encoder is not multi-temporal, input tensor shape is (B C H W) with C the number of bands required by the + encoder for the given modality. Raises: NotImplementedError: raise if the method is not implemented. Returns: - torch.Tensor: embedding generated by the encoder. + list[torch.Tensor]: list of the embeddings for each modality. For single-temporal encoders, the list's + elements are of shape (B, embed_dim, H', W'). For multi-temporal encoders, the list's elements are of shape + (B, C', T, H', W') with T the number of time steps if the encoder does not have any time-merging strategy, + else (B, C', H', W') if the encoder has a time-merging strategy (where C'==self.output_dim). """ raise NotImplementedError def download_model(self) -> None: + """Download the model if the weights are not already downloaded. + """ if self.download_url and not os.path.isfile(self.encoder_weights): # TODO: change this path os.makedirs("pretrained_models", exist_ok=True)