Skip to content

Commit 521b4f9

Browse files
authored
docs: Fixed example docstring and unittests (#33)
* docs: Fixed example of SmoothGradCAMpp * test: Fixed unittests * test: Speeded up unittests * test: Fixed base Cam unittest * test: Optimized speed for Score CAM family * test: Optimized testing speed for SSCAM and ISCAM
1 parent 71c2756 commit 521b4f9

File tree

2 files changed

+12
-18
lines changed

2 files changed

+12
-18
lines changed

test/test_cams.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,6 @@
1111
from torchcam import cams
1212

1313

14-
def _forward(model, input_tensor):
15-
if model.training:
16-
scores = model(input_tensor)
17-
else:
18-
with torch.no_grad():
19-
scores = model(input_tensor)
20-
21-
return scores
22-
23-
2414
class CAMCoreTester(unittest.TestCase):
2515
def _verify_cam(self, cam):
2616
# Simple verifications
@@ -50,11 +40,11 @@ def _test_extractor(self, extractor, model):
5040
img_tensor = self._get_img_tensor()
5141

5242
# Check that a batch of 2 cannot be accepted
53-
_ = _forward(model, torch.stack((img_tensor, img_tensor)))
43+
_ = model(torch.stack((img_tensor, img_tensor)))
5444
self.assertRaises(ValueError, extractor, 0)
5545

5646
# Correct forward
57-
scores = _forward(model, img_tensor.unsqueeze(0))
47+
scores = model(img_tensor.unsqueeze(0))
5848

5949
# Check incorrect class index
6050
self.assertRaises(ValueError, extractor, -1)
@@ -68,20 +58,24 @@ def _test_extractor(self, extractor, model):
6858

6959
def _test_cam(self, name):
7060
# Get a pretrained model
71-
model = resnet18(pretrained=False).eval()
72-
conv_layer = 'layer4.1.relu'
61+
model = mobilenet_v2(pretrained=False).eval()
62+
conv_layer = None if name == "CAM" else 'features.16.conv.3'
7363

64+
kwargs = {}
65+
# Speed up testing by reducing the number of samples
66+
if name in ['SSCAM', 'ISCAM']:
67+
kwargs['num_samples'] = 4
7468
# Hook the corresponding layer in the model
75-
extractor = cams.__dict__[name](model, conv_layer)
69+
extractor = cams.__dict__[name](model, conv_layer, **kwargs)
7670

7771
with torch.no_grad():
7872
self._test_extractor(extractor, model)
7973

8074
def _test_gradcam(self, name):
8175

8276
# Get a pretrained model
83-
model = mobilenet_v2(pretrained=False)
84-
conv_layer = 'features.17.conv.3'
77+
model = mobilenet_v2(pretrained=False).eval()
78+
conv_layer = 'features.18.0'
8579

8680
# Hook the corresponding layer in the model
8781
extractor = cams.__dict__[name](model, conv_layer)

torchcam/cams/gradcam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class SmoothGradCAMpp(_GradCAM):
195195
>>> from torchvision.models import resnet18
196196
>>> from torchcam.cams import SmoothGradCAMpp
197197
>>> model = resnet18(pretrained=True).eval()
198-
>>> cam = SmoothGradCAMpp(model, 'layer4', 'conv1')
198+
>>> cam = SmoothGradCAMpp(model, 'layer4')
199199
>>> scores = model(input_tensor)
200200
>>> cam(class_idx=100)
201201

0 commit comments

Comments
 (0)