-
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Added original CAM implementation (#2)
* docs: Fixed usage instruction * feat: Added original CAM implementation * chore: Reorganized package * test: Added CAM unittest * refactor: Refactored CAM * refactor: Refactored GradCAMs * refactor: Refactored CAMs * test: Updated unittests accordingly * docs: Updated readme * style: Removed extra blank line * docs: Updated documentation
- Loading branch information
Showing
11 changed files
with
305 additions
and
199 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
torchcam.cams | ||
============= | ||
|
||
|
||
.. currentmodule:: torchcam.cams | ||
|
||
|
||
CAM | ||
-------- | ||
Related to activation-based class activation maps. | ||
|
||
|
||
.. autoclass:: CAM | ||
|
||
|
||
Grad-CAM | ||
-------- | ||
Related to gradient-based class activation maps. | ||
|
||
|
||
.. autoclass:: GradCAM | ||
|
||
|
||
.. autoclass:: GradCAMpp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import unittest | ||
import requests | ||
from io import BytesIO | ||
from PIL import Image | ||
import torch | ||
from torchvision.models import resnet18, mobilenet_v2 | ||
from torchvision.transforms.functional import resize, to_tensor, normalize | ||
|
||
from torchcam import cams | ||
|
||
|
||
class Tester(unittest.TestCase): | ||
|
||
def _verify_cam(self, cam): | ||
# Simple verifications | ||
self.assertIsInstance(cam, torch.Tensor) | ||
self.assertEqual(cam.shape, (1, 7, 7)) | ||
|
||
@staticmethod | ||
def _get_img_tensor(): | ||
|
||
# Get a dog image | ||
URL = 'https://www.woopets.fr/assets/races/000/066/big-portrait/border-collie.jpg' | ||
response = requests.get(URL) | ||
|
||
# Forward an image | ||
pil_img = Image.open(BytesIO(response.content), mode='r').convert('RGB') | ||
img_tensor = normalize(to_tensor(resize(pil_img, (224, 224))), | ||
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
|
||
return img_tensor | ||
|
||
def test_cam(self): | ||
# Get a pretrained model | ||
model = resnet18(pretrained=True).eval() | ||
conv_layer = 'layer4' | ||
fc_layer = 'fc' | ||
# Border collie index in ImageNet | ||
class_idx = 232 | ||
|
||
# Hook the corresponding layer in the model | ||
extractor = cams.CAM(model, conv_layer, fc_layer) | ||
|
||
# Get a dog image | ||
img_tensor = self._get_img_tensor() | ||
# Forward it | ||
with torch.no_grad(): | ||
_ = model(img_tensor.unsqueeze(0)) | ||
|
||
# Use the hooked data to compute activation map | ||
self._verify_cam(extractor(class_idx)) | ||
|
||
def _test_gradcam(self, name): | ||
|
||
# Get a pretrained model | ||
model = mobilenet_v2(pretrained=True) | ||
conv_layer = 'features' | ||
# Border collie index in ImageNet | ||
class_idx = 232 | ||
|
||
# Hook the corresponding layer in the model | ||
extractor = cams.__dict__[name](model, conv_layer) | ||
|
||
# Get a dog image | ||
img_tensor = self._get_img_tensor() | ||
|
||
# Forward an image | ||
out = model(img_tensor.unsqueeze(0)) | ||
|
||
# Use the hooked data to compute activation map | ||
self._verify_cam(extractor(out, class_idx)) | ||
|
||
|
||
for cam_extractor in ['GradCAM', 'GradCAMpp']: | ||
def do_test(self, cam_extractor=cam_extractor): | ||
self._test_gradcam(cam_extractor) | ||
|
||
setattr(Tester, "test_" + cam_extractor.lower(), do_test) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from torchcam.gradcam import * | ||
from torchcam import cams | ||
from torchcam import utils | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .cam import * | ||
from .gradcam import * | ||
|
||
del cam | ||
del gradcam |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
#!usr/bin/python | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
GradCAM | ||
""" | ||
|
||
import torch | ||
|
||
|
||
__all__ = ['CAM'] | ||
|
||
|
||
class _CAM(object): | ||
"""Implements a class activation map extractor | ||
Args: | ||
model (torch.nn.Module): input model | ||
conv_layer (str): name of the last convolutional layer | ||
""" | ||
|
||
hook_a = None | ||
|
||
def __init__(self, model, conv_layer): | ||
|
||
if not hasattr(model, conv_layer): | ||
raise ValueError(f"Unable to find submodule {conv_layer} in the model") | ||
self.model = model | ||
# Forward hook | ||
self.model._modules.get(conv_layer).register_forward_hook(self._hook_a) | ||
|
||
def _hook_a(self, module, input, output): | ||
self.hook_a = output.data | ||
|
||
@staticmethod | ||
def _normalize(cams): | ||
cams -= cams.flatten(start_dim=1).min().view(-1, 1, 1) | ||
cams /= cams.flatten(start_dim=1).max().view(-1, 1, 1) | ||
|
||
return cams | ||
|
||
def _get_weights(self, class_idx): | ||
|
||
raise NotImplementedError | ||
|
||
def __call__(self, class_idx, normalized=True): | ||
|
||
# Get map weight | ||
weights = self._get_weights(class_idx) | ||
|
||
# Perform the weighted combination to get the CAM | ||
batch_cams = (weights.view(-1, 1, 1) * self.hook_a).sum(dim=1) | ||
|
||
# Normalize the CAM | ||
if normalized: | ||
batch_cams = self._normalize(batch_cams) | ||
|
||
return batch_cams | ||
|
||
|
||
class CAM(_CAM): | ||
"""Implements a class activation map extractor as described in https://arxiv.org/abs/1512.04150 | ||
Args: | ||
model (torch.nn.Module): input model | ||
conv_layer (str): name of the last convolutional layer | ||
""" | ||
|
||
hook_a = None | ||
|
||
def __init__(self, model, conv_layer, fc_layer): | ||
|
||
super().__init__(model, conv_layer) | ||
# Softmax weight | ||
self._fc_weights = self.model._modules.get(fc_layer).weight.data | ||
|
||
def _get_weights(self, class_idx): | ||
|
||
# Take the FC weights of the target class | ||
return self._fc_weights[class_idx, :] |
Oops, something went wrong.