Skip to content

Commit 7be0b4f

Browse files
authored
feat: Added original CAM implementation (#2)
* docs: Fixed usage instruction * feat: Added original CAM implementation * chore: Reorganized package * test: Added CAM unittest * refactor: Refactored CAM * refactor: Refactored GradCAMs * refactor: Refactored CAMs * test: Updated unittests accordingly * docs: Updated readme * style: Removed extra blank line * docs: Updated documentation
1 parent 29bb8f3 commit 7be0b4f

File tree

11 files changed

+305
-199
lines changed

11 files changed

+305
-199
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ import matplotlib.pyplot as plt
5050
from torchvision.models import resnet50
5151
from torchvision.transforms import transforms
5252
from torchvision.transforms.functional import to_pil_image
53-
from gradcam import GradCAM, GradCAMpp, overlay_mask
53+
from torchcam.cams import CAM, GradCAM, GradCAMpp
54+
from torchcam.utils import overlay_mask
5455

5556

5657
# Pretrained imagenet model
@@ -81,7 +82,7 @@ classes = {int(key):value for (key, value)
8182
class_idx = 232
8283

8384
# Use the hooked data to compute activation map
84-
activation_maps = gradcam.get_activation_maps(out, class_idx)
85+
activation_maps = gradcam(out, class_idx)
8586
# Convert it to PIL image
8687
# The indexing below means first image in batch
8788
heatmap = to_pil_image(activation_maps[0].cpu().numpy(), mode='F')
@@ -101,7 +102,7 @@ plt.imshow(result); plt.axis('off'); plt.title(classes.get(class_idx)); plt.tigh
101102

102103
The project is currently under development, here are the objectives for the next releases:
103104

104-
- [ ] Parallel CAMs: enable batch processing.
105+
- [x] Parallel CAMs: enable batch processing.
105106
- [ ] Benchmark: compare class activation map computations for different architectures.
106107
- [ ] Signature improvement: retrieve automatically the last convolutional layer.
107108
- [ ] Refine RPN: create a region proposal network using CAM.

docs/source/cams.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
torchcam.cams
2+
=============
3+
4+
5+
.. currentmodule:: torchcam.cams
6+
7+
8+
CAM
9+
--------
10+
Related to activation-based class activation maps.
11+
12+
13+
.. autoclass:: CAM
14+
15+
16+
Grad-CAM
17+
--------
18+
Related to gradient-based class activation maps.
19+
20+
21+
.. autoclass:: GradCAM
22+
23+
24+
.. autoclass:: GradCAMpp

docs/source/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ The :mod:`torchcam` package gives PyTorch users the possibility to visualize the
77
:maxdepth: 1
88
:caption: Package Reference
99

10-
torchcam
10+
cams
1111
utils
1212

1313

docs/source/torchcam.rst

Lines changed: 0 additions & 16 deletions
This file was deleted.

test/test_cams.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import unittest
2+
import requests
3+
from io import BytesIO
4+
from PIL import Image
5+
import torch
6+
from torchvision.models import resnet18, mobilenet_v2
7+
from torchvision.transforms.functional import resize, to_tensor, normalize
8+
9+
from torchcam import cams
10+
11+
12+
class Tester(unittest.TestCase):
13+
14+
def _verify_cam(self, cam):
15+
# Simple verifications
16+
self.assertIsInstance(cam, torch.Tensor)
17+
self.assertEqual(cam.shape, (1, 7, 7))
18+
19+
@staticmethod
20+
def _get_img_tensor():
21+
22+
# Get a dog image
23+
URL = 'https://www.woopets.fr/assets/races/000/066/big-portrait/border-collie.jpg'
24+
response = requests.get(URL)
25+
26+
# Forward an image
27+
pil_img = Image.open(BytesIO(response.content), mode='r').convert('RGB')
28+
img_tensor = normalize(to_tensor(resize(pil_img, (224, 224))),
29+
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
30+
31+
return img_tensor
32+
33+
def test_cam(self):
34+
# Get a pretrained model
35+
model = resnet18(pretrained=True).eval()
36+
conv_layer = 'layer4'
37+
fc_layer = 'fc'
38+
# Border collie index in ImageNet
39+
class_idx = 232
40+
41+
# Hook the corresponding layer in the model
42+
extractor = cams.CAM(model, conv_layer, fc_layer)
43+
44+
# Get a dog image
45+
img_tensor = self._get_img_tensor()
46+
# Forward it
47+
with torch.no_grad():
48+
_ = model(img_tensor.unsqueeze(0))
49+
50+
# Use the hooked data to compute activation map
51+
self._verify_cam(extractor(class_idx))
52+
53+
def _test_gradcam(self, name):
54+
55+
# Get a pretrained model
56+
model = mobilenet_v2(pretrained=True)
57+
conv_layer = 'features'
58+
# Border collie index in ImageNet
59+
class_idx = 232
60+
61+
# Hook the corresponding layer in the model
62+
extractor = cams.__dict__[name](model, conv_layer)
63+
64+
# Get a dog image
65+
img_tensor = self._get_img_tensor()
66+
67+
# Forward an image
68+
out = model(img_tensor.unsqueeze(0))
69+
70+
# Use the hooked data to compute activation map
71+
self._verify_cam(extractor(out, class_idx))
72+
73+
74+
for cam_extractor in ['GradCAM', 'GradCAMpp']:
75+
def do_test(self, cam_extractor=cam_extractor):
76+
self._test_gradcam(cam_extractor)
77+
78+
setattr(Tester, "test_" + cam_extractor.lower(), do_test)
79+
80+
81+
if __name__ == '__main__':
82+
unittest.main()

test/test_gradcam.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

torchcam/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torchcam.gradcam import *
1+
from torchcam import cams
22
from torchcam import utils
33

44

torchcam/cams/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .cam import *
2+
from .gradcam import *
3+
4+
del cam
5+
del gradcam

torchcam/cams/cam.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
"""
5+
GradCAM
6+
"""
7+
8+
import torch
9+
10+
11+
__all__ = ['CAM']
12+
13+
14+
class _CAM(object):
15+
"""Implements a class activation map extractor
16+
17+
Args:
18+
model (torch.nn.Module): input model
19+
conv_layer (str): name of the last convolutional layer
20+
"""
21+
22+
hook_a = None
23+
24+
def __init__(self, model, conv_layer):
25+
26+
if not hasattr(model, conv_layer):
27+
raise ValueError(f"Unable to find submodule {conv_layer} in the model")
28+
self.model = model
29+
# Forward hook
30+
self.model._modules.get(conv_layer).register_forward_hook(self._hook_a)
31+
32+
def _hook_a(self, module, input, output):
33+
self.hook_a = output.data
34+
35+
@staticmethod
36+
def _normalize(cams):
37+
cams -= cams.flatten(start_dim=1).min().view(-1, 1, 1)
38+
cams /= cams.flatten(start_dim=1).max().view(-1, 1, 1)
39+
40+
return cams
41+
42+
def _get_weights(self, class_idx):
43+
44+
raise NotImplementedError
45+
46+
def __call__(self, class_idx, normalized=True):
47+
48+
# Get map weight
49+
weights = self._get_weights(class_idx)
50+
51+
# Perform the weighted combination to get the CAM
52+
batch_cams = (weights.view(-1, 1, 1) * self.hook_a).sum(dim=1)
53+
54+
# Normalize the CAM
55+
if normalized:
56+
batch_cams = self._normalize(batch_cams)
57+
58+
return batch_cams
59+
60+
61+
class CAM(_CAM):
62+
"""Implements a class activation map extractor as described in https://arxiv.org/abs/1512.04150
63+
64+
Args:
65+
model (torch.nn.Module): input model
66+
conv_layer (str): name of the last convolutional layer
67+
"""
68+
69+
hook_a = None
70+
71+
def __init__(self, model, conv_layer, fc_layer):
72+
73+
super().__init__(model, conv_layer)
74+
# Softmax weight
75+
self._fc_weights = self.model._modules.get(fc_layer).weight.data
76+
77+
def _get_weights(self, class_idx):
78+
79+
# Take the FC weights of the target class
80+
return self._fc_weights[class_idx, :]

0 commit comments

Comments
 (0)