Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add interpolate_like for cpu #10544

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ Vision functions
deform_conv2d
pad
interpolate
interpolate_like
upsample
grid_sample
affine_grid
Expand Down
8 changes: 5 additions & 3 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1789,9 +1789,11 @@
bind_python: False

- name: "upsample_nearest_2d"
signature:
'Tensor (Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None,
String data_format="channels_first") => UpsampleNearest2D'
signature: [
'Tensor (Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None,
String data_format="channels_first") => UpsampleNearest2D',
'Tensor (Tensor x, Tensor like, String data_format="channels_first") => UpsampleNearest2D'
]
bind_python: True

- name: "upsample_nearest_2d_grad"
Expand Down
21 changes: 20 additions & 1 deletion oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1991,6 +1991,24 @@ class UpsampleNearest2DFunctor {
std::shared_ptr<OpExpr> op_;
};

class UpsampleNearestLike2DFunctor {
public:
UpsampleNearestLike2DFunctor() {
op_ = CHECK_JUST(
one::OpBuilder("upsample_nearest_2d").Input("x").Input("like").Output("y").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& like,
const std::string& data_format) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("data_format");
attrs.SetAllAttrs(data_format);
return OpInterpUtil::Dispatch<Tensor>(*op_, {x, like}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class UpsampleNearest2DGradFunctor {
public:
UpsampleNearest2DGradFunctor() {
Expand Down Expand Up @@ -4118,7 +4136,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::UnfoldTensorFunctor>("UnfoldTensor");
m.add_functor<impl::UnfoldTensorGradFunctor>("UnfoldTensorGrad");
m.add_functor<impl::UpsampleGradFunctor>("UpsampleGrad");
m.add_functor<impl::UpsampleNearest2DFunctor>("UpsampleNearest2D");
m.add_functor<impl::UpsampleNearest2DFunctor, impl::UpsampleNearestLike2DFunctor>(
"UpsampleNearest2D");
m.add_functor<impl::UpsampleNearest2DGradFunctor>("UpsampleNearest2DGrad");
m.add_functor<impl::UpsampleBilinear2DFunctor>("UpsampleBilinear2D");
m.add_functor<impl::UpsampleBilinear2DGradFunctor>("UpsampleBilinear2DGrad");
Expand Down
2 changes: 1 addition & 1 deletion oneflow/user/kernels/upsample_nearest_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class UpsampleNearest2DCPUKernel final : public user_op::OpKernel {
const int64_t out_height = y_tensor->shape_view().At(2);
const int64_t out_width = y_tensor->shape_view().At(3);
const int64_t elem_cnt = y_tensor->shape_view().elem_cnt();
if (!output_size.empty()) {
if (!output_size.empty() || ctx->Tensor4ArgNameAndIndex("like", 0)) {
height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);
width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);
}
Expand Down
12 changes: 12 additions & 0 deletions oneflow/user/ops/upsample_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ typename std::enable_if<(N <= 3), Maybe<void>>::type UpsamplingInferLogicalDesc(
user_op::InferContext* ctx, const std::string& func_name) {
const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0);
user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0);
if (ctx->has_input("like", 0)) {
const user_op::TensorDesc& like_desc = ctx->InputTensorDesc("like", 0);
int64_t like_num_axes = like_desc.shape().NumAxes();
CHECK_GT_OR_RETURN(like_num_axes, N)
<< "like shape size should > " << N << ", but got " << like_desc.shape().ToString();
Shape output_shape = x_desc.shape();
for (int i = 0; i < N; ++i) {
output_shape[i + 2] = like_desc.shape().At(like_num_axes - N + i);
}
y_desc->set_shape(output_shape);
return Maybe<void>::Ok();
}
if (N == 1) {
CHECK_OR_RETURN(ctx->Attr<std::string>("data_format") == "channels_first"
&& x_desc.shape().NumAxes() == (N + 2))
Expand Down
1 change: 1 addition & 0 deletions python/oneflow/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
limitations under the License.
"""
from oneflow.nn.modules.interpolate import interpolate
from oneflow.nn.modules.interpolate_like import interpolate_like
from oneflow.nn.modules.affine_grid import affine_grid
from oneflow.nn.modules.grid_sample import grid_sample
from oneflow.nn.modules.sparse_softmax_cross_entropy import sparse_softmax_cross_entropy
Expand Down
156 changes: 156 additions & 0 deletions python/oneflow/nn/modules/interpolate_like.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import math
import warnings
from typing import Optional, Tuple, Union

import oneflow as flow
from oneflow.framework.tensor import register_tensor_op
from oneflow.nn.modules.module import Module


class InterpolateLike:
def __init__(
self, mode: str = "nearest", align_corners: Optional[bool] = None,
):
if mode in ("nearest", "area") and align_corners is not None:
raise ValueError(
"align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
)
self.mode = mode
if align_corners == None:
align_corners = False
self.align_corners = align_corners
if self.mode not in (
"nearest",
"bilinear",
"linear",
"area",
"bicubic",
"trilinear",
):
raise ValueError(
'interpolation must be "nearest" or "bilinear" or "linear" or "area" or "bicubic" or "trilinear".'
)
if self.mode == "nearest" and self.align_corners:
raise ValueError('interpolation "nearest" does not support align_corners.')

def forward(self, x, like):
if len(x.shape) == 3 and self.mode == "bilinear":
raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input")
if len(x.shape) == 3 and self.mode == "trilinear":
raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input")
if len(x.shape) == 4 and self.mode == "linear":
raise NotImplementedError("Got 4D input, but linear mode needs 3D input")
if len(x.shape) == 4 and self.mode == "trilinear":
raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input")
if len(x.shape) == 5 and self.mode == "linear":
raise NotImplementedError("Got 5D input, but linear mode needs 3D input")
if len(x.shape) == 5 and self.mode == "bilinear":
raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input")

dim = len(x.shape) - 2
if len(x.shape) == 3 and self.mode == "nearest":
return flow._C.upsample_nearest_1d(x, like, data_format="channels_first",)
if len(x.shape) == 4 and self.mode == "nearest":
return flow._C.upsample_nearest_2d(x, like, data_format="channels_first",)
if len(x.shape) == 5 and self.mode == "nearest":
return flow._C.upsample_nearest_3d(x, like, data_format="channels_first",)

raise NotImplementedError(
"Input Error: Only 3D, 4D and 5D input Tensors supported"
" (got {}D) for the modes: nearest"
" (got {})".format(len(x.shape), self.mode)
)


def interpolate_like(
input, like, mode="nearest", align_corners=None,
):
"""The interface is consistent with PyTorch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考interpolate,需要在docs/source/nn.functional.rst 里加一下interpolate_like

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已补充


The documentation is referenced from: https://pytorch.org/docs/1.10/_modules/torch/nn/functional.html#interpolate.


Down/up samples the input to the same shape as the `like` tensor.

The algorithm used for interpolation is determined by :attr:`mode`.

Currently temporal, spatial and volumetric sampling are supported, i.e.
expected inputs are 3-D, 4-D or 5-D in shape.

The input dimensions are interpreted in the form:
`mini-batch x channels x [optional depth] x [optional height] x width`.

The modes available for resizing are: `nearest`, `linear` (3D-only),
`bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`

Args:
input (Tensor): the input tensor
like (Tensor): the like tensor
mode (str): algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'`` | ``'area'``. Default: ``'nearest'``
align_corners (bool, optional): Geometrically, we consider the pixels of the
input and output as squares rather than points.
If set to ``True``, the input and output tensors are aligned by the
center points of their corner pixels, preserving the values at the corner pixels.
If set to ``False``, the input and output tensors are aligned by the corner
points of their corner pixels, and the interpolation uses edge value padding
for out-of-boundary values. This only has an effect when :attr:`mode`
is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``.
Default: ``False``

.. note::
With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce
negative values or values greater than 255 for images.
Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot
when displaying the image.

.. warning::
With ``align_corners = True``, the linearly interpolating modes
(`linear`, `bilinear`, and `trilinear`) don't proportionally align the
output and input pixels, and thus the output values can depend on the
input size. This was the default behavior for these modes up to version
0.3.1. Since then, the default behavior is ``align_corners = False``.
See :class:`~torch.nn.Upsample` for concrete examples on how this
affects the outputs.

For example:

.. code-block:: python

>>> import oneflow as flow
>>> import numpy as np

>>> input = flow.tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32)
>>> like = flow.randn(1, 1, 4, 4)
>>> output = flow.nn.functional.interpolate_like(input, like, mode="nearest")
>>> output
tensor([[[[1., 1., 2., 2.],
[1., 1., 2., 2.],
[3., 3., 4., 4.],
[3., 3., 4., 4.]]]], dtype=oneflow.float32)

"""
return InterpolateLike(mode=mode, align_corners=align_corners,).forward(input, like)


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
77 changes: 77 additions & 0 deletions python/oneflow/test/modules/test_upsample_like.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest
from collections import OrderedDict

import numpy as np
from oneflow.test_utils.test_util import GenArgList
from oneflow.test_utils.automated_test_util import *

import oneflow as flow
import oneflow.unittest


def _test_upsample_nearest_2d_like(test_case, shape_scale):
input_shape, out_like_shape = shape_scale
# init data by shape
inputs = np.random.randn(*input_shape)
out_like = np.random.randn(*out_like_shape)

# get numpy function
def nearest_upsample_by_np(inputs, out_like):
in_height, in_width = inputs.shape[-2:]
out_height, out_width = out_like.shape[-2:]
scale_h = out_height / in_height
scale_w = out_width / in_width
output = np.zeros(out_like.shape)
for i in range(out_height):
for j in range(out_width):
src_i = int(min(i / scale_h, in_height - 1))
src_j = int(min(j / scale_w, in_width - 1))
output[..., i, j] = inputs[..., src_i, src_j]
return output

# oneflow
cpu_input = flow.tensor(inputs, dtype=flow.float32)
cpu_out_like = flow.tensor(out_like, dtype=flow.float32)
cpu_output = flow.nn.functional.interpolate_like(
cpu_input, like=cpu_out_like, mode="nearest"
)
# numpy
np_output = nearest_upsample_by_np(inputs, out_like)
# compare result between oneflow and numpy
test_case.assertTrue(np.allclose(np_output, cpu_output.numpy(), 0.001, 0.001))


@flow.unittest.skip_unless_1n1d()
class TestUpsample2dLike(flow.unittest.TestCase):
def test_upsample2d_like(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_upsample_nearest_2d_like,
]
arg_dict["shape_scale"] = [
((1, 1, 2, 2), (1, 1, 3, 3)),
((5, 3, 6, 4), (5, 3, 9, 6)),
((2, 3, 2, 4), (2, 3, 3, 5)),
]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])


if __name__ == "__main__":
unittest.main()
Loading