Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perceptual Similarity loss #20844

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras/api/_tf_keras/keras/applications/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from keras.api.applications import imagenet_utils
from keras.api.applications import inception_resnet_v2
from keras.api.applications import inception_v3
from keras.api.applications import lpips
from keras.api.applications import mobilenet
from keras.api.applications import mobilenet_v2
from keras.api.applications import mobilenet_v3
Expand Down Expand Up @@ -46,6 +47,7 @@
from keras.src.applications.efficientnet_v2 import EfficientNetV2S
from keras.src.applications.inception_resnet_v2 import InceptionResNetV2
from keras.src.applications.inception_v3 import InceptionV3
from keras.src.applications.lpips import LPIPS
from keras.src.applications.mobilenet import MobileNet
from keras.src.applications.mobilenet_v2 import MobileNetV2
from keras.src.applications.mobilenet_v3 import MobileNetV3Large
Expand Down
8 changes: 8 additions & 0 deletions keras/api/_tf_keras/keras/applications/lpips/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""DO NOT EDIT.

This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras.src.applications.lpips import LPIPS
from keras.src.applications.lpips import preprocess_input
2 changes: 2 additions & 0 deletions keras/api/applications/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from keras.api.applications import imagenet_utils
from keras.api.applications import inception_resnet_v2
from keras.api.applications import inception_v3
from keras.api.applications import lpips
from keras.api.applications import mobilenet
from keras.api.applications import mobilenet_v2
from keras.api.applications import mobilenet_v3
Expand Down Expand Up @@ -46,6 +47,7 @@
from keras.src.applications.efficientnet_v2 import EfficientNetV2S
from keras.src.applications.inception_resnet_v2 import InceptionResNetV2
from keras.src.applications.inception_v3 import InceptionV3
from keras.src.applications.lpips import LPIPS
from keras.src.applications.mobilenet import MobileNet
from keras.src.applications.mobilenet_v2 import MobileNetV2
from keras.src.applications.mobilenet_v3 import MobileNetV3Large
Expand Down
8 changes: 8 additions & 0 deletions keras/api/applications/lpips/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""DO NOT EDIT.

This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras.src.applications.lpips import LPIPS
from keras.src.applications.lpips import preprocess_input
43 changes: 42 additions & 1 deletion keras/src/applications/applications_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from keras.src.applications import efficientnet_v2
from keras.src.applications import inception_resnet_v2
from keras.src.applications import inception_v3
from keras.src.applications import lpips
from keras.src.applications import mobilenet
from keras.src.applications import mobilenet_v2
from keras.src.applications import mobilenet_v3
Expand Down Expand Up @@ -81,7 +82,7 @@
(resnet_v2.ResNet101V2, 2048, resnet_v2),
(resnet_v2.ResNet152V2, 2048, resnet_v2),
]
MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt", "DenseNet", "NASNet"]
MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt", "DenseNet", "NASNet", "LPIPS"]

# Add names for `named_parameters`, and add each data format for each model
test_parameters = [
Expand Down Expand Up @@ -264,3 +265,43 @@ def test_application_classifier_activation(self, app, *_):
)
last_layer_act = model.layers[-1].activation.__name__
self.assertEqual(last_layer_act, "softmax")

@parameterized.named_parameters(
[
(
"{}_{}".format(lpips.LPIPS.__name__, image_data_format),
image_data_format,
)
for image_data_format in ["channels_first", "channels_last"]
]
)
def test_application_lpips(self, image_data_format):
self.skip_if_invalid_image_data_format_for_model(
lpips.LPIPS, image_data_format
)
backend.set_image_data_format(image_data_format)

model = lpips.LPIPS()
output_shape = list(model.outputs[0].shape)

# Two images as input
self.assertEqual(len(model.input_shape), 2)

# Single output
self.assertEqual(output_shape, [None])

# Can run a correct inference on a test image
if image_data_format == "channels_first":
shape = model.input_shape[0][2:4]
else:
shape = model.input_shape[0][1:3]

x = _get_elephant(shape)

x = lpips.preprocess_input(x)
y = lpips.preprocess_input(x)

preds = model.predict([x, y])

# same image so lpips should be 0
self.assertEqual(preds, 0.0)
189 changes: 189 additions & 0 deletions keras/src/applications/lpips.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
from keras.src import backend
from keras.src import layers
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.applications import imagenet_utils
from keras.src.applications import vgg16
from keras.src.models import Functional
from keras.src.utils import file_utils

WEIGHTS_PATH = (
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct weights path should still be added here. Currently I've put a placeholder. I uploaded the LPIPS weights to Hugging Face, as I'm not sure how to upload to storage.googleapis.com.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently added test will fail because it cannot find the weights at that location. Easy way to try out the tests locally is to change this line:

with

model = lpips.LPIPS(weights=<path_to_local_copy>)

"https://storage.googleapis.com/tensorflow/keras-applications/"
"lpips/lpips_vgg16_weights.h5"
) # TODO: store weights at this location


def vgg_backbone(layer_names):
"""VGG backbone for LPIPS.

Args:
layer_names: list of layer names to extract features from

Returns:
Functional model with outputs at specified layers
"""
vgg = vgg16.VGG16(include_top=False, weights=None)
outputs = [
layer.output for layer in vgg.layers if layer.name in layer_names
]
return Functional(vgg.input, outputs)


def linear_model(channels):
"""Get the linear head model for LPIPS.
Combines feature differences from VGG backbone.

Args:
channels: list of channel sizes for feature differences

Returns:
Functional model
"""
inputs, outputs = [], []
for ii, channel in enumerate(channels):
x = layers.Input(shape=(None, None, channel))
y = layers.Dropout(rate=0.5)(x)
y = layers.Conv2D(
filters=1,
kernel_size=1,
use_bias=False,
name=f"linear_{ii}",
)(y)
inputs.append(x)
outputs.append(y)

model = Functional(inputs=inputs, outputs=outputs, name="linear_model")
return model


@keras_export(["keras.applications.lpips.LPIPS", "keras.applications.LPIPS"])
def LPIPS(
weights="imagenet",
input_tensor=None,
input_shape=None,
network_type="vgg",
name="lpips",
):
"""Instantiates the LPIPS model.

Reference:
- [The Unreasonable Effectiveness of Deep Features as a Perceptual Metric](
https://arxiv.org/abs/1801.03924)

Args:
weights: one of `None` (random initialization),
`"imagenet"` (pre-training on ImageNet),
or the path to the weights file to be loaded.
input_tensor: optional Keras tensor for model input
input_shape: optional shape tuple, defaults to (None, None, 3)
network_type: backbone network type (currently only 'vgg' supported)
name: model name string

Returns:
A `Model` instance.
"""
if network_type != "vgg":
raise ValueError(
"Currently only VGG backbone is supported. "
f"Got network_type={network_type}"
)

if backend.image_data_format() == "channels_first":
raise ValueError(
"LPIPS does not support the `channels_first` image data "
"format. Switch to `channels_last` by editing your local "
"config file at ~/.keras/keras.json"
)

if not (weights in {"imagenet", None} or file_utils.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), 'imagenet' "
"(pre-training on ImageNet), "
"or the path to the weights file to be loaded."
)

# Define inputs
if input_tensor is None:
img_input1 = layers.Input(
shape=input_shape or (None, None, 3), name="input1"
)
img_input2 = layers.Input(
shape=input_shape or (None, None, 3), name="input2"
)
else:
if not backend.is_keras_tensor(input_tensor):
img_input1 = layers.Input(tensor=input_tensor, shape=input_shape)
img_input2 = layers.Input(tensor=input_tensor, shape=input_shape)
else:
img_input1 = input_tensor
img_input2 = input_tensor

# VGG feature extraction
vgg_layers = [
"block1_conv2",
"block2_conv2",
"block3_conv3",
"block4_conv3",
"block5_conv3",
]
vgg_net = vgg_backbone(vgg_layers)

feat1 = vgg_net(img_input1)
feat2 = vgg_net(img_input2)

def normalize(x, eps: float = 1e-8):
return x * ops.rsqrt(
eps + ops.sum(ops.square(x), axis=-1, keepdims=True)
)

norm1 = [normalize(f) for f in feat1]
norm2 = [normalize(f) for f in feat2]

diffs = [ops.square(t1 - t2) for t1, t2 in zip(norm1, norm2)]

channels = [f.shape[-1] for f in feat1]

linear_net = linear_model(channels)

lin_out = linear_net(diffs)

spatial_average = [
ops.mean(t, axis=[1, 2], keepdims=False) for t in lin_out
]

# need a layer to convert list to tensor
output = layers.Lambda(lambda x: ops.convert_to_tensor(x))(spatial_average)

output = ops.squeeze(ops.sum(output, axis=0), axis=-1)

# Create model
model = Functional([img_input1, img_input2], output, name=name)

# Load weights
if weights == "imagenet":
weights_path = file_utils.get_file(
"lpips_vgg16_weights.h5",
WEIGHTS_PATH,
cache_subdir="models",
file_hash=None, # TODO: add hash
)
model.load_weights(weights_path)
elif weights is not None:
model.load_weights(weights)

return model


@keras_export("keras.applications.lpips.preprocess_input")
def preprocess_input(x, data_format=None):
return imagenet_utils.preprocess_input(
x, data_format=data_format, mode="torch"
)


preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
mode="",
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE,
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC,
)
Loading