|
81 | 81 | (resnet_v2.ResNet50V2, 2048, resnet_v2),
|
82 | 82 | (resnet_v2.ResNet101V2, 2048, resnet_v2),
|
83 | 83 | (resnet_v2.ResNet152V2, 2048, resnet_v2),
|
84 |
| - # lpips |
85 |
| - (lpips.LPIPS, 512, lpips), |
86 | 84 | ]
|
87 |
| -MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt", "DenseNet", "NASNet"] |
| 85 | +MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt", "DenseNet", "NASNet", "LPIPS"] |
88 | 86 |
|
89 | 87 | # Add names for `named_parameters`, and add each data format for each model
|
90 | 88 | test_parameters = [
|
@@ -267,3 +265,43 @@ def test_application_classifier_activation(self, app, *_):
|
267 | 265 | )
|
268 | 266 | last_layer_act = model.layers[-1].activation.__name__
|
269 | 267 | self.assertEqual(last_layer_act, "softmax")
|
| 268 | + |
| 269 | + @parameterized.named_parameters( |
| 270 | + [ |
| 271 | + ( |
| 272 | + "{}_{}".format(lpips.LPIPS.__name__, image_data_format), |
| 273 | + image_data_format, |
| 274 | + ) |
| 275 | + for image_data_format in ["channels_first", "channels_last"] |
| 276 | + ] |
| 277 | + ) |
| 278 | + def test_application_lpips(self, image_data_format): |
| 279 | + self.skip_if_invalid_image_data_format_for_model( |
| 280 | + lpips.LPIPS, image_data_format |
| 281 | + ) |
| 282 | + backend.set_image_data_format(image_data_format) |
| 283 | + |
| 284 | + model = lpips.LPIPS() |
| 285 | + output_shape = list(model.outputs[0].shape) |
| 286 | + |
| 287 | + # Two images as input |
| 288 | + self.assertEqual(len(model.input_shape), 2) |
| 289 | + |
| 290 | + # Single output |
| 291 | + self.assertEqual(output_shape, [None]) |
| 292 | + |
| 293 | + # Can run a correct inference on a test image |
| 294 | + if image_data_format == "channels_first": |
| 295 | + shape = model.input_shape[0][2:4] |
| 296 | + else: |
| 297 | + shape = model.input_shape[0][1:3] |
| 298 | + |
| 299 | + x = _get_elephant(shape) |
| 300 | + |
| 301 | + x = lpips.preprocess_input(x) |
| 302 | + y = lpips.preprocess_input(x) |
| 303 | + |
| 304 | + preds = model.predict([x, y]) |
| 305 | + |
| 306 | + # same image so lpips should be 0 |
| 307 | + self.assertEqual(preds, 0.0) |
0 commit comments