Skip to content

Commit a95d680

Browse files
authored
feat: Added implementation of SS-CAM (#11)
* feat: Added implementation of Smoothed Score-CAM * test: Updated unittests * docs: Updated documentation * feat: Updated example script * docs: Updated readme * refactor: Refactored SS-CAM
1 parent fb3be81 commit a95d680

File tree

6 files changed

+121
-18
lines changed

6 files changed

+121
-18
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ This project is developed and maintained by the repo owner, but the implementati
9393
- [Grad-CAM++](https://arxiv.org/abs/1710.11063): improvement of GradCAM++ for more accurate pixel-level contribution to the activation.
9494
- [Smooth Grad-CAM++](https://arxiv.org/abs/1908.01224): SmoothGrad mechanism coupled with GradCAM.
9595
- [Score-CAM](https://arxiv.org/abs/1910.01279): score-weighting of class activation for better interpretability.
96+
- [SS-CAM](https://arxiv.org/abs/2006.14255): SmoothGrad mechanism coupled with Score-CAM.
9697

9798

9899

docs/source/cams.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ Methods related to activation-based class activation maps.
1414

1515
.. autoclass:: ScoreCAM
1616

17+
.. autoclass:: SSCAM
18+
1719

1820
Grad-CAM
1921
--------

scripts/cam_example.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torchvision import models
1616
from torchvision.transforms.functional import normalize, resize, to_tensor, to_pil_image
1717

18-
from torchcam.cams import CAM, GradCAM, GradCAMpp, SmoothGradCAMpp, ScoreCAM
18+
from torchcam.cams import CAM, GradCAM, GradCAMpp, SmoothGradCAMpp, ScoreCAM, SSCAM
1919
from torchcam.utils import overlay_mask
2020

2121
VGG_CONFIG = {_vgg: dict(input_layer='features', conv_layer='features')
@@ -58,9 +58,9 @@ def main(args):
5858
# Hook the corresponding layer in the model
5959
cam_extractors = [CAM(model, conv_layer, fc_layer), GradCAM(model, conv_layer),
6060
GradCAMpp(model, conv_layer), SmoothGradCAMpp(model, conv_layer, input_layer),
61-
ScoreCAM(model, conv_layer, input_layer)]
61+
ScoreCAM(model, conv_layer, input_layer), SSCAM(model, conv_layer, input_layer)]
6262

63-
fig, axes = plt.subplots(1, len(cam_extractors))
63+
fig, axes = plt.subplots(1, len(cam_extractors), figsize=(7, 2))
6464
for idx, extractor in enumerate(cam_extractors):
6565
model.zero_grad()
6666
scores = model(img_tensor.unsqueeze(0))
@@ -80,7 +80,7 @@ def main(args):
8080

8181
axes[idx].imshow(result)
8282
axes[idx].axis('off')
83-
axes[idx].set_title(extractor.__class__.__name__, size=10)
83+
axes[idx].set_title(extractor.__class__.__name__, size=8)
8484

8585
plt.tight_layout()
8686
if args.savefig:
@@ -95,7 +95,7 @@ def main(args):
9595
parser.add_argument("--img", type=str,
9696
default='https://www.woopets.fr/assets/races/000/066/big-portrait/border-collie.jpg',
9797
help="The image to extract CAM from")
98-
parser.add_argument("--class-idx", type=int, default=None, help='Index of the class to inspect')
98+
parser.add_argument("--class-idx", type=int, default=232, help='Index of the class to inspect')
9999
parser.add_argument("--device", type=str, default=None, help='Default device to perform computation on')
100100
parser.add_argument("--savefig", type=str, default=None, help="Path to save figure")
101101
args = parser.parse_args()

static/images/cam_example.png

31.4 KB
Loading

test/test_cams.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_smooth_gradcampp(self):
101101
self._test_extractor(extractor, model)
102102

103103

104-
for cam_extractor in ['CAM', 'ScoreCAM']:
104+
for cam_extractor in ['CAM', 'ScoreCAM', 'SSCAM']:
105105
def do_test(self, cam_extractor=cam_extractor):
106106
self._test_cam(cam_extractor)
107107

torchcam/cams/cam.py

Lines changed: 112 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torch.nn.functional as F
1111

12-
__all__ = ['CAM', 'ScoreCAM']
12+
__all__ = ['CAM', 'ScoreCAM', 'SSCAM']
1313

1414

1515
class _CAM(object):
@@ -128,7 +128,8 @@ class CAM(_CAM):
128128
129129
where :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at
130130
position :math:`(x, y)`,
131-
and :math:`w_k^{(c)}` is the weight corresponding to class :math:`c` for unit :math:`k`.
131+
and :math:`w_k^{(c)}` is the weight corresponding to class :math:`c` for unit :math:`k` in the fully
132+
connected layer..
132133
133134
Example::
134135
>>> from torchvision.models import resnet18
@@ -172,18 +173,18 @@ class ScoreCAM(_CAM):
172173
with the coefficient :math:`w_k^{(c)}` being defined as:
173174
174175
.. math::
175-
w_k^{(c)} = softmax(Y^{(c)}(M) - Y^{(c)}(X_b))
176+
w_k^{(c)} = softmax(Y^{(c)}(M_k) - Y^{(c)}(X_b))
176177
177178
where :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at
178179
position :math:`(x, y)`, :math:`Y^{(c)}(X)` is the model output score for class :math:`c` before softmax
179180
for input :math:`X`, :math:`X_b` is a baseline image,
180-
and :math:`M` is defined as follows:
181+
and :math:`M_k` is defined as follows:
181182
182183
.. math::
183-
M = \\Big(\\frac{M^{(d)} - \\min M^{(d)}}{\\max M^{(d)} - \\min M^{(d)}} \\odot X \\Big)_{1 \\leq d \\leq D}
184+
M_k = \\frac{U(A_k) - \\min\\limits_m U(A_m)}{\\max\\limits_m U(A_m) - \\min\\limits_m U(A_m)})
185+
\\odot X
184186
185-
where :math:`\\odot` refers to the element-wise multiplication, :math:`M^{(d)}` is the upsampled version of
186-
:math:`A_d` on node :math:`d`, and :math:`D` is the number of channels on the target convolutional layer.
187+
where :math:`\\odot` refers to the element-wise multiplication and :math:`U` is the upsampling operation.
187188
188189
Example::
189190
>>> from torchvision.models import resnet18
@@ -222,12 +223,12 @@ def _store_input(self, module, input):
222223
def _get_weights(self, class_idx, scores=None):
223224
"""Computes the weight coefficients of the hooked activation maps"""
224225

225-
# Upsample activation to input_size
226-
# 1 * O * M * N
227-
upsampled_a = F.interpolate(self.hook_a, self._input.shape[-2:], mode='bilinear', align_corners=False)
226+
# Normalize the activation
227+
upsampled_a = self._normalize(self.hook_a)
228228

229-
# Normalize it
230-
upsampled_a = self._normalize(upsampled_a)
229+
# Upsample it to input_size
230+
# 1 * O * M * N
231+
upsampled_a = F.interpolate(upsampled_a, self._input.shape[-2:], mode='bilinear', align_corners=False)
231232

232233
# Use it as a mask
233234
# O * I * H * W
@@ -253,3 +254,102 @@ def _get_weights(self, class_idx, scores=None):
253254

254255
def __repr__(self):
255256
return f"{self.__class__.__name__}(batch_size={self.bs})"
257+
258+
259+
class SSCAM(ScoreCAM):
260+
"""Implements a class activation map extractor as described in `"SS-CAM: Smoothed Score-CAM for
261+
Sharper Visual Feature Localization" <https://arxiv.org/pdf/2006.14255.pdf>`_.
262+
263+
The localization map is computed as follows:
264+
265+
.. math::
266+
L^{(c)}_{SS-CAM}(x, y) = ReLU\\Big(\\sum\\limits_k w_k^{(c)} A_k(x, y)\\Big)
267+
268+
with the coefficient :math:`w_k^{(c)}` being defined as:
269+
270+
.. math::
271+
w_k^{(c)} = \\frac{1}{N} \\sum\\limits_1^N softmax(Y^{(c)}(M_k) - Y^{(c)}(X_b))
272+
273+
where :math:`N` is the number of samples used to smooth the weights,
274+
:math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at
275+
position :math:`(x, y)`, :math:`Y^{(c)}(X)` is the model output score for class :math:`c` before softmax
276+
for input :math:`X`, :math:`X_b` is a baseline image,
277+
and :math:`M_k` is defined as follows:
278+
279+
.. math::
280+
M_k = \\Bigg(\\frac{U(A_k) - \\min\\limits_m U(A_m)}{\\max\\limits_m U(A_m) - \\min\\limits_m U(A_m)} +
281+
\\delta\\Bigg) \\odot X
282+
283+
where :math:`\\odot` refers to the element-wise multiplication, :math:`U` is the upsampling operation,
284+
:math:`\\delta \\sim \\mathcal{N}(0, \\sigma^2)` is the random noise that follows a 0-mean gaussian distribution
285+
with a standard deviation of :math:`\\sigma`.
286+
287+
Example::
288+
>>> from torchvision.models import resnet18
289+
>>> from torchcam.cams import SSCAM
290+
>>> model = resnet18(pretrained=True).eval()
291+
>>> cam = SSCAM(model, 'layer4', 'conv1')
292+
>>> with torch.no_grad(): out = model(input_tensor)
293+
>>> cam(class_idx=100)
294+
295+
Args:
296+
model (torch.nn.Module): input model
297+
conv_layer (str): name of the last convolutional layer
298+
input_layer (str): name of the first layer
299+
batch_size (int, optional): batch size used to forward masked inputs
300+
num_samples (int, optional): number of noisy samples used for weight computation
301+
std (float, optional): standard deviation of the noise added to the normalized activation
302+
"""
303+
304+
hook_a = None
305+
hook_handles = []
306+
307+
def __init__(self, model, conv_layer, input_layer, batch_size=32, num_samples=35, std=2.0):
308+
309+
super().__init__(model, conv_layer, input_layer, batch_size)
310+
311+
self.num_samples = num_samples
312+
self.std = std
313+
self._distrib = torch.distributions.normal.Normal(0, self.std)
314+
315+
def _get_weights(self, class_idx, scores=None):
316+
"""Computes the weight coefficients of the hooked activation maps"""
317+
318+
# Normalize the activation
319+
upsampled_a = self._normalize(self.hook_a)
320+
321+
# Upsample it to input_size
322+
# 1 * O * M * N
323+
upsampled_a = F.interpolate(upsampled_a, self._input.shape[-2:], mode='bilinear', align_corners=False)
324+
325+
# Use it as a mask
326+
# O * I * H * W
327+
upsampled_a = upsampled_a.squeeze(0).unsqueeze(1)
328+
329+
# Initialize weights
330+
weights = torch.zeros(upsampled_a.shape[0], dtype=upsampled_a.dtype).to(device=upsampled_a.device)
331+
332+
# Disable hook updates
333+
self._hooks_enabled = False
334+
335+
for _idx in range(self.num_samples):
336+
noisy_m = self._input * (upsampled_a +
337+
self._distrib.sample(self._input.size()).to(device=self._input.device))
338+
339+
# Process by chunk (GPU RAM limitation)
340+
for idx in range(math.ceil(weights.shape[0] / self.bs)):
341+
342+
selection_slice = slice(idx * self.bs, min((idx + 1) * self.bs, weights.shape[0]))
343+
with torch.no_grad():
344+
# Get the softmax probabilities of the target class
345+
weights[selection_slice] += F.softmax(self.model(noisy_m[selection_slice]), dim=1)[:, class_idx]
346+
347+
weights /= self.num_samples
348+
349+
# Reenable hook updates
350+
self._hooks_enabled = True
351+
352+
return weights
353+
354+
def __repr__(self):
355+
return f"{self.__class__.__name__}(batch_size={self.bs}, num_samples={self.num_samples}, std={self.std})"

0 commit comments

Comments
 (0)