Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed MinMax normalization #112

Merged
merged 4 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions configs/preprocessing/reg_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ train:
preprocessor_cfg:
- _target_: pangaea.engine.data_preprocessor.RandomCropToEncoder
- _target_: pangaea.engine.data_preprocessor.BandFilter
- _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd
- _target_: pangaea.engine.data_preprocessor.NormalizeMinMax
- _target_: pangaea.engine.data_preprocessor.BandPadding

val:
_target_: pangaea.engine.data_preprocessor.Preprocessor
preprocessor_cfg:
- _target_: pangaea.engine.data_preprocessor.BandFilter
- _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd
- _target_: pangaea.engine.data_preprocessor.NormalizeMinMax
- _target_: pangaea.engine.data_preprocessor.BandPadding

test:
_target_: pangaea.engine.data_preprocessor.Preprocessor
preprocessor_cfg:
- _target_: pangaea.engine.data_preprocessor.BandFilter
- _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd
- _target_: pangaea.engine.data_preprocessor.NormalizeMinMax
- _target_: pangaea.engine.data_preprocessor.BandPadding
13 changes: 6 additions & 7 deletions pangaea/engine/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,11 @@ def __init__(
self,
**meta,
) -> None:
"""Initialize the NormalizeMeanStd.
"""Initialize the NormalizeMinMax.
Args:
meta: statistics/info of the input data and target encoder
data_min: global maximum value of incoming data
data_sax: global minimum value of incoming data
data_min: global minimum value of incoming data
data_max: global maximum value of incoming data
"""
super().__init__()

Expand All @@ -365,7 +365,7 @@ def __init__(
def __call__(
self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]]
) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
"""Apply Mean/Std Normalization to the data.
"""Apply Min/Max Normalization to the data.
Args:
data (dict): input data.
Returns:
Expand All @@ -381,9 +381,8 @@ def __call__(
"""

for k in self.data_min.keys():
size = (-1,) + data["image"][k].shape[1:]
data["image"][k].sub_(self.data_min[k].view(size)).div_(
(self.data_max[k] - self.data_min[k]).view(size)
data["image"][k].sub_(self.data_min[k].view(-1, 1, 1, 1)).div_(
(self.data_max[k] - self.data_min[k]).view(-1, 1, 1, 1)
)
return data

Expand Down