Skip to content

Commit 71c2756

Browse files
authored
feat: Added automatic layer name resolution (#32)
* test: Renamed testers * feat: Added layer resolution utils * feat: Added layer resolution to CAM * test: Added unittests for utils and simplified existing ones * docs: Updated README * refactor: Reflected changes on CAM interface * feat: Updated visualization example script * style: Fixed lint * test: Fixed unittests * feat: Forced eval switch when possible
1 parent 16e7641 commit 71c2756

File tree

8 files changed

+246
-115
lines changed

8 files changed

+246
-115
lines changed

README.md

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ Simple way to leverage the class-specific activation of convolutional layers in
1515
* [Prerequisites](#prerequisites)
1616
* [Installation](#installation)
1717
* [Usage](#usage)
18-
* [Technical Roadmap](#technical-roadmap)
1918
* [Documentation](#documentation)
2019
* [Contributing](#contributing)
2120
* [Credits](#credits)
@@ -58,20 +57,6 @@ python scripts/cam_example.py --model resnet50 --class-idx 232
5857

5958

6059

61-
62-
63-
## Technical roadmap
64-
65-
The project is currently under development, here are the objectives for the next releases:
66-
67-
- [x] Parallel CAMs: enable batch processing.
68-
- [x] Benchmark: compare class activation map computations for different architectures.
69-
- [ ] Signature improvement: retrieve automatically the specific required layer names.
70-
- [ ] Refined RPN: create a region proposal network using CAM.
71-
- [ ] Task transfer: turn a well-trained classifier into an object detector.
72-
73-
74-
7560
## Documentation
7661

7762
The full package documentation is available [here](https://frgfm.github.io/torch-cam/) for detailed specifications. The documentation was built with [Sphinx](sphinx-doc.org) using a [theme](github.com/readthedocs/sphinx_rtd_theme) provided by [Read the Docs](readthedocs.org).

scripts/cam_example.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#!usr/bin/python
2-
# -*- coding: utf-8 -*-
32

43
"""
54
CAM visualization
65
"""
76

7+
import math
88
import argparse
99
from io import BytesIO
1010

@@ -18,18 +18,18 @@
1818
from torchcam.cams import CAM, GradCAM, GradCAMpp, SmoothGradCAMpp, ScoreCAM, SSCAM, ISCAM
1919
from torchcam.utils import overlay_mask
2020

21-
VGG_CONFIG = {_vgg: dict(input_layer='features', conv_layer='features')
21+
VGG_CONFIG = {_vgg: dict(conv_layer='features')
2222
for _vgg in models.vgg.__dict__.keys()}
2323

24-
RESNET_CONFIG = {_resnet: dict(input_layer='conv1', conv_layer='layer4', fc_layer='fc')
24+
RESNET_CONFIG = {_resnet: dict(conv_layer='layer4', fc_layer='fc')
2525
for _resnet in models.resnet.__dict__.keys()}
2626

27-
DENSENET_CONFIG = {_densenet: dict(input_layer='features', conv_layer='features', fc_layer='classifier')
27+
DENSENET_CONFIG = {_densenet: dict(conv_layer='features', fc_layer='classifier')
2828
for _densenet in models.densenet.__dict__.keys()}
2929

3030
MODEL_CONFIG = {
3131
**VGG_CONFIG, **RESNET_CONFIG, **DENSENET_CONFIG,
32-
'mobilenet_v2': dict(input_layer='features', conv_layer='features')
32+
'mobilenet_v2': dict(conv_layer='features')
3333
}
3434

3535

@@ -43,7 +43,6 @@ def main(args):
4343
# Pretrained imagenet model
4444
model = models.__dict__[args.model](pretrained=True).eval().to(device=device)
4545
conv_layer = MODEL_CONFIG[args.model]['conv_layer']
46-
input_layer = MODEL_CONFIG[args.model]['input_layer']
4746
fc_layer = MODEL_CONFIG[args.model]['fc_layer']
4847

4948
# Image
@@ -57,15 +56,17 @@ def main(args):
5756

5857
# Hook the corresponding layer in the model
5958
cam_extractors = [CAM(model, conv_layer, fc_layer), GradCAM(model, conv_layer),
60-
GradCAMpp(model, conv_layer), SmoothGradCAMpp(model, conv_layer, input_layer),
61-
ScoreCAM(model, conv_layer, input_layer), SSCAM(model, conv_layer, input_layer),
62-
ISCAM(model, conv_layer, input_layer)]
59+
GradCAMpp(model, conv_layer), SmoothGradCAMpp(model, conv_layer),
60+
ScoreCAM(model, conv_layer), SSCAM(model, conv_layer),
61+
ISCAM(model, conv_layer)]
6362

6463
# Don't trigger all hooks
6564
for extractor in cam_extractors:
6665
extractor._hooks_enabled = False
6766

68-
fig, axes = plt.subplots(1, len(cam_extractors), figsize=(7, 2))
67+
num_rows = 2
68+
num_cols = math.ceil(len(cam_extractors) / num_rows)
69+
_, axes = plt.subplots(num_rows, num_cols, figsize=(6, 4))
6970
for idx, extractor in enumerate(cam_extractors):
7071
extractor._hooks_enabled = True
7172
model.zero_grad()
@@ -76,6 +77,7 @@ def main(args):
7677

7778
# Use the hooked data to compute activation map
7879
activation_map = extractor(class_idx, scores).cpu()
80+
7981
# Clean data
8082
extractor.clear_hooks()
8183
extractor._hooks_enabled = False
@@ -85,9 +87,13 @@ def main(args):
8587
# Plot the result
8688
result = overlay_mask(pil_img, heatmap)
8789

88-
axes[idx].imshow(result)
89-
axes[idx].axis('off')
90-
axes[idx].set_title(extractor.__class__.__name__, size=8)
90+
axes[idx // num_cols][idx % num_cols].imshow(result)
91+
axes[idx // num_cols][idx % num_cols].set_title(extractor.__class__.__name__, size=8)
92+
93+
# Clear axes
94+
for row in axes:
95+
for ax in row:
96+
ax.axis('off')
9197

9298
plt.tight_layout()
9399
if args.savefig:

test/test_cams.py

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import requests
55
import torch
66
from PIL import Image
7+
from torch import nn
78
from torchvision.models import mobilenet_v2, resnet18
89
from torchvision.transforms.functional import normalize, resize, to_tensor
910

@@ -20,7 +21,7 @@ def _forward(model, input_tensor):
2021
return scores
2122

2223

23-
class Tester(unittest.TestCase):
24+
class CAMCoreTester(unittest.TestCase):
2425
def _verify_cam(self, cam):
2526
# Simple verifications
2627
self.assertIsInstance(cam, torch.Tensor)
@@ -67,76 +68,91 @@ def _test_extractor(self, extractor, model):
6768

6869
def _test_cam(self, name):
6970
# Get a pretrained model
70-
model = resnet18(pretrained=False).eval()
71-
conv_layer = 'layer4'
72-
input_layer = 'conv1'
73-
fc_layer = 'fc'
74-
75-
# Hook the corresponding layer in the model
76-
extractor = cams.__dict__[name](model, conv_layer, fc_layer if name == 'CAM' else input_layer)
77-
78-
self._test_extractor(extractor, model)
79-
80-
def _test_cam_arbitrary_layer(self, name):
81-
8271
model = resnet18(pretrained=False).eval()
8372
conv_layer = 'layer4.1.relu'
84-
input_layer = 'conv1'
85-
fc_layer = 'fc'
8673

8774
# Hook the corresponding layer in the model
88-
extractor = cams.__dict__[name](model, conv_layer, fc_layer if name == 'CAM' else input_layer)
75+
extractor = cams.__dict__[name](model, conv_layer)
8976

90-
self._test_extractor(extractor, model)
77+
with torch.no_grad():
78+
self._test_extractor(extractor, model)
9179

9280
def _test_gradcam(self, name):
9381

9482
# Get a pretrained model
9583
model = mobilenet_v2(pretrained=False)
96-
conv_layer = 'features'
84+
conv_layer = 'features.17.conv.3'
9785

9886
# Hook the corresponding layer in the model
9987
extractor = cams.__dict__[name](model, conv_layer)
10088

10189
self._test_extractor(extractor, model)
10290

103-
def _test_gradcam_arbitrary_layer(self, name):
91+
def test_smooth_gradcampp(self):
10492

105-
model = mobilenet_v2(pretrained=False)
106-
conv_layer = 'features.17.conv.3'
93+
# Get a pretrained model
94+
model = mobilenet_v2(pretrained=False).eval()
10795

10896
# Hook the corresponding layer in the model
109-
extractor = cams.__dict__[name](model, conv_layer)
97+
extractor = cams.SmoothGradCAMpp(model)
11098

11199
self._test_extractor(extractor, model)
112100

113-
def test_smooth_gradcampp(self):
114101

115-
# Get a pretrained model
116-
model = mobilenet_v2(pretrained=False)
117-
conv_layer = 'features'
118-
input_layer = 'features'
102+
class CAMUtilsTester(unittest.TestCase):
119103

120-
# Hook the corresponding layer in the model
121-
extractor = cams.SmoothGradCAMpp(model, conv_layer, input_layer)
104+
@staticmethod
105+
def _get_custom_module():
122106

123-
self._test_extractor(extractor, model)
107+
mod = nn.Sequential(
108+
nn.Sequential(
109+
nn.Conv2d(3, 8, 3, 1),
110+
nn.ReLU(),
111+
nn.Conv2d(8, 16, 3, 1),
112+
nn.ReLU(),
113+
nn.AdaptiveAvgPool2d((1, 1))
114+
),
115+
nn.Flatten(1),
116+
nn.Linear(16, 1)
117+
)
118+
return mod
119+
120+
def test_locate_candidate_layer(self):
121+
122+
# ResNet-18
123+
mod = resnet18().eval()
124+
self.assertEqual(cams.utils.locate_candidate_layer(mod), 'layer4')
125+
126+
# Custom model
127+
mod = self._get_custom_module()
128+
129+
self.assertEqual(cams.utils.locate_candidate_layer(mod), '0.3')
130+
# Check that the model is switched back to its origin mode afterwards
131+
self.assertTrue(mod.training)
132+
133+
def test_locate_linear_layer(self):
134+
135+
# ResNet-18
136+
mod = resnet18().eval()
137+
self.assertEqual(cams.utils.locate_linear_layer(mod), 'fc')
138+
139+
# Custom model
140+
mod = self._get_custom_module()
141+
self.assertEqual(cams.utils.locate_linear_layer(mod), '2')
124142

125143

126144
for cam_extractor in ['CAM', 'ScoreCAM', 'SSCAM', 'ISCAM']:
127145
def do_test(self, cam_extractor=cam_extractor):
128146
self._test_cam(cam_extractor)
129-
self._test_cam_arbitrary_layer(cam_extractor)
130147

131-
setattr(Tester, "test_" + cam_extractor.lower(), do_test)
148+
setattr(CAMCoreTester, "test_" + cam_extractor.lower(), do_test)
132149

133150

134151
for cam_extractor in ['GradCAM', 'GradCAMpp']:
135152
def do_test(self, cam_extractor=cam_extractor):
136153
self._test_gradcam(cam_extractor)
137-
self._test_gradcam_arbitrary_layer(cam_extractor)
138154

139-
setattr(Tester, "test_" + cam_extractor.lower(), do_test)
155+
setattr(CAMCoreTester, "test_" + cam_extractor.lower(), do_test)
140156

141157

142158
if __name__ == '__main__':

test/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torchcam import utils
77

88

9-
class Tester(unittest.TestCase):
9+
class UtilsTester(unittest.TestCase):
1010
def test_overlay_mask(self):
1111

1212
img = Image.fromarray(np.zeros((4, 4, 3)).astype(np.uint8))

torchcam/cams/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .cam import *
22
from .gradcam import *
3+
from .utils import *

0 commit comments

Comments
 (0)