Skip to content

Commit

Permalink
GlobalMaxAvgPoolingClassificationHead
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Mar 1, 2024
1 parent d7dd14f commit 1634a3b
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions pytorch_toolbelt/modules/heads/classification_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import torch
from torch import nn, Tensor

from pytorch_toolbelt.modules import instantiate_activation_block
from pytorch_toolbelt.modules.interfaces import AbstractHead, FeatureMapsSpecification

__all__ = [
"GlobalAveragePoolingClassificationHead",
"GlobalMaxPoolingClassificationHead",
"GenericPoolingClassificationHead",
"FullyConnectedClassificationHead",
"GlobalMaxAvgPoolingClassificationHead",
]


Expand Down Expand Up @@ -74,6 +76,43 @@ def __init__(
)


class GlobalMaxAvgPoolingClassificationHead(AbstractHead):
def __init__(
self,
*,
input_spec: FeatureMapsSpecification,
num_classes: int,
activation: str,
dropout_rate: float = 0.0,
feature_map_index: int = -1,
):
super().__init__(input_spec)
self.max_pooling = nn.AdaptiveMaxPool2d((1, 1))
self.avg_pooling = nn.AdaptiveAvgPool2d((1, 1))
self.feature_map_index = feature_map_index
num_channels = input_spec.channels[self.feature_map_index]
self.bottleneck = nn.Sequential(
nn.BatchNorm1d(num_channels * 2),
nn.Linear(num_channels * 2, num_channels),
instantiate_activation_block(activation, inplace=True),
nn.Dropout(dropout_rate),
nn.BatchNorm1d(num_channels),
nn.Linear(num_channels, num_channels),
instantiate_activation_block(activation, inplace=True),
nn.Dropout(dropout_rate),
)
self.classifier = nn.Linear(num_channels, num_classes)

def forward(self, feature_maps: List[Tensor]) -> Tensor:
x = feature_maps[self.feature_map_index]
x_max = self.max_pooling(x).flatten(start_dim=1)
x_avg = self.avg_pooling(x).flatten(start_dim=1)
x = torch.cat([x_max, x_avg], dim=1)
x = self.bottleneck(x)
x = self.classifier(x)
return x


class FullyConnectedClassificationHead(AbstractHead):
def __init__(
self,
Expand Down

0 comments on commit 1634a3b

Please sign in to comment.