Skip to content

Commit 4b17c98

Browse files
committed
update docs
1 parent ed328df commit 4b17c98

File tree

13 files changed

+133
-131
lines changed

13 files changed

+133
-131
lines changed

docs/input_feature_interpreter.rst

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,19 @@
33
Input Feature Based Interpreters
44
================================
55

6-
Smooth Gradients
6+
Consensus
77
----------------
88

9-
.. autoclass:: interpretdl.SmoothGradInterpreter
9+
.. autoclass:: interpretdl.ConsensusInterpreter
10+
:members:
11+
12+
Gradient Shap
13+
-------------
14+
15+
.. autoclass:: interpretdl.GradShapCVInterpreter
16+
:members:
17+
18+
.. autoclass:: interpretdl.GradShapNLPInterpreter
1019
:members:
1120

1221
Integrated Gradients
@@ -17,22 +26,7 @@ Integrated Gradients
1726

1827
.. autoclass:: interpretdl.IntGradNLPInterpreter
1928
:members:
20-
21-
Occlusion
22-
---------
2329

24-
.. autoclass:: interpretdl.OcclusionInterpreter
25-
:members:
26-
27-
Gradient Shap
28-
-------------
29-
30-
.. autoclass:: interpretdl.GradShapCVInterpreter
31-
:members:
32-
33-
.. autoclass:: interpretdl.GradShapNLPInterpreter
34-
:members:
35-
3630
LIME
3731
----
3832

@@ -48,6 +42,30 @@ LIME With Global Prior
4842
.. autoclass:: interpretdl.LIMEPriorInterpreter
4943
:members:
5044

45+
LRP
46+
----------------
47+
48+
.. autoclass:: interpretdl.LRPCVInterpreter
49+
:members:
50+
51+
Occlusion
52+
---------
53+
54+
.. autoclass:: interpretdl.OcclusionInterpreter
55+
:members:
56+
57+
Smooth Gradients
58+
----------------
59+
60+
.. autoclass:: interpretdl.SmoothGradInterpreter
61+
:members:
62+
63+
Smooth Gradients V2
64+
-------------------
65+
66+
.. autoclass:: interpretdl.SmoothGradInterpreterV2
67+
:members:
68+
5169
NormLIME
5270
--------
5371

@@ -56,9 +74,3 @@ NormLIME
5674

5775
.. autoclass:: interpretdl.NormLIMENLPInterpreter
5876
:members:
59-
60-
LRP
61-
----------------
62-
63-
.. autoclass:: interpretdl.LRPCVInterpreter
64-
:members:

interpretdl/interpreter/abc_interpreter.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,27 @@
1111

1212
class Interpreter(ABC):
1313
"""
14-
Interpreter is the base abstract class for all interpretation algorithms.
15-
Interpreters should (1) prepare the ``self.predict_fn`` that outputs probability predictions, gradients or other
16-
desired intermediate results of the model, and (2) implement the core function ``interpret`` of the interpretation
17-
algorithm.
14+
Interpreter is the base abstract class for all Interpreters.
15+
The implementation of Interpreters should (1) prepare the ``self.predict_fn`` that outputs probability predictions,
16+
gradients or other desired intermediate results of the model, and (2) implement the core function ``interpret`` of
17+
the interpretation algorithm.
18+
This kind of implementation works for all post-poc interpretation algorithms. While there are other algorithms that
19+
may have different features, this kind of implementation can cover most of them. So we follow this design for all
20+
Interpreters in this library.
1821
1922
Three sub-abstract Interpreters that implement ``self.predict_fn`` are currently provided in this file:
2023
``InputGradientInterpreter``, ``InputOutputInterpreter``, ``IntermediateLayerInterpreter``. For each of them, the
2124
implemented ``predict_fn`` can be used by several different algorithms. Therefore, the further implementations can
22-
focus on the core algorithm.
23-
24-
Args:
25-
paddle_model (callable): A model with ``forward`` and possibly ``backward`` functions.
26-
device (str): The device used for running `paddle_model`, options: ``cpu``, ``gpu:0``, ``gpu:1`` etc.
27-
use_cuda (bool): Would be deprecated soon. Use ``device`` directly.
25+
focus on the core algorithm. More sub-abstract Interpreters will be provided if necessary.
2826
"""
2927

3028
def __init__(self, paddle_model: callable, device: str, use_cuda: bool = None, **kwargs):
29+
"""
30+
31+
Args:
32+
paddle_model (callable): A model with ``forward`` and possibly ``backward`` functions.
33+
device (str): The device used for running `paddle_model`, options: ``cpu``, ``gpu:0``, ``gpu:1`` etc.
34+
"""
3135
self.device = device
3236
self.paddle_model = paddle_model
3337
self.predict_fn = None
@@ -54,13 +58,16 @@ def interpret(self, **kwargs):
5458
raise NotImplementedError
5559

5660
def _build_predict_fn(self, **kwargs):
57-
"""Build self.predict_fn for interpreters."""
61+
""" Build self.predict_fn for interpreters. This will be called by interpret(). """
5862
raise NotImplementedError
5963

6064
def _paddle_env_setup(self):
6165
"""Prepare the environment setup. This is not always necessary because the setup can be done within the
62-
function of ``_build_predict_fn``. This function is a simple implementation for disabling gradient computation.
66+
function of ``_build_predict_fn``.
6367
"""
68+
#######################################################################
69+
# This is a simple implementation for disabling gradient computation. #
70+
#######################################################################
6471
import paddle
6572
if not paddle.is_compiled_with_cuda() and self.device[:3] == 'gpu':
6673
print("Paddle is not installed with GPU support. Change to CPU version now.")
@@ -79,11 +86,16 @@ class InputGradientInterpreter(Interpreter):
7986
``InputGradientInterpreter`` are used by input gradient based Interpreters. Interpreters that are derived from
8087
``InputGradientInterpreter``: ``GradShapCVInterpreter``, ``IntGradCVInterpreter``, ``SmoothGradInterpreter``.
8188
82-
The ``predict_fn`` provided by this interpreter will output input gradient given an input.
83-
89+
The ``predict_fn`` in this interpreter will return input gradient given an input.
8490
"""
8591

8692
def __init__(self, paddle_model: callable, device: str, use_cuda: bool = None, **kwargs):
93+
"""
94+
95+
Args:
96+
paddle_model (callable): A model with ``forward`` and possibly ``backward`` functions.
97+
device (str): The device used for running `paddle_model`, options: ``cpu``, ``gpu:0``, ``gpu:1`` etc.
98+
"""
8799
Interpreter.__init__(self, paddle_model, device, use_cuda, **kwargs)
88100
assert hasattr(paddle_model, 'forward'), \
89101
"paddle_model has to be " \
@@ -197,6 +209,12 @@ class InputOutputInterpreter(Interpreter):
197209
"""
198210

199211
def __init__(self, paddle_model: callable, device: str, use_cuda: bool = None, **kwargs):
212+
"""
213+
214+
Args:
215+
paddle_model (callable): A model with ``forward`` and possibly ``backward`` functions.
216+
device (str): The device used for running `paddle_model`, options: ``cpu``, ``gpu:0``, ``gpu:1`` etc.
217+
"""
200218
Interpreter.__init__(self, paddle_model, device, use_cuda, **kwargs)
201219
assert hasattr(paddle_model, 'forward'), \
202220
"paddle_model has to be " \
@@ -259,11 +277,17 @@ class IntermediateLayerInterpreter(Interpreter):
259277
Interpreters that are derived from ``IntermediateLayerInterpreter``:
260278
``RolloutInterpreter``, ``ScoreCAMInterpreter``.
261279
262-
The ``predict_fn`` provided by this interpreter will output the model's intermediate outputs given an input.
263-
280+
The ``predict_fn`` provided by this interpreter will return the model's intermediate outputs given an input.
264281
"""
265282

266283
def __init__(self, paddle_model: callable, device: str, use_cuda: bool = None, **kwargs):
284+
"""
285+
286+
Args:
287+
paddle_model (callable): A model with ``forward`` and possibly ``backward`` functions.
288+
device (str): The device used for running `paddle_model`, options: ``cpu``, ``gpu:0``, ``gpu:1`` etc.
289+
"""
290+
267291
Interpreter.__init__(self, paddle_model, device, use_cuda, **kwargs)
268292
assert hasattr(paddle_model, 'forward'), \
269293
"paddle_model has to be " \
@@ -272,16 +296,15 @@ def __init__(self, paddle_model: callable, device: str, use_cuda: bool = None, *
272296
def _build_predict_fn(self, rebuild: bool = False, target_layer: str = None, target_layer_pattern: str = None):
273297
"""Build self.predict_fn for IntermediateLayer based algorithms.
274298
The model is supposed to be a classification model.
275-
target_layer and target_layer_pattern cannot be set at the same time.
299+
``target_layer`` and ``target_layer_pattern`` cannot be set at the same time. See the arguments below.
276300
277301
Args:
278302
rebuild (bool, optional): forces to rebuild. Defaults to False.
279-
target_layer (str, optional): the name of the desired layer whose features will output.
280-
This is used when there is only one layer to output. Conflict with ``target_layer_pattern``.
281-
Defaults to None.
282-
target_layer_pattern (str, optional): the pattern name of the layers whose features will output.
283-
This is used when there are several layers to output and they share a common pattern name.
284-
Conflict with ``target_layer``. Defaults to None.
303+
target_layer (str, optional): the name of the desired layer whose features will output. This is used when
304+
there is only one layer to output. Conflict with ``target_layer_pattern``. Defaults to None.
305+
target_layer_pattern (str, optional): the pattern name of the layers whose features will output. This is
306+
used when there are several layers to output and they share a common pattern name. Conflict with
307+
``target_layer``. Defaults to None.
285308
"""
286309

287310
if self.predict_fn is not None:

interpretdl/interpreter/consensus.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,17 @@
55
class ConsensusInterpreter(object):
66
"""
77
8-
ConsensusInterpreter averages the explanations of a given Interpreter over a list of models.
9-
The averaged result is more like an explanation for the data, instead of specific models.
10-
For visual object recognition tasks, the Consensus explanation would be more aligned with the object than
11-
individual models.
8+
ConsensusInterpreter averages the explanations of a given Interpreter over a list of models. The averaged result
9+
is more like an explanation for the data, instead of specific models. For visual object recognition tasks, the
10+
Consensus explanation would be more aligned with the object than individual models.
1211
1312
More details regarding the Consensus method can be found in the original paper:
1413
https://arxiv.org/abs/2109.00707.
15-
1614
"""
1715

1816
def __init__(self, InterpreterClass, list_of_models: list, device: str = 'gpu:0', use_cuda=None, **kwargs):
1917
"""
20-
18+
2119
Args:
2220
InterpreterClass ([type]): The given Interpreter defined in InterpretDL.
2321
list_of_models (list): a list of trained models. Can be found from paddle.vision.models, or
@@ -35,13 +33,14 @@ def __init__(self, InterpreterClass, list_of_models: list, device: str = 'gpu:0'
3533
def interpret(self, inputs: str or list(str) or np.ndarray, **kwargs) -> np.ndarray:
3634
"""
3735
The technical details are simple to understand for the Consensus method:
38-
Given the ``inputs`` and the interpretation algorithm (one of Interpreters), each model in ``list_of_models``
39-
will produce an explanation, then Consensus will concatenate all the explanations. Subsequent normalization
40-
and average can be done as users' preference. The suggested operation for input gradient based algorithms is
41-
average of the absolute values.
36+
Given the ``inputs`` and the interpretation algorithm (one of the Interpreters), each model in
37+
``list_of_models`` will produce an explanation, then Consensus will concatenate all the explanations.
38+
Subsequent normalization and average can be done as users' preference. The suggested operation for input
39+
gradient based algorithms is average of the absolute values.
4240
4341
We leave the visualization to users.
44-
See https://github.com/PaddlePaddle/InterpretDL/tree/master/tutorials/consensus_tutorial_cv.ipynb for an example.
42+
See https://github.com/PaddlePaddle/InterpretDL/tree/master/tutorials/consensus_tutorial_cv.ipynb for an
43+
example.
4544
4645
.. code-block:: python
4746
@@ -74,7 +73,8 @@ def interpret(self, inputs: str or list(str) or np.ndarray, **kwargs) -> np.ndar
7473
ax[-1].set_title('Consensus')
7574
7675
Args:
77-
inputs (str or list of strs or numpy.ndarray): The input image filepath or a list of filepaths or numpy array of read images.
76+
inputs (str or list of strs or numpy.ndarray): The input image filepath or a list of filepaths or numpy
77+
array of read images.
7878
7979
Returns:
8080
np.ndarray: Concatenated raw explanations.

interpretdl/interpreter/forgetting_events.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@ class ForgettingEventsInterpreter(Interpreter):
2121
https://arxiv.org/abs/1812.05159.
2222
"""
2323

24-
def __init__(self, paddle_model: callable, device: str, use_cuda=None):
25-
"""Initialize the ForgettingEventsInterpreter.
26-
27-
Args:
28-
paddle_model (callable): A user-defined function that gives access to model predictions.
29-
It takes in data inputs and output predictions.
30-
device (str): Whether or not to use cuda. Default: None.
24+
def __init__(self, paddle_model: callable, device: str = 'gpu:0', use_cuda=None):
3125
"""
26+
27+
Args:
28+
paddle_model (callable): A model with ``forward`` and possibly ``backward`` functions.
29+
device (str): The device used for running `paddle_model`, options: ``cpu``, ``gpu:0``, ``gpu:1`` etc.
30+
"""
3231
Interpreter.__init__(self, paddle_model, device, use_cuda)
3332

3433
def interpret(self,

interpretdl/interpreter/gradient_cam.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class GradCAMInterpreter(Interpreter):
1111
1212
Given a convolutional network and an image classification task, classification activation map (CAM) can be derived
1313
from the global average pooling and the last fully-connected layer, and show the important regions that affect
14-
model decisions.
14+
model's decisions.
1515
1616
GradCAM further looks at the gradients flowing into one of the convolutional layers to give weight to activation
1717
maps. Note that if there is a global average pooling layer in the network, GradCAM targeting the last layer is
@@ -25,11 +25,11 @@ class GradCAMInterpreter(Interpreter):
2525

2626
def __init__(self, paddle_model: callable, device: str = 'gpu:0', use_cuda=None):
2727
"""
28-
28+
2929
Args:
3030
paddle_model (callable): A model with ``forward`` and possibly ``backward`` functions.
3131
device (str): The device used for running `paddle_model`, options: ``cpu``, ``gpu:0``, ``gpu:1`` etc.
32-
"""
32+
"""
3333
Interpreter.__init__(self, paddle_model, device, use_cuda)
3434
self.paddle_prepared = False
3535

@@ -47,9 +47,9 @@ def interpret(self,
4747
save_path: str = None) -> np.ndarray:
4848
"""
4949
The technical details of the GradCAM method are described as follows:
50-
GradCAM computes the feature map and the gradient of the objective function w.r.t. ``target_layer_name``.
51-
With the average of gradients along the spatial dimensions, gradients will be multiplied with feature map,
52-
following by a ReLU activation to produce the final explanation.
50+
GradCAM computes the feature map at the layer of ``target_layer_name`` and the gradient of the objective
51+
function w.r.t. ``target_layer_name``. With the average of gradients along the spatial dimensions, gradients
52+
will be multiplied with feature map, following by a ReLU activation to produce the final explanation.
5353
5454
Args:
5555
inputs (str or list of strs or numpy.ndarray): The input image filepath or a list of filepaths or numpy

interpretdl/interpreter/gradient_shap.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ class GradShapCVInterpreter(InputGradientInterpreter):
1616
GradShap uses noised inputs to get input gradients and then average.
1717
1818
More details regarding the GradShap method can be found in the original paper:
19-
http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions
19+
http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.
2020
"""
2121

2222
def __init__(self, paddle_model: callable, device: str = 'gpu:0', use_cuda=None):
2323
"""
24-
24+
2525
Args:
2626
paddle_model (callable): A model with ``forward`` and possibly ``backward`` functions.
2727
device (str): The device used for running `paddle_model`, options: ``cpu``, ``gpu:0``, ``gpu:1`` etc.
@@ -128,16 +128,16 @@ class GradShapNLPInterpreter(Interpreter):
128128
are done for the embeddings.
129129
130130
More details regarding the GradShap method can be found in the original paper:
131-
http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions
131+
http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.
132132
"""
133133

134134
def __init__(self, paddle_model: callable, device: str = 'gpu:0', use_cuda: bool = None) -> None:
135135
"""
136-
136+
137137
Args:
138138
paddle_model (callable): A model with ``forward`` and possibly ``backward`` functions.
139139
device (str): The device used for running `paddle_model`, options: ``cpu``, ``gpu:0``, ``gpu:1`` etc.
140-
"""
140+
"""
141141
Interpreter.__init__(self, paddle_model, device, use_cuda)
142142

143143
def interpret(self,

interpretdl/interpreter/integrated_gradients.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ class IntGradCVInterpreter(InputGradientInterpreter):
2222

2323
def __init__(self, paddle_model: callable, device: str = 'gpu:0', use_cuda: bool = None):
2424
"""
25-
25+
2626
Args:
2727
paddle_model (callable): A model with ``forward`` and possibly ``backward`` functions.
2828
device (str): The device used for running `paddle_model`, options: ``cpu``, ``gpu:0``, ``gpu:1`` etc.
29-
"""
29+
"""
3030
InputGradientInterpreter.__init__(self, paddle_model, device, use_cuda)
3131

3232
def interpret(self,
@@ -138,16 +138,16 @@ class IntGradNLPInterpreter(Interpreter):
138138
are done for the embeddings.
139139
140140
More details regarding the Integrated Gradients method can be found in the original paper:
141-
https://arxiv.org/abs/1703.01365
141+
https://arxiv.org/abs/1703.01365.
142142
"""
143143

144144
def __init__(self, paddle_model: callable, device: str = 'gpu:0', use_cuda: bool = None) -> None:
145145
"""
146-
146+
147147
Args:
148148
paddle_model (callable): A model with ``forward`` and possibly ``backward`` functions.
149149
device (str): The device used for running `paddle_model`, options: ``cpu``, ``gpu:0``, ``gpu:1`` etc.
150-
"""
150+
"""
151151
Interpreter.__init__(self, paddle_model, device, use_cuda)
152152

153153
def interpret(self,

0 commit comments

Comments
 (0)