1
- from typing import List
1
+ from typing import List , Union , Tuple , Mapping
2
2
3
+ import torch
3
4
from torch import nn , Tensor
4
5
5
6
from pytorch_toolbelt .modules .interfaces import AbstractHead , FeatureMapsSpecification
8
9
"GlobalAveragePoolingClassificationHead" ,
9
10
"GlobalMaxPoolingClassificationHead" ,
10
11
"GenericPoolingClassificationHead" ,
12
+ "FullyConnectedClassificationHead" ,
11
13
]
12
14
13
15
@@ -37,7 +39,11 @@ def forward(self, feature_maps: List[Tensor]) -> Tensor:
37
39
38
40
class GlobalMaxPoolingClassificationHead (GenericPoolingClassificationHead ):
39
41
def __init__ (
40
- self , input_spec : FeatureMapsSpecification , num_classes : int , dropout_rate : float = 0.0 , feature_map_index : int = - 1
42
+ self ,
43
+ input_spec : FeatureMapsSpecification ,
44
+ num_classes : int ,
45
+ dropout_rate : float = 0.0 ,
46
+ feature_map_index : int = - 1 ,
41
47
):
42
48
pooling = nn .AdaptiveMaxPool2d ((1 , 1 ))
43
49
super ().__init__ (
@@ -51,7 +57,11 @@ def __init__(
51
57
52
58
class GlobalAveragePoolingClassificationHead (GenericPoolingClassificationHead ):
53
59
def __init__ (
54
- self , input_spec : FeatureMapsSpecification , num_classes : int , dropout_rate : float = 0.0 , feature_map_index : int = - 1
60
+ self ,
61
+ input_spec : FeatureMapsSpecification ,
62
+ num_classes : int ,
63
+ dropout_rate : float = 0.0 ,
64
+ feature_map_index : int = - 1 ,
55
65
):
56
66
pooling = nn .AdaptiveAvgPool2d ((1 , 1 ))
57
67
@@ -62,3 +72,24 @@ def __init__(
62
72
dropout_rate = dropout_rate ,
63
73
feature_map_index = feature_map_index ,
64
74
)
75
+
76
+
77
+ class FullyConnectedClassificationHead (AbstractHead ):
78
+ def __init__ (
79
+ self ,
80
+ input_spec : FeatureMapsSpecification ,
81
+ num_classes : int ,
82
+ dropout_rate : float = 0.0 ,
83
+ feature_map_index : int = - 1 ,
84
+ ):
85
+ super ().__init__ (input_spec )
86
+ self .feature_map_index = feature_map_index
87
+ self .dropout = nn .Dropout (dropout_rate )
88
+ self .classifier = nn .LazyLinear (num_classes )
89
+
90
+ def forward (self , feature_maps : List [Tensor ]) -> Tensor :
91
+ x = feature_maps [self .feature_map_index ]
92
+ x = torch .flatten (x , start_dim = 1 )
93
+ x = self .dropout (x )
94
+ x = self .classifier (x )
95
+ return x
0 commit comments