Skip to content

Commit

Permalink
Merge pull request #74 from yurujaja/73-fix-multi-temporal-output
Browse files Browse the repository at this point in the history
73 fix multi temporal output
  • Loading branch information
VMarsocci authored Sep 27, 2024
2 parents 774d4c4 + b1dd07b commit f0f4676
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
3 changes: 2 additions & 1 deletion configs/encoder/dofa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ wave_list:
ASC_VH: 3.75
DSC_VV: 3.75
DSC_VH: 3.75
VV-VH: 3.75

output_layers:
- 3
- 5
- 7
- 11
- 11
14 changes: 11 additions & 3 deletions pangaea/decoders/upernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ def forward(self, img, output_shape=None):
feat = self.encoder(img)
else:
feat = self.encoder(img)

# multi_temporal models can return either (B C' T=1 H' W')
# or (B C' H' W'), we need (B C' H' W')
if feat[0].ndim == 5:
feat = [f.squeeze(-3) for f in feat]

else:
# remove the temporal dim
# [B C T=1 H W] -> [B C H W]
Expand Down Expand Up @@ -250,6 +256,11 @@ def forward(
feats = self.encoder(img)
else:
feats = self.encoder(img)
# multi_temporal models can return either (B C' T H' W')
# or (B C' H' W') via internal merging strategy
# if we have (B C' H' W') we need to skip multi_temporal_strategy
if feats[0].ndim == 4:
self.multi_temporal_strategy = None

# If the encoder handles only single temporal data, we apply multi_temporal_strategy
else:
Expand Down Expand Up @@ -350,9 +361,6 @@ def forward(
else:
feat1, feat2 = self.encoder_forward(img)

print("LEN ", len(feat1))
print("SHHAPE ", feat1[0].shape)

if self.strategy == "diff":
feat = [f2 - f1 for f1, f2 in zip(feat1, feat2)]
elif self.strategy == "concat":
Expand Down

0 comments on commit f0f4676

Please sign in to comment.