Skip to content

Commit 7b1bfbd

Browse files
committed
FullyConnectedClassificationHead
1 parent cfbdf1f commit 7b1bfbd

File tree

1 file changed

+34
-3
lines changed

1 file changed

+34
-3
lines changed

pytorch_toolbelt/modules/heads/classification_heads.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import List
1+
from typing import List, Union, Tuple, Mapping
22

3+
import torch
34
from torch import nn, Tensor
45

56
from pytorch_toolbelt.modules.interfaces import AbstractHead, FeatureMapsSpecification
@@ -8,6 +9,7 @@
89
"GlobalAveragePoolingClassificationHead",
910
"GlobalMaxPoolingClassificationHead",
1011
"GenericPoolingClassificationHead",
12+
"FullyConnectedClassificationHead",
1113
]
1214

1315

@@ -37,7 +39,11 @@ def forward(self, feature_maps: List[Tensor]) -> Tensor:
3739

3840
class GlobalMaxPoolingClassificationHead(GenericPoolingClassificationHead):
3941
def __init__(
40-
self, input_spec: FeatureMapsSpecification, num_classes: int, dropout_rate: float = 0.0, feature_map_index: int = -1
42+
self,
43+
input_spec: FeatureMapsSpecification,
44+
num_classes: int,
45+
dropout_rate: float = 0.0,
46+
feature_map_index: int = -1,
4147
):
4248
pooling = nn.AdaptiveMaxPool2d((1, 1))
4349
super().__init__(
@@ -51,7 +57,11 @@ def __init__(
5157

5258
class GlobalAveragePoolingClassificationHead(GenericPoolingClassificationHead):
5359
def __init__(
54-
self, input_spec: FeatureMapsSpecification, num_classes: int, dropout_rate: float = 0.0, feature_map_index: int = -1
60+
self,
61+
input_spec: FeatureMapsSpecification,
62+
num_classes: int,
63+
dropout_rate: float = 0.0,
64+
feature_map_index: int = -1,
5565
):
5666
pooling = nn.AdaptiveAvgPool2d((1, 1))
5767

@@ -62,3 +72,24 @@ def __init__(
6272
dropout_rate=dropout_rate,
6373
feature_map_index=feature_map_index,
6474
)
75+
76+
77+
class FullyConnectedClassificationHead(AbstractHead):
78+
def __init__(
79+
self,
80+
input_spec: FeatureMapsSpecification,
81+
num_classes: int,
82+
dropout_rate: float = 0.0,
83+
feature_map_index: int = -1,
84+
):
85+
super().__init__(input_spec)
86+
self.feature_map_index = feature_map_index
87+
self.dropout = nn.Dropout(dropout_rate)
88+
self.classifier = nn.LazyLinear(num_classes)
89+
90+
def forward(self, feature_maps: List[Tensor]) -> Tensor:
91+
x = feature_maps[self.feature_map_index]
92+
x = torch.flatten(x, start_dim=1)
93+
x = self.dropout(x)
94+
x = self.classifier(x)
95+
return x

0 commit comments

Comments
 (0)