Skip to content

Commit 4709571

Browse files
author
Tristan Stevens
committed
added proper lpips test
1 parent 670bc55 commit 4709571

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

keras/src/applications/applications_test.py

+41-3
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,8 @@
8181
(resnet_v2.ResNet50V2, 2048, resnet_v2),
8282
(resnet_v2.ResNet101V2, 2048, resnet_v2),
8383
(resnet_v2.ResNet152V2, 2048, resnet_v2),
84-
# lpips
85-
(lpips.LPIPS, 512, lpips),
8684
]
87-
MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt", "DenseNet", "NASNet"]
85+
MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt", "DenseNet", "NASNet", "LPIPS"]
8886

8987
# Add names for `named_parameters`, and add each data format for each model
9088
test_parameters = [
@@ -267,3 +265,43 @@ def test_application_classifier_activation(self, app, *_):
267265
)
268266
last_layer_act = model.layers[-1].activation.__name__
269267
self.assertEqual(last_layer_act, "softmax")
268+
269+
@parameterized.named_parameters(
270+
[
271+
(
272+
"{}_{}".format(lpips.LPIPS.__name__, image_data_format),
273+
image_data_format,
274+
)
275+
for image_data_format in ["channels_first", "channels_last"]
276+
]
277+
)
278+
def test_application_lpips(self, image_data_format):
279+
self.skip_if_invalid_image_data_format_for_model(
280+
lpips.LPIPS, image_data_format
281+
)
282+
backend.set_image_data_format(image_data_format)
283+
284+
model = lpips.LPIPS()
285+
output_shape = list(model.outputs[0].shape)
286+
287+
# Two images as input
288+
self.assertEqual(len(model.input_shape), 2)
289+
290+
# Single output
291+
self.assertEqual(output_shape, [None])
292+
293+
# Can run a correct inference on a test image
294+
if image_data_format == "channels_first":
295+
shape = model.input_shape[0][2:4]
296+
else:
297+
shape = model.input_shape[0][1:3]
298+
299+
x = _get_elephant(shape)
300+
301+
x = lpips.preprocess_input(x)
302+
y = lpips.preprocess_input(x)
303+
304+
preds = model.predict([x, y])
305+
306+
# same image so lpips should be 0
307+
self.assertEqual(preds, 0.0)

keras/src/applications/lpips.py

+7
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ def LPIPS(
8888
f"Got network_type={network_type}"
8989
)
9090

91+
if backend.image_data_format() == "channels_first":
92+
raise ValueError(
93+
"LPIPS does not support the `channels_first` image data "
94+
"format. Switch to `channels_last` by editing your local "
95+
"config file at ~/.keras/keras.json"
96+
)
97+
9198
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
9299
raise ValueError(
93100
"The `weights` argument should be either "

0 commit comments

Comments
 (0)