99import torch
1010import torch .nn .functional as F
1111
12- __all__ = ['CAM' , 'ScoreCAM' ]
12+ __all__ = ['CAM' , 'ScoreCAM' , 'SSCAM' ]
1313
1414
1515class _CAM (object ):
@@ -128,7 +128,8 @@ class CAM(_CAM):
128128
129129 where :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at
130130 position :math:`(x, y)`,
131- and :math:`w_k^{(c)}` is the weight corresponding to class :math:`c` for unit :math:`k`.
131+ and :math:`w_k^{(c)}` is the weight corresponding to class :math:`c` for unit :math:`k` in the fully
132+ connected layer..
132133
133134 Example::
134135 >>> from torchvision.models import resnet18
@@ -172,18 +173,18 @@ class ScoreCAM(_CAM):
172173 with the coefficient :math:`w_k^{(c)}` being defined as:
173174
174175 .. math::
175- w_k^{(c)} = softmax(Y^{(c)}(M ) - Y^{(c)}(X_b))
176+ w_k^{(c)} = softmax(Y^{(c)}(M_k ) - Y^{(c)}(X_b))
176177
177178 where :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at
178179 position :math:`(x, y)`, :math:`Y^{(c)}(X)` is the model output score for class :math:`c` before softmax
179180 for input :math:`X`, :math:`X_b` is a baseline image,
180- and :math:`M ` is defined as follows:
181+ and :math:`M_k ` is defined as follows:
181182
182183 .. math::
183- M = \\ Big(\\ frac{M^{(d)} - \\ min M^{(d)}}{\\ max M^{(d)} - \\ min M^{(d)}} \\ odot X \\ Big)_{1 \\ leq d \\ leq D}
184+ M_k = \\ frac{U(A_k) - \\ min\\ limits_m U(A_m)}{\\ max\\ limits_m U(A_m) - \\ min\\ limits_m U(A_m)})
185+ \\ odot X
184186
185- where :math:`\\ odot` refers to the element-wise multiplication, :math:`M^{(d)}` is the upsampled version of
186- :math:`A_d` on node :math:`d`, and :math:`D` is the number of channels on the target convolutional layer.
187+ where :math:`\\ odot` refers to the element-wise multiplication and :math:`U` is the upsampling operation.
187188
188189 Example::
189190 >>> from torchvision.models import resnet18
@@ -222,12 +223,12 @@ def _store_input(self, module, input):
222223 def _get_weights (self , class_idx , scores = None ):
223224 """Computes the weight coefficients of the hooked activation maps"""
224225
225- # Upsample activation to input_size
226- # 1 * O * M * N
227- upsampled_a = F .interpolate (self .hook_a , self ._input .shape [- 2 :], mode = 'bilinear' , align_corners = False )
226+ # Normalize the activation
227+ upsampled_a = self ._normalize (self .hook_a )
228228
229- # Normalize it
230- upsampled_a = self ._normalize (upsampled_a )
229+ # Upsample it to input_size
230+ # 1 * O * M * N
231+ upsampled_a = F .interpolate (upsampled_a , self ._input .shape [- 2 :], mode = 'bilinear' , align_corners = False )
231232
232233 # Use it as a mask
233234 # O * I * H * W
@@ -253,3 +254,102 @@ def _get_weights(self, class_idx, scores=None):
253254
254255 def __repr__ (self ):
255256 return f"{ self .__class__ .__name__ } (batch_size={ self .bs } )"
257+
258+
259+ class SSCAM (ScoreCAM ):
260+ """Implements a class activation map extractor as described in `"SS-CAM: Smoothed Score-CAM for
261+ Sharper Visual Feature Localization" <https://arxiv.org/pdf/2006.14255.pdf>`_.
262+
263+ The localization map is computed as follows:
264+
265+ .. math::
266+ L^{(c)}_{SS-CAM}(x, y) = ReLU\\ Big(\\ sum\\ limits_k w_k^{(c)} A_k(x, y)\\ Big)
267+
268+ with the coefficient :math:`w_k^{(c)}` being defined as:
269+
270+ .. math::
271+ w_k^{(c)} = \\ frac{1}{N} \\ sum\\ limits_1^N softmax(Y^{(c)}(M_k) - Y^{(c)}(X_b))
272+
273+ where :math:`N` is the number of samples used to smooth the weights,
274+ :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at
275+ position :math:`(x, y)`, :math:`Y^{(c)}(X)` is the model output score for class :math:`c` before softmax
276+ for input :math:`X`, :math:`X_b` is a baseline image,
277+ and :math:`M_k` is defined as follows:
278+
279+ .. math::
280+ M_k = \\ Bigg(\\ frac{U(A_k) - \\ min\\ limits_m U(A_m)}{\\ max\\ limits_m U(A_m) - \\ min\\ limits_m U(A_m)} +
281+ \\ delta\\ Bigg) \\ odot X
282+
283+ where :math:`\\ odot` refers to the element-wise multiplication, :math:`U` is the upsampling operation,
284+ :math:`\\ delta \\ sim \\ mathcal{N}(0, \\ sigma^2)` is the random noise that follows a 0-mean gaussian distribution
285+ with a standard deviation of :math:`\\ sigma`.
286+
287+ Example::
288+ >>> from torchvision.models import resnet18
289+ >>> from torchcam.cams import SSCAM
290+ >>> model = resnet18(pretrained=True).eval()
291+ >>> cam = SSCAM(model, 'layer4', 'conv1')
292+ >>> with torch.no_grad(): out = model(input_tensor)
293+ >>> cam(class_idx=100)
294+
295+ Args:
296+ model (torch.nn.Module): input model
297+ conv_layer (str): name of the last convolutional layer
298+ input_layer (str): name of the first layer
299+ batch_size (int, optional): batch size used to forward masked inputs
300+ num_samples (int, optional): number of noisy samples used for weight computation
301+ std (float, optional): standard deviation of the noise added to the normalized activation
302+ """
303+
304+ hook_a = None
305+ hook_handles = []
306+
307+ def __init__ (self , model , conv_layer , input_layer , batch_size = 32 , num_samples = 35 , std = 2.0 ):
308+
309+ super ().__init__ (model , conv_layer , input_layer , batch_size )
310+
311+ self .num_samples = num_samples
312+ self .std = std
313+ self ._distrib = torch .distributions .normal .Normal (0 , self .std )
314+
315+ def _get_weights (self , class_idx , scores = None ):
316+ """Computes the weight coefficients of the hooked activation maps"""
317+
318+ # Normalize the activation
319+ upsampled_a = self ._normalize (self .hook_a )
320+
321+ # Upsample it to input_size
322+ # 1 * O * M * N
323+ upsampled_a = F .interpolate (upsampled_a , self ._input .shape [- 2 :], mode = 'bilinear' , align_corners = False )
324+
325+ # Use it as a mask
326+ # O * I * H * W
327+ upsampled_a = upsampled_a .squeeze (0 ).unsqueeze (1 )
328+
329+ # Initialize weights
330+ weights = torch .zeros (upsampled_a .shape [0 ], dtype = upsampled_a .dtype ).to (device = upsampled_a .device )
331+
332+ # Disable hook updates
333+ self ._hooks_enabled = False
334+
335+ for _idx in range (self .num_samples ):
336+ noisy_m = self ._input * (upsampled_a +
337+ self ._distrib .sample (self ._input .size ()).to (device = self ._input .device ))
338+
339+ # Process by chunk (GPU RAM limitation)
340+ for idx in range (math .ceil (weights .shape [0 ] / self .bs )):
341+
342+ selection_slice = slice (idx * self .bs , min ((idx + 1 ) * self .bs , weights .shape [0 ]))
343+ with torch .no_grad ():
344+ # Get the softmax probabilities of the target class
345+ weights [selection_slice ] += F .softmax (self .model (noisy_m [selection_slice ]), dim = 1 )[:, class_idx ]
346+
347+ weights /= self .num_samples
348+
349+ # Reenable hook updates
350+ self ._hooks_enabled = True
351+
352+ return weights
353+
354+ def __repr__ (self ):
355+ return f"{ self .__class__ .__name__ } (batch_size={ self .bs } , num_samples={ self .num_samples } , std={ self .std } )"
0 commit comments