1111from torchcam import cams
1212
1313
14- def _forward (model , input_tensor ):
15- if model .training :
16- scores = model (input_tensor )
17- else :
18- with torch .no_grad ():
19- scores = model (input_tensor )
20-
21- return scores
22-
23-
2414class CAMCoreTester (unittest .TestCase ):
2515 def _verify_cam (self , cam ):
2616 # Simple verifications
@@ -50,11 +40,11 @@ def _test_extractor(self, extractor, model):
5040 img_tensor = self ._get_img_tensor ()
5141
5242 # Check that a batch of 2 cannot be accepted
53- _ = _forward ( model , torch .stack ((img_tensor , img_tensor )))
43+ _ = model ( torch .stack ((img_tensor , img_tensor )))
5444 self .assertRaises (ValueError , extractor , 0 )
5545
5646 # Correct forward
57- scores = _forward ( model , img_tensor .unsqueeze (0 ))
47+ scores = model ( img_tensor .unsqueeze (0 ))
5848
5949 # Check incorrect class index
6050 self .assertRaises (ValueError , extractor , - 1 )
@@ -68,20 +58,24 @@ def _test_extractor(self, extractor, model):
6858
6959 def _test_cam (self , name ):
7060 # Get a pretrained model
71- model = resnet18 (pretrained = False ).eval ()
72- conv_layer = 'layer4.1.relu '
61+ model = mobilenet_v2 (pretrained = False ).eval ()
62+ conv_layer = None if name == "CAM" else 'features.16.conv.3 '
7363
64+ kwargs = {}
65+ # Speed up testing by reducing the number of samples
66+ if name in ['SSCAM' , 'ISCAM' ]:
67+ kwargs ['num_samples' ] = 4
7468 # Hook the corresponding layer in the model
75- extractor = cams .__dict__ [name ](model , conv_layer )
69+ extractor = cams .__dict__ [name ](model , conv_layer , ** kwargs )
7670
7771 with torch .no_grad ():
7872 self ._test_extractor (extractor , model )
7973
8074 def _test_gradcam (self , name ):
8175
8276 # Get a pretrained model
83- model = mobilenet_v2 (pretrained = False )
84- conv_layer = 'features.17.conv.3 '
77+ model = mobilenet_v2 (pretrained = False ). eval ()
78+ conv_layer = 'features.18.0 '
8579
8680 # Hook the corresponding layer in the model
8781 extractor = cams .__dict__ [name ](model , conv_layer )
0 commit comments