-
Notifications
You must be signed in to change notification settings - Fork 19.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Perceptual Similarity loss #20844
Open
tristan-deep
wants to merge
7
commits into
keras-team:master
Choose a base branch
from
tristan-deep:lpips
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Perceptual Similarity loss #20844
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
f282310
LPIPS loss
9e38046
store lpips weights TODO
adaac26
add api gen changes
07771f1
simplify network calls
eb4819d
removed lpips from losses
670bc55
include lpips in applications test
4709571
added proper lpips test
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""DO NOT EDIT. | ||
|
||
This file was autogenerated. Do not edit it by hand, | ||
since your modifications would be overwritten. | ||
""" | ||
|
||
from keras.src.applications.lpips import LPIPS | ||
from keras.src.applications.lpips import preprocess_input |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""DO NOT EDIT. | ||
|
||
This file was autogenerated. Do not edit it by hand, | ||
since your modifications would be overwritten. | ||
""" | ||
|
||
from keras.src.applications.lpips import LPIPS | ||
from keras.src.applications.lpips import preprocess_input |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
from keras.src import backend | ||
from keras.src import layers | ||
from keras.src import ops | ||
from keras.src.api_export import keras_export | ||
from keras.src.applications import imagenet_utils | ||
from keras.src.applications import vgg16 | ||
from keras.src.models import Functional | ||
from keras.src.utils import file_utils | ||
|
||
WEIGHTS_PATH = ( | ||
"https://storage.googleapis.com/tensorflow/keras-applications/" | ||
"lpips/lpips_vgg16_weights.h5" | ||
) # TODO: store weights at this location | ||
|
||
|
||
def vgg_backbone(layer_names): | ||
"""VGG backbone for LPIPS. | ||
|
||
Args: | ||
layer_names: list of layer names to extract features from | ||
|
||
Returns: | ||
Functional model with outputs at specified layers | ||
""" | ||
vgg = vgg16.VGG16(include_top=False, weights=None) | ||
outputs = [ | ||
layer.output for layer in vgg.layers if layer.name in layer_names | ||
] | ||
return Functional(vgg.input, outputs) | ||
|
||
|
||
def linear_model(channels): | ||
"""Get the linear head model for LPIPS. | ||
Combines feature differences from VGG backbone. | ||
|
||
Args: | ||
channels: list of channel sizes for feature differences | ||
|
||
Returns: | ||
Functional model | ||
""" | ||
inputs, outputs = [], [] | ||
for ii, channel in enumerate(channels): | ||
x = layers.Input(shape=(None, None, channel)) | ||
y = layers.Dropout(rate=0.5)(x) | ||
y = layers.Conv2D( | ||
filters=1, | ||
kernel_size=1, | ||
use_bias=False, | ||
name=f"linear_{ii}", | ||
)(y) | ||
inputs.append(x) | ||
outputs.append(y) | ||
|
||
model = Functional(inputs=inputs, outputs=outputs, name="linear_model") | ||
return model | ||
|
||
|
||
@keras_export(["keras.applications.lpips.LPIPS", "keras.applications.LPIPS"]) | ||
def LPIPS( | ||
weights="imagenet", | ||
input_tensor=None, | ||
input_shape=None, | ||
network_type="vgg", | ||
name="lpips", | ||
): | ||
"""Instantiates the LPIPS model. | ||
|
||
Reference: | ||
- [The Unreasonable Effectiveness of Deep Features as a Perceptual Metric]( | ||
https://arxiv.org/abs/1801.03924) | ||
|
||
Args: | ||
weights: one of `None` (random initialization), | ||
`"imagenet"` (pre-training on ImageNet), | ||
or the path to the weights file to be loaded. | ||
input_tensor: optional Keras tensor for model input | ||
input_shape: optional shape tuple, defaults to (None, None, 3) | ||
network_type: backbone network type (currently only 'vgg' supported) | ||
name: model name string | ||
|
||
Returns: | ||
A `Model` instance. | ||
""" | ||
if network_type != "vgg": | ||
raise ValueError( | ||
"Currently only VGG backbone is supported. " | ||
f"Got network_type={network_type}" | ||
) | ||
|
||
if backend.image_data_format() == "channels_first": | ||
raise ValueError( | ||
"LPIPS does not support the `channels_first` image data " | ||
"format. Switch to `channels_last` by editing your local " | ||
"config file at ~/.keras/keras.json" | ||
) | ||
|
||
if not (weights in {"imagenet", None} or file_utils.exists(weights)): | ||
raise ValueError( | ||
"The `weights` argument should be either " | ||
"`None` (random initialization), 'imagenet' " | ||
"(pre-training on ImageNet), " | ||
"or the path to the weights file to be loaded." | ||
) | ||
|
||
# Define inputs | ||
if input_tensor is None: | ||
img_input1 = layers.Input( | ||
shape=input_shape or (None, None, 3), name="input1" | ||
) | ||
img_input2 = layers.Input( | ||
shape=input_shape or (None, None, 3), name="input2" | ||
) | ||
else: | ||
if not backend.is_keras_tensor(input_tensor): | ||
img_input1 = layers.Input(tensor=input_tensor, shape=input_shape) | ||
img_input2 = layers.Input(tensor=input_tensor, shape=input_shape) | ||
else: | ||
img_input1 = input_tensor | ||
img_input2 = input_tensor | ||
|
||
# VGG feature extraction | ||
vgg_layers = [ | ||
"block1_conv2", | ||
"block2_conv2", | ||
"block3_conv3", | ||
"block4_conv3", | ||
"block5_conv3", | ||
] | ||
vgg_net = vgg_backbone(vgg_layers) | ||
|
||
feat1 = vgg_net(img_input1) | ||
feat2 = vgg_net(img_input2) | ||
|
||
def normalize(x, eps: float = 1e-8): | ||
return x * ops.rsqrt( | ||
eps + ops.sum(ops.square(x), axis=-1, keepdims=True) | ||
) | ||
|
||
norm1 = [normalize(f) for f in feat1] | ||
norm2 = [normalize(f) for f in feat2] | ||
|
||
diffs = [ops.square(t1 - t2) for t1, t2 in zip(norm1, norm2)] | ||
|
||
channels = [f.shape[-1] for f in feat1] | ||
|
||
linear_net = linear_model(channels) | ||
|
||
lin_out = linear_net(diffs) | ||
|
||
spatial_average = [ | ||
ops.mean(t, axis=[1, 2], keepdims=False) for t in lin_out | ||
] | ||
|
||
# need a layer to convert list to tensor | ||
output = layers.Lambda(lambda x: ops.convert_to_tensor(x))(spatial_average) | ||
|
||
output = ops.squeeze(ops.sum(output, axis=0), axis=-1) | ||
|
||
# Create model | ||
model = Functional([img_input1, img_input2], output, name=name) | ||
|
||
# Load weights | ||
if weights == "imagenet": | ||
weights_path = file_utils.get_file( | ||
"lpips_vgg16_weights.h5", | ||
WEIGHTS_PATH, | ||
cache_subdir="models", | ||
file_hash=None, # TODO: add hash | ||
) | ||
model.load_weights(weights_path) | ||
elif weights is not None: | ||
model.load_weights(weights) | ||
|
||
return model | ||
|
||
|
||
@keras_export("keras.applications.lpips.preprocess_input") | ||
def preprocess_input(x, data_format=None): | ||
return imagenet_utils.preprocess_input( | ||
x, data_format=data_format, mode="torch" | ||
) | ||
|
||
|
||
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( | ||
mode="", | ||
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE, | ||
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct weights path should still be added here. Currently I've put a placeholder. I uploaded the LPIPS weights to Hugging Face, as I'm not sure how to upload to
storage.googleapis.com
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently added test will fail because it cannot find the weights at that location. Easy way to try out the tests locally is to change this line:
keras/keras/src/applications/applications_test.py
Line 284 in 4709571
with