|
| 1 | +Installation |
| 2 | +============ |
| 3 | + |
| 4 | +Install all the dependencies to make the most out of TorchCAM |
| 5 | + |
| 6 | +.. code-block::python |
| 7 | +
|
| 8 | + >>> !pip install torchvision matplotlib |
| 9 | +
|
| 10 | +
|
| 11 | +Latest stable release |
| 12 | +--------------------- |
| 13 | + |
| 14 | +.. code-block:: python |
| 15 | +
|
| 16 | + >>> !pip install torch-cam |
| 17 | +
|
| 18 | +From source |
| 19 | +----------- |
| 20 | + |
| 21 | +.. code-block:: python |
| 22 | +
|
| 23 | + >>> # Install the most up-to-date version from GitHub |
| 24 | + >>> !pip install -e git+https://github.com/frgfm/torch-cam.git#egg=torchcam |
| 25 | +
|
| 26 | +
|
| 27 | +Now go to ``Runtime/Restart runtime`` for your changes to take effect! |
| 28 | + |
| 29 | +Basic usage |
| 30 | +=========== |
| 31 | + |
| 32 | +.. code-block:: python |
| 33 | +
|
| 34 | + >>> # Download an image |
| 35 | + >>> !wget https://www.woopets.fr/assets/races/000/066/big-portrait/border-collie.jpg |
| 36 | + >>> # Set this to your image path if you wish to run it on your own data |
| 37 | + >>> img_path = "border-collie.jpg" |
| 38 | +
|
| 39 | +
|
| 40 | +.. code-block:: python |
| 41 | +
|
| 42 | + >>> # Instantiate your model here |
| 43 | + >>> from torchvision.models import resnet18 |
| 44 | + >>> model = resnet18(pretrained=True).eval() |
| 45 | +
|
| 46 | +
|
| 47 | +
|
| 48 | +Illustrate your classifier capabilities |
| 49 | +--------------------------------------- |
| 50 | + |
| 51 | +.. code-block:: python |
| 52 | +
|
| 53 | + >>> %matplotlib inline |
| 54 | + >>> # All imports |
| 55 | + >>> from torchvision.io.image import read_image |
| 56 | + >>> from torchvision.transforms.functional import normalize, resize, to_pil_image |
| 57 | + >>> import matplotlib.pyplot as plt |
| 58 | + >>> from torchcam.cams import SmoothGradCAMpp, LayerCAM |
| 59 | + >>> from torchcam.utils import overlay_mask |
| 60 | +
|
| 61 | +.. code-block:: python |
| 62 | +
|
| 63 | + >>> cam_extractor = SmoothGradCAMpp(model) |
| 64 | + >>> # Get your input |
| 65 | + >>> img = read_image(img_path) |
| 66 | + >>> # Preprocess it for your chosen model |
| 67 | + >>> input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| 68 | + >>> # Preprocess your data and feed it to the model |
| 69 | + >>> out = model(input_tensor.unsqueeze(0)) |
| 70 | + >>> # Retrieve the CAM by passing the class index and the model output |
| 71 | + >>> cams = cam_extractor(out.squeeze(0).argmax().item(), out) |
| 72 | +
|
| 73 | +.. code-block:: python |
| 74 | +
|
| 75 | + >>> # Notice that there is one CAM per target layer (here only 1) |
| 76 | + >>> for cam in cams: |
| 77 | + >>> print(cam.shape) |
| 78 | + torch.Size([7, 7]) |
| 79 | +
|
| 80 | +
|
| 81 | +.. code-block:: python |
| 82 | +
|
| 83 | + >>> # The raw CAM |
| 84 | + >>> for name, cam in zip(cam_extractor.target_names, cams): |
| 85 | + >>> plt.imshow(cam.numpy()); plt.axis('off'); plt.title(name); plt.show() |
| 86 | +
|
| 87 | +
|
| 88 | +.. code-block:: python |
| 89 | +
|
| 90 | + >>> # Overlayed on the image |
| 91 | + >>> for name, cam in zip(cam_extractor.target_names, cams): |
| 92 | + >>> result = overlay_mask(to_pil_image(img), to_pil_image(cam, mode='F'), alpha=0.5) |
| 93 | + >>> plt.imshow(result); plt.axis('off'); plt.title(name); plt.show() |
| 94 | +
|
| 95 | +
|
| 96 | +.. code-block:: python |
| 97 | +
|
| 98 | + >>> # Once you're finished, clear the hooks on your model |
| 99 | + >>> cam_extractor.clear_hooks() |
| 100 | +
|
| 101 | +Advanced tricks |
| 102 | +=============== |
| 103 | + |
| 104 | +Extract localization cues |
| 105 | +------------------------- |
| 106 | + |
| 107 | +.. code-block::python |
| 108 | +
|
| 109 | + >>> import torch |
| 110 | + >>> from torch.nn.functional import softmax, interpolate |
| 111 | +
|
| 112 | +.. code-block::python |
| 113 | +
|
| 114 | + >>> # Retrieve the CAM from several layers at the same time |
| 115 | + >>> cam_extractor = LayerCAM(model) |
| 116 | + >>> # Preprocess your data and feed it to the model |
| 117 | + >>> out = model(input_tensor.unsqueeze(0)) |
| 118 | + >>> print(softmax(out, dim=1).max()) |
| 119 | + tensor(0.9115, grad_fn=<MaxBackward1>) |
| 120 | +
|
| 121 | +
|
| 122 | +.. code-block::python |
| 123 | +
|
| 124 | + >>> cams = cam_extractor(out.squeeze(0).argmax().item(), out) |
| 125 | +
|
| 126 | +.. code-block::python |
| 127 | +
|
| 128 | + >>> # Resize it |
| 129 | + >>> resized_cams = [resize(to_pil_image(cam), img.shape[-2:]) for cam in cams] |
| 130 | + >>> segmaps = [to_pil_image((resize(cam.unsqueeze(0), img.shape[-2:]).squeeze(0) >= 0.5).to(dtype=torch.float32)) for cam in cams] |
| 131 | + >>> # Plot it |
| 132 | + >>> for name, cam, seg in zip(cam_extractor.target_names, resized_cams, segmaps): |
| 133 | + >>> _, axes = plt.subplots(1, 2) |
| 134 | + >>> axes[0].imshow(cam); axes[0].axis('off'); axes[0].set_title(name) |
| 135 | + >>> axes[1].imshow(seg); axes[1].axis('off'); axes[1].set_title(name) |
| 136 | + >>> plt.show() |
| 137 | +
|
| 138 | +
|
| 139 | +Fuse CAMs from multiple layers |
| 140 | +------------------------------ |
| 141 | + |
| 142 | +.. code-block::python |
| 143 | +
|
| 144 | + >>> # Retrieve the CAM from several layers at the same time |
| 145 | + >>> cam_extractor = LayerCAM(model, ["layer2", "layer3", "layer4"]) |
| 146 | + >>> # Preprocess your data and feed it to the model |
| 147 | + >>> out = model(input_tensor.unsqueeze(0)) |
| 148 | + >>> # Retrieve the CAM by passing the class index and the model output |
| 149 | + >>> cams = cam_extractor(out.squeeze(0).argmax().item(), out) |
| 150 | +
|
| 151 | +.. code-block::python |
| 152 | +
|
| 153 | + >>> # This time, there are several CAMs |
| 154 | + >>> for cam in cams: |
| 155 | + >>> print(cam.shape) |
| 156 | + torch.Size([14, 14]) |
| 157 | + torch.Size([7, 7]) |
| 158 | +
|
| 159 | +
|
| 160 | +.. code-block::python |
| 161 | +
|
| 162 | + >>> # The raw CAM |
| 163 | + >>> _, axes = plt.subplots(1, len(cam_extractor.target_names)) |
| 164 | + >>> for idx, name, cam in zip(range(len(cam_extractor.target_names)), cam_extractor.target_names, cams): |
| 165 | + >>> axes[idx].imshow(cam.numpy()); axes[idx].axis('off'); axes[idx].set_title(name); |
| 166 | + >>> plt.show() |
| 167 | +
|
| 168 | +
|
| 169 | +.. code-block::python |
| 170 | +
|
| 171 | + >>> # Let's fuse them |
| 172 | + >>> fused_cam = cam_extractor.fuse_cams(cams) |
| 173 | + >>> # Plot the raw version |
| 174 | + >>> plt.imshow(fused_cam.numpy()); plt.axis('off'); plt.title(" + ".join(cam_extractor.target_names)); plt.show() |
| 175 | + >>> # Plot the overlayed version |
| 176 | + >>> result = overlay_mask(to_pil_image(img), to_pil_image(fused_cam, mode='F'), alpha=0.5) |
| 177 | + >>> plt.imshow(result); plt.axis('off'); plt.title(" + ".join(cam_extractor.target_names)); plt.show() |
0 commit comments