Skip to content

Commit ed328df

Browse files
committed
code style and docs string improvement for interpreters
1 parent 05dd75c commit ed328df

File tree

6 files changed

+76
-55
lines changed

6 files changed

+76
-55
lines changed

interpretdl/interpreter/_normlime_base.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,24 @@
1111

1212
class NormLIMECVInterpreter(LIMECVInterpreter):
1313
"""
14-
NormLIME Interpreter for CV tasks.
14+
NormLIME Interpreter for CV tasks.
15+
16+
(TODO) Some technical details will be complete soon.
1517
1618
More details regarding the NormLIME method can be found in the original paper:
17-
https://arxiv.org/abs/1909.04200
19+
https://arxiv.org/abs/1909.04200.
20+
21+
Args:
22+
paddle_model (_type_):
23+
A user-defined function that gives access to model predictions.
24+
It takes the following arguments:
25+
- data: Data inputs.
26+
and outputs predictions.
27+
device (str, optional): The device used for running `paddle_model`, options: ``cpu``, ``gpu:0``, ``gpu:1`` etc.
28+
use_cuda (_type_, optional): Would be deprecated soon. Use ``device`` directly.
1829
"""
1930

2031
def __init__(self, paddle_model, device='gpu:0', use_cuda=None):
21-
"""
22-
23-
Args:
24-
paddle_model (_type_):
25-
A user-defined function that gives access to model predictions.
26-
It takes the following arguments:
27-
- data: Data inputs.
28-
and outputs predictions.
29-
device (str, optional): The device used for running `paddle_model`, options: ``cpu``, ``gpu:0``, ``gpu:1`` etc.
30-
use_cuda (_type_, optional): Would be deprecated soon. Use ``device`` directly.
31-
"""
3232

3333
LIMECVInterpreter.__init__(self, paddle_model, use_cuda=use_cuda, device=device)
3434
self.lime_interpret = super().interpret
@@ -58,6 +58,8 @@ def interpret(self,
5858
"""
5959
Main function of the interpreter.
6060
61+
(TODO) Some technical details will be complete soon.
62+
6163
Args:
6264
image_paths (list of strs): A list of image filepaths.
6365
num_samples (int, optional): LIME sampling numbers. Larger number of samples usually gives more
@@ -182,20 +184,18 @@ class NormLIMENLPInterpreter(LIMENLPInterpreter):
182184
NormLIME Interpreter for NLP tasks.
183185
184186
More details regarding the NormLIME method can be found in the original paper:
185-
https://arxiv.org/abs/1909.04200
186-
"""
187+
https://arxiv.org/abs/1909.04200.
187188
188-
def __init__(self, paddle_model: callable, device: str = 'gpu:0', use_cuda=None):
189-
"""
189+
Args:
190+
paddle_model (callable): A user-defined function that gives access to model predictions.
191+
It takes the following arguments:
190192
191-
Args:
192-
paddle_model (callable): A user-defined function that gives access to model predictions.
193-
It takes the following arguments:
193+
- data: Data inputs.
194+
and outputs predictions. See the example at the end of ``interpret()``.
195+
device (str): The device used for running `paddle_model`, options: ``cpu``, ``gpu:0``, ``gpu:1`` etc.
196+
"""
194197

195-
- data: Data inputs.
196-
and outputs predictions. See the example at the end of ``interpret()``.
197-
device (str): The device used for running `paddle_model`, options: ``cpu``, ``gpu:0``, ``gpu:1`` etc.
198-
"""
198+
def __init__(self, paddle_model: callable, device: str = 'gpu:0', use_cuda=None):
199199
LIMENLPInterpreter.__init__(self, paddle_model, device, use_cuda)
200200
self.lime_interpret = super().interpret
201201

interpretdl/interpreter/gradient_cam.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,16 @@ def interpret(self,
5252
following by a ReLU activation to produce the final explanation.
5353
5454
Args:
55-
inputs (str or list of strs or numpy.ndarray): The input image filepath or a list of filepaths or numpy array of read images.
55+
inputs (str or list of strs or numpy.ndarray): The input image filepath or a list of filepaths or numpy
56+
array of read images.
5657
target_layer_name (str): The target layer to calculate gradients.
5758
labels (list or tuple or numpy.ndarray, optional): The target labels to analyze.
58-
The number of labels should be equal to the number of images. If None, the most likely label for each image will be used. Default: None
59-
resize_to (int, optional): [description]. Images will be rescaled with the shorter edge being `resize_to`. Defaults to 224.
60-
crop_to (int, optional): [description]. After resize, images will be center cropped to a square image with the size `crop_to`.
61-
If None, no crop will be performed. Defaults to None.
59+
The number of labels should be equal to the number of images. If None, the most likely label for each
60+
image will be used. Default: None
61+
resize_to (int, optional): [description]. Images will be rescaled with the shorter edge being `resize_to`.
62+
Defaults to 224.
63+
crop_to (int, optional): [description]. After resize, images will be center cropped to a square image with
64+
the size `crop_to`. If None, no crop will be performed. Defaults to None.
6265
visual (bool, optional): Whether or not to visualize the processed image. Default: True
6366
save_path (str or list of strs or None, optional): The filepath(s) to save the processed image(s).
6467
If None, the image will not be saved. Default: None

interpretdl/interpreter/lime.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,18 @@ def interpret(self,
5151
5252
Args:
5353
data (str): The input file path.
54-
interpret_class (int, optional): The index of class to interpret. If None, the most likely label will be used. Default: None
55-
num_samples (int, optional): LIME sampling numbers. Larger number of samples usually gives more accurate interpretation. Default: 1000
54+
interpret_class (int, optional): The index of class to interpret. If None, the most likely label will be
55+
used. Default: None
56+
num_samples (int, optional): LIME sampling numbers. Larger number of samples usually gives more accurate
57+
interpretation. Default: 1000
5658
batch_size (int, optional): Number of samples to forward each time. Default: 50
57-
resize_to (int, optional): [description]. Images will be rescaled with the shorter edge being `resize_to`. Defaults to 224.
58-
crop_to ([type], optional): [description]. After resize, images will be center cropped to a square image with the size `crop_to`.
59-
If None, no crop will be performed. Defaults to None.
59+
resize_to (int, optional): [description]. Images will be rescaled with the shorter edge being `resize_to`.
60+
Defaults to 224.
61+
crop_to ([type], optional): [description]. After resize, images will be center cropped to a square image
62+
with the size `crop_to`. If None, no crop will be performed. Defaults to None.
6063
visual (bool, optional): Whether or not to visualize the processed image. Default: True
61-
save_path (str, optional): The path to save the processed image. If None, the image will not be saved. Default: None
64+
save_path (str, optional): The path to save the processed image. If None, the image will not be saved.
65+
Default: None
6266
6367
Returns:
6468
[dict]: LIME results: {interpret_label_i: weights on features}

interpretdl/interpreter/lime_prior.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,19 @@ def interpret(self,
9090
9191
Args:
9292
inputs (str): The input file path.
93-
interpret_class (int, optional): The index of class to interpret. If None, the most likely label will be used. Default: None
93+
interpret_class (int, optional): The index of class to interpret. If None, the most likely label will be
94+
used. Default: None
9495
prior_reg_force (float, optional): The regularization force to apply. Default: 1.0
95-
num_samples (int, optional): LIME sampling numbers. Larger number of samples usually gives more accurate interpretation. Default: 1000
96+
num_samples (int, optional): LIME sampling numbers. Larger number of samples usually gives more accurate
97+
interpretation. Default: 1000
9698
batch_size (int, optional): Number of samples to forward each time. Default: 50
97-
resize_to (int, optional): [description]. Images will be rescaled with the shorter edge being `resize_to`. Defaults to 224.
98-
crop_to ([type], optional): [description]. After resize, images will be center cropped to a square image with the size `crop_to`.
99-
If None, no crop will be performed. Defaults to None.
99+
resize_to (int, optional): [description]. Images will be rescaled with the shorter edge being `resize_to`.
100+
Defaults to 224.
101+
crop_to ([type], optional): [description]. After resize, images will be center cropped to a square image
102+
with the size `crop_to`. If None, no crop will be performed. Defaults to None.
100103
visual (bool, optional): Whether or not to visualize the processed image. Default: True
101-
save_path (str, optional): The path to save the processed image. If None, the image will not be saved. Default: None
104+
save_path (str, optional): The path to save the processed image. If None, the image will not be saved.
105+
Default: None
102106
103107
Returns:
104108
[dict]: LIME results: {interpret_label_i: weights on features}

interpretdl/interpreter/score_cam.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,19 @@ def interpret(self,
4141
(TODO) The technical details will be described later.
4242
4343
Args:
44-
inputs (str or list of strs or numpy.ndarray): The input image filepath or a list of filepaths or numpy array of read images.
44+
inputs (str or list of strs or numpy.ndarray): The input image filepath or a list of filepaths or numpy
45+
array of read images.
4546
target_layer_name (str): The target layer to calculate gradients.
46-
labels (list or tuple or numpy.ndarray, optional): The target labels to analyze. The number of labels should be equal to the number of images.
47-
If None, the most likely label for each image will be used. Default: None
48-
resize_to (int, optional): [description]. Images will be rescaled with the shorter edge being `resize_to`. Defaults to 224.
49-
crop_to ([type], optional): [description]. After resize, images will be center cropped to a square image with the size `crop_to`.
50-
If None, no crop will be performed. Defaults to None.
47+
labels (list or tuple or numpy.ndarray, optional): The target labels to analyze. The number of labels
48+
should be equal to the number of images. If None, the most likely label for each image will be used.
49+
Default: None
50+
resize_to (int, optional): [description]. Images will be rescaled with the shorter edge being `resize_to`.
51+
Defaults to 224.
52+
crop_to ([type], optional): [description]. After resize, images will be center cropped to a square image
53+
with the size `crop_to`. If None, no crop will be performed. Defaults to None.
5154
visual (bool, optional): Whether or not to visualize the processed image. Default: True
52-
save_path (str or list of strs or None, optional): The filepath(s) to save the processed image(s). If None, the image will not be saved. Default: None
55+
save_path (str or list of strs or None, optional): The filepath(s) to save the processed image(s). If None,
56+
the image will not be saved. Default: None
5357
5458
Returns:
5559
[numpy.ndarray]: interpretations/heatmap for images

interpretdl/interpreter/smooth_grad.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,22 @@ def __init__(self, paddle_model: callable, device: str = 'gpu:0', use_cuda=None)
3535
the gradients w.r.t. these noised inputs. The final explanation is averaged gradients.
3636
3737
Args:
38-
inputs (str or list of strs or numpy.ndarray): The input image filepath or a list of filepaths or numpy array of read images.
39-
labels (list or tuple or numpy.ndarray, optional): The target labels to analyze. The number of labels should be equal to the number of images. If None, the most likely label for each image will be used. Default: None
38+
inputs (str or list of strs or numpy.ndarray): The input image filepath or a list of filepaths or numpy
39+
array of read images.
40+
labels (list or tuple or numpy.ndarray, optional): The target labels to analyze. The number of labels
41+
should be equal to the number of images. If None, the most likely label for each image will be used.
42+
Default: None
4043
noise_amount (float, optional): Noise level of added noise to the image.
41-
The std of Guassian random noise is noise_amount * (x_max - x_min). Default: 0.1
44+
The std of Guassian random noise is noise_amount * (x_max - x_min).
45+
Default: 0.1
4246
n_samples (int, optional): The number of new images generated by adding noise. Default: 50
43-
resize_to (int, optional): [description]. Images will be rescaled with the shorter edge being `resize_to`. Defaults to 224.
44-
crop_to ([type], optional): [description]. After resize, images will be center cropped to a square image with the size `crop_to`.
45-
If None, no crop will be performed. Defaults to None.
47+
resize_to (int, optional): [description]. Images will be rescaled with the shorter edge being `resize_to`.
48+
Defaults to 224.
49+
crop_to ([type], optional): [description]. After resize, images will be center cropped to a square image
50+
with the size `crop_to`. If None, no crop will be performed. Defaults to None.
4651
visual (bool, optional): Whether or not to visualize the processed image. Default: True
47-
save_path (str or list of strs or None, optional): The filepath(s) to save the processed image(s). If None, the image will not be saved. Default: None
52+
save_path (str or list of strs or None, optional): The filepath(s) to save the processed image(s). If None,
53+
the image will not be saved. Default: None
4854
4955
:return: interpretations/gradients for each image
5056
:rtype: numpy.ndarray

0 commit comments

Comments
 (0)