-
Notifications
You must be signed in to change notification settings - Fork 667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add interpolate_like for cpu #10544
Open
woaixiaoxiao
wants to merge
18
commits into
master
Choose a base branch
from
add_interpolate_like
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
add interpolate_like for cpu #10544
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
82c7543
success version
woaixiaoxiao 326cde9
clean
woaixiaoxiao 0615dd3
auto format by CI
oneflow-ci-bot 73f8a6c
clean
woaixiaoxiao a615764
Merge branch 'add_interpolate_like' of github.com:Oneflow-Inc/oneflow…
woaixiaoxiao 3569303
clean
woaixiaoxiao daefdb2
clean
woaixiaoxiao 169829b
add test
woaixiaoxiao c745ab6
auto format by CI
oneflow-ci-bot fc3633c
clean
woaixiaoxiao 1bc2807
merge
woaixiaoxiao 630707d
add rst
woaixiaoxiao 7212508
Merge branch 'master' into add_interpolate_like
ShawnXuan f01d542
merge origin
woaixiaoxiao 7abcbcb
change docs
woaixiaoxiao d025ab6
merge
woaixiaoxiao a575d22
fix (#10549)
ShawnXuan edc22fa
Merge branch 'master' into add_interpolate_like
ShawnXuan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -166,6 +166,7 @@ Vision functions | |
deform_conv2d | ||
pad | ||
interpolate | ||
interpolate_like | ||
upsample | ||
grid_sample | ||
affine_grid | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import math | ||
import warnings | ||
from typing import Optional, Tuple, Union | ||
|
||
import oneflow as flow | ||
from oneflow.framework.tensor import register_tensor_op | ||
from oneflow.nn.modules.module import Module | ||
|
||
|
||
class InterpolateLike: | ||
def __init__( | ||
self, mode: str = "nearest", align_corners: Optional[bool] = None, | ||
): | ||
if mode in ("nearest", "area") and align_corners is not None: | ||
raise ValueError( | ||
"align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear" | ||
) | ||
self.mode = mode | ||
if align_corners == None: | ||
align_corners = False | ||
self.align_corners = align_corners | ||
if self.mode not in ( | ||
"nearest", | ||
"bilinear", | ||
"linear", | ||
"area", | ||
"bicubic", | ||
"trilinear", | ||
): | ||
raise ValueError( | ||
'interpolation must be "nearest" or "bilinear" or "linear" or "area" or "bicubic" or "trilinear".' | ||
) | ||
if self.mode == "nearest" and self.align_corners: | ||
raise ValueError('interpolation "nearest" does not support align_corners.') | ||
|
||
def forward(self, x, like): | ||
if len(x.shape) == 3 and self.mode == "bilinear": | ||
raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") | ||
if len(x.shape) == 3 and self.mode == "trilinear": | ||
raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") | ||
if len(x.shape) == 4 and self.mode == "linear": | ||
raise NotImplementedError("Got 4D input, but linear mode needs 3D input") | ||
if len(x.shape) == 4 and self.mode == "trilinear": | ||
raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") | ||
if len(x.shape) == 5 and self.mode == "linear": | ||
raise NotImplementedError("Got 5D input, but linear mode needs 3D input") | ||
if len(x.shape) == 5 and self.mode == "bilinear": | ||
raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") | ||
|
||
dim = len(x.shape) - 2 | ||
if len(x.shape) == 3 and self.mode == "nearest": | ||
return flow._C.upsample_nearest_1d(x, like, data_format="channels_first",) | ||
if len(x.shape) == 4 and self.mode == "nearest": | ||
return flow._C.upsample_nearest_2d(x, like, data_format="channels_first",) | ||
if len(x.shape) == 5 and self.mode == "nearest": | ||
return flow._C.upsample_nearest_3d(x, like, data_format="channels_first",) | ||
|
||
raise NotImplementedError( | ||
"Input Error: Only 3D, 4D and 5D input Tensors supported" | ||
" (got {}D) for the modes: nearest" | ||
" (got {})".format(len(x.shape), self.mode) | ||
) | ||
|
||
|
||
def interpolate_like( | ||
input, like, mode="nearest", align_corners=None, | ||
): | ||
"""The interface is consistent with PyTorch. | ||
|
||
The documentation is referenced from: https://pytorch.org/docs/1.10/_modules/torch/nn/functional.html#interpolate. | ||
|
||
|
||
Down/up samples the input to the same shape as the `like` tensor. | ||
|
||
The algorithm used for interpolation is determined by :attr:`mode`. | ||
|
||
Currently temporal, spatial and volumetric sampling are supported, i.e. | ||
expected inputs are 3-D, 4-D or 5-D in shape. | ||
|
||
The input dimensions are interpreted in the form: | ||
`mini-batch x channels x [optional depth] x [optional height] x width`. | ||
|
||
The modes available for resizing are: `nearest`, `linear` (3D-only), | ||
`bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area` | ||
|
||
Args: | ||
input (Tensor): the input tensor | ||
like (Tensor): the like tensor | ||
mode (str): algorithm used for upsampling: | ||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | | ||
``'trilinear'`` | ``'area'``. Default: ``'nearest'`` | ||
align_corners (bool, optional): Geometrically, we consider the pixels of the | ||
input and output as squares rather than points. | ||
If set to ``True``, the input and output tensors are aligned by the | ||
center points of their corner pixels, preserving the values at the corner pixels. | ||
If set to ``False``, the input and output tensors are aligned by the corner | ||
points of their corner pixels, and the interpolation uses edge value padding | ||
for out-of-boundary values. This only has an effect when :attr:`mode` | ||
is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``. | ||
Default: ``False`` | ||
|
||
.. note:: | ||
With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce | ||
negative values or values greater than 255 for images. | ||
Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot | ||
when displaying the image. | ||
|
||
.. warning:: | ||
With ``align_corners = True``, the linearly interpolating modes | ||
(`linear`, `bilinear`, and `trilinear`) don't proportionally align the | ||
output and input pixels, and thus the output values can depend on the | ||
input size. This was the default behavior for these modes up to version | ||
0.3.1. Since then, the default behavior is ``align_corners = False``. | ||
See :class:`~torch.nn.Upsample` for concrete examples on how this | ||
affects the outputs. | ||
|
||
For example: | ||
|
||
.. code-block:: python | ||
|
||
>>> import oneflow as flow | ||
>>> import numpy as np | ||
|
||
>>> input = flow.tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32) | ||
>>> like = flow.randn(1, 1, 4, 4) | ||
>>> output = flow.nn.functional.interpolate_like(input, like, mode="nearest") | ||
>>> output | ||
tensor([[[[1., 1., 2., 2.], | ||
[1., 1., 2., 2.], | ||
[3., 3., 4., 4.], | ||
[3., 3., 4., 4.]]]], dtype=oneflow.float32) | ||
|
||
""" | ||
return InterpolateLike(mode=mode, align_corners=align_corners,).forward(input, like) | ||
|
||
|
||
if __name__ == "__main__": | ||
import doctest | ||
|
||
doctest.testmod(raise_on_error=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import unittest | ||
from collections import OrderedDict | ||
|
||
import numpy as np | ||
from oneflow.test_utils.test_util import GenArgList | ||
from oneflow.test_utils.automated_test_util import * | ||
|
||
import oneflow as flow | ||
import oneflow.unittest | ||
|
||
|
||
def _test_upsample_nearest_2d_like(test_case, shape_scale): | ||
input_shape, out_like_shape = shape_scale | ||
# init data by shape | ||
inputs = np.random.randn(*input_shape) | ||
out_like = np.random.randn(*out_like_shape) | ||
|
||
# get numpy function | ||
def nearest_upsample_by_np(inputs, out_like): | ||
in_height, in_width = inputs.shape[-2:] | ||
out_height, out_width = out_like.shape[-2:] | ||
scale_h = out_height / in_height | ||
scale_w = out_width / in_width | ||
output = np.zeros(out_like.shape) | ||
for i in range(out_height): | ||
for j in range(out_width): | ||
src_i = int(min(i / scale_h, in_height - 1)) | ||
src_j = int(min(j / scale_w, in_width - 1)) | ||
output[..., i, j] = inputs[..., src_i, src_j] | ||
return output | ||
|
||
# oneflow | ||
cpu_input = flow.tensor(inputs, dtype=flow.float32) | ||
cpu_out_like = flow.tensor(out_like, dtype=flow.float32) | ||
cpu_output = flow.nn.functional.interpolate_like( | ||
cpu_input, like=cpu_out_like, mode="nearest" | ||
) | ||
# numpy | ||
np_output = nearest_upsample_by_np(inputs, out_like) | ||
# compare result between oneflow and numpy | ||
test_case.assertTrue(np.allclose(np_output, cpu_output.numpy(), 0.001, 0.001)) | ||
|
||
|
||
@flow.unittest.skip_unless_1n1d() | ||
class TestUpsample2dLike(flow.unittest.TestCase): | ||
def test_upsample2d_like(test_case): | ||
arg_dict = OrderedDict() | ||
arg_dict["test_fun"] = [ | ||
_test_upsample_nearest_2d_like, | ||
] | ||
arg_dict["shape_scale"] = [ | ||
((1, 1, 2, 2), (1, 1, 3, 3)), | ||
((5, 3, 6, 4), (5, 3, 9, 6)), | ||
((2, 3, 2, 4), (2, 3, 3, 5)), | ||
] | ||
for arg in GenArgList(arg_dict): | ||
arg[0](test_case, *arg[1:]) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考interpolate,需要在
docs/source/nn.functional.rst
里加一下interpolate_likeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已补充