Skip to content

Commit

Permalink
FullyConnectedClassificationHead
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Feb 12, 2024
1 parent cfbdf1f commit 7b1bfbd
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions pytorch_toolbelt/modules/heads/classification_heads.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List
from typing import List, Union, Tuple, Mapping

import torch
from torch import nn, Tensor

from pytorch_toolbelt.modules.interfaces import AbstractHead, FeatureMapsSpecification
Expand All @@ -8,6 +9,7 @@
"GlobalAveragePoolingClassificationHead",
"GlobalMaxPoolingClassificationHead",
"GenericPoolingClassificationHead",
"FullyConnectedClassificationHead",
]


Expand Down Expand Up @@ -37,7 +39,11 @@ def forward(self, feature_maps: List[Tensor]) -> Tensor:

class GlobalMaxPoolingClassificationHead(GenericPoolingClassificationHead):
def __init__(
self, input_spec: FeatureMapsSpecification, num_classes: int, dropout_rate: float = 0.0, feature_map_index: int = -1
self,
input_spec: FeatureMapsSpecification,
num_classes: int,
dropout_rate: float = 0.0,
feature_map_index: int = -1,
):
pooling = nn.AdaptiveMaxPool2d((1, 1))
super().__init__(
Expand All @@ -51,7 +57,11 @@ def __init__(

class GlobalAveragePoolingClassificationHead(GenericPoolingClassificationHead):
def __init__(
self, input_spec: FeatureMapsSpecification, num_classes: int, dropout_rate: float = 0.0, feature_map_index: int = -1
self,
input_spec: FeatureMapsSpecification,
num_classes: int,
dropout_rate: float = 0.0,
feature_map_index: int = -1,
):
pooling = nn.AdaptiveAvgPool2d((1, 1))

Expand All @@ -62,3 +72,24 @@ def __init__(
dropout_rate=dropout_rate,
feature_map_index=feature_map_index,
)


class FullyConnectedClassificationHead(AbstractHead):
def __init__(
self,
input_spec: FeatureMapsSpecification,
num_classes: int,
dropout_rate: float = 0.0,
feature_map_index: int = -1,
):
super().__init__(input_spec)
self.feature_map_index = feature_map_index
self.dropout = nn.Dropout(dropout_rate)
self.classifier = nn.LazyLinear(num_classes)

def forward(self, feature_maps: List[Tensor]) -> Tensor:
x = feature_maps[self.feature_map_index]
x = torch.flatten(x, start_dim=1)
x = self.dropout(x)
x = self.classifier(x)
return x

0 comments on commit 7b1bfbd

Please sign in to comment.