diff --git a/README.md b/README.md
index e799f74..68d94bd 100644
--- a/README.md
+++ b/README.md
@@ -6,7 +6,7 @@ By [Duo Li](https://duoli.org/), [Jie Hu](https://github.com/hujie-frank), [Chan

-**TL; DR.** `involution` is a general-purpose neural primitive that is versatile for a spectrum of deep learning models on different vision tasks. `involution` bridges `convolution` and `self-attention` in design, while being more efficient and effective than `convolution`, simpler than `self-attention` in form.
+**TL; DR.** `involution` is a general-purpose neural primitive that is versatile for a spectrum of deep learning models on different vision tasks. `involution` bridges `convolution` and `self-attention` in design, while being more efficient and effective than `convolution`, simpler than `self-attention` in form.


diff --git a/docker/run-docker.sh b/docker/run-docker.sh
new file mode 100755
index 0000000..e30c835
--- /dev/null
+++ b/docker/run-docker.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+RUN_DIR=$(dirname $(readlink -f $0))
+
+DOCKER_VOLUME="${DOCKER_VOLUME} -v $(dirname ${RUN_DIR}):/workspace/involution:rw"
+
+docker run \
+ -it \
+ --rm \
+ --gpus '"device=0"' \
+ ${DOCKER_VOLUME} \
+ --name Involution-PyTorch \
+ pytorch/pytorch:1.7.0-cuda11.0-cudnn8-devel bash
+ # pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel bash
+ # pytorch/pytorch:1.9.0-cuda11.1-cudnn8-devel bash
+ # pytorch/pytorch:1.8.1-cuda11.1-cudnn8-devel bash
+ # nvcr.io/nvidia/pytorch:21.05-py3
+ # nvcr.io/nvidia/pytorch:20.08-py3
diff --git a/include/involution2d_cpu.h b/include/involution2d_cpu.h
new file mode 100644
index 0000000..a1e8851
--- /dev/null
+++ b/include/involution2d_cpu.h
@@ -0,0 +1,53 @@
+#pragma once
+
+#include
+#include
+
+namespace involution {
+namespace cpu {
+
+at::Tensor involution2d_forward(
+ const at::Tensor& input,
+ const at::Tensor& weight,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+);
+
+at::Tensor involution2d_backward_grad_input(
+ const at::Tensor& grad,
+ const at::Tensor& weight,
+ const std::vector& input_shape,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+);
+
+at::Tensor involution2d_backward_grad_weight(
+ const at::Tensor& grad,
+ const at::Tensor& input,
+ const std::vector& weight_shape,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+);
+
+std::vector involution2d_backward(
+ const at::Tensor& grad,
+ const at::Tensor& weight,
+ const at::Tensor& input,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+);
+
+} // namespace cpu
+} // namespace involution
diff --git a/include/involution2d_cuda.cuh b/include/involution2d_cuda.cuh
new file mode 100644
index 0000000..dfc9425
--- /dev/null
+++ b/include/involution2d_cuda.cuh
@@ -0,0 +1,59 @@
+#pragma once
+
+#include
+#include
+#include
+
+namespace involution {
+namespace cuda {
+
+#define CUDA_MAX_THREADS 1024u
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
+
+at::Tensor involution2d_forward(
+ const at::Tensor& input,
+ const at::Tensor& weight,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+);
+
+at::Tensor involution2d_backward_grad_input(
+ const at::Tensor& grad,
+ const at::Tensor& weight,
+ const std::vector& input_shape,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+);
+
+at::Tensor involution2d_backward_grad_weight(
+ const at::Tensor& grad,
+ const at::Tensor& input,
+ const std::vector& weight_shape,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+);
+
+std::vector involution2d_backward(
+ const at::Tensor& grad,
+ const at::Tensor& weight,
+ const at::Tensor& input,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+);
+
+} // namespace cuda
+} // namespace involution
diff --git a/include/involution2d_wrapper.h b/include/involution2d_wrapper.h
new file mode 100644
index 0000000..12f4295
--- /dev/null
+++ b/include/involution2d_wrapper.h
@@ -0,0 +1,234 @@
+#pragma once
+
+#include
+#include
+#include
+
+#include "involution2d_cpu.h"
+
+#ifdef USE_CUDA
+# include "involution2d_cuda.cuh"
+#endif
+
+namespace involution {
+
+at::Tensor involution2d(
+ const at::Tensor& input,
+ const at::Tensor& weight,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation
+) {
+ static auto op = at::Dispatcher::singleton()
+ .findSchemaOrThrow("involution::involution2d", "")
+ .typed();
+
+ return op.call(input, weight, stride, padding, dilation);
+}
+
+at::Tensor involution2d_autocast(
+ const at::Tensor& input,
+ const at::Tensor& weight,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation
+) {
+ c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
+ auto exec_type = at::autocast::promote_type(at::kFloat, input, weight);
+ return involution2d(at::autocast::cached_cast(exec_type, input), at::autocast::cached_cast(exec_type, weight), stride, padding, dilation)
+ .to(input.scalar_type());
+}
+
+at::Tensor _involution2d_backward_grad_input(
+ const at::Tensor& grad,
+ const at::Tensor& weight,
+ const std::vector& input_shape,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation
+) {
+ static auto op = at::Dispatcher::singleton()
+ .findSchemaOrThrow("involution2d::_involution2d_backward_grad_input", "")
+ .typed();
+
+ return op.call(grad, weight, input_shape, stride, padding, dilation);
+}
+
+at::Tensor _involution2d_backward_grad_weight(
+ const at::Tensor& grad,
+ const at::Tensor& input,
+ const std::vector& weight_shape,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation
+) {
+ static auto op = at::Dispatcher::singleton()
+ .findSchemaOrThrow("involution2d::_involution2d_backward_grad_weight", "")
+ .typed();
+
+ return op.call(grad, input, weight_shape, stride, padding, dilation);
+}
+
+namespace cpu {
+
+class Involution2dFunctionCPU : public torch::autograd::Function
+{
+ public:
+
+ static torch::autograd::variable_list forward(
+ torch::autograd::AutogradContext* ctx,
+ const torch::autograd::Variable& input,
+ const torch::autograd::Variable& weight,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+ ) {
+ ctx->saved_data["kernel_size"] = kernel_size;
+ ctx->saved_data["stride"] = stride;
+ ctx->saved_data["padding"] = padding;
+ ctx->saved_data["dilation"] = dilation;
+ ctx->saved_data["groups"] = groups;
+ ctx->save_for_backward({input, weight});
+
+ auto output = involution2d_forward(input, weight, kernel_size, stride, padding, dilation, groups);
+
+ return {output};
+ }
+
+ static torch::autograd::variable_list backward(
+ torch::autograd::AutogradContext* ctx,
+ const torch::autograd::variable_list grad_output
+ ) {
+ torch::autograd::variable_list saved = ctx->get_saved_variables();
+ torch::autograd::Variable input = saved[0];
+ torch::autograd::Variable weight = saved[1];
+
+ auto kernel_size = ctx->saved_data["kernel_size"].toIntVector();
+ auto stride = ctx->saved_data["stride"].toIntVector();
+ auto padding = ctx->saved_data["padding"].toIntVector();
+ auto dilation = ctx->saved_data["dilation"].toIntVector();
+ auto groups = ctx->saved_data["groups"].toInt();
+
+ auto grads = involution2d_backward(grad_output[0], weight, input, kernel_size, stride, padding, dilation, groups);
+
+ return {
+ grads[0],
+ grads[1],
+ torch::autograd::Variable(),
+ torch::autograd::Variable(),
+ torch::autograd::Variable(),
+ torch::autograd::Variable(),
+ torch::autograd::Variable()
+ };
+ }
+};
+
+at::Tensor involution2d_autograd(
+ const torch::autograd::Variable& input,
+ const torch::autograd::Variable& weight,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+) {
+ return Involution2dFunctionCPU::apply(input, weight, kernel_size, stride, padding, dilation, groups)[0];
+}
+
+} // namespace cpu
+
+#ifdef USE_CUDA
+
+namespace cuda {
+
+class Involution2dFunctionCUDA : public torch::autograd::Function
+{
+ public:
+
+ static torch::autograd::variable_list forward(
+ torch::autograd::AutogradContext* ctx,
+ const torch::autograd::Variable& input,
+ const torch::autograd::Variable& weight,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+ ) {
+ ctx->saved_data["kernel_size"] = kernel_size;
+ ctx->saved_data["stride"] = stride;
+ ctx->saved_data["padding"] = padding;
+ ctx->saved_data["dilation"] = dilation;
+ ctx->saved_data["groups"] = groups;
+ ctx->save_for_backward({input, weight});
+
+ auto output = involution2d_forward(input, weight, kernel_size, stride, padding, dilation, groups);
+
+ return {output};
+ }
+
+ static torch::autograd::variable_list backward(
+ torch::autograd::AutogradContext* ctx,
+ const torch::autograd::variable_list grad_output
+ ) {
+ torch::autograd::variable_list saved = ctx->get_saved_variables();
+ torch::autograd::Variable input = saved[0];
+ torch::autograd::Variable weight = saved[1];
+
+ auto kernel_size = ctx->saved_data["kernel_size"].toIntVector();
+ auto stride = ctx->saved_data["stride"].toIntVector();
+ auto padding = ctx->saved_data["padding"].toIntVector();
+ auto dilation = ctx->saved_data["dilation"].toIntVector();
+ auto groups = ctx->saved_data["groups"].toInt();
+
+ auto grads = involution2d_backward(grad_output[0], weight, input, kernel_size, stride, padding, dilation, groups);
+
+ return {
+ grads[0],
+ grads[1],
+ torch::autograd::Variable(),
+ torch::autograd::Variable(),
+ torch::autograd::Variable(),
+ torch::autograd::Variable(),
+ torch::autograd::Variable()
+ };
+ }
+};
+
+at::Tensor involution2d_autograd(
+ const torch::autograd::Variable& input,
+ const torch::autograd::Variable& weight,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+) {
+ return Involution2dFunctionCUDA::apply(input, weight, kernel_size, stride, padding, dilation, groups)[0];
+}
+
+at::Tensor involution2d_autocast(
+ const torch::autograd::Variable& input,
+ const torch::autograd::Variable& weight,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+) {
+ c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
+ auto exec_type = at::autocast::promote_type(at::kFloat, input, weight);
+ return involution2d_autograd(
+ at::autocast::cached_cast(exec_type, input),
+ at::autocast::cached_cast(exec_type, weight),
+ kernel_size, stride, padding, dilation, groups
+ );
+}
+
+} // namespace cuda
+
+#endif
+
+} // namespace involution
diff --git a/involution/__init__.py b/involution/__init__.py
new file mode 100644
index 0000000..7839be6
--- /dev/null
+++ b/involution/__init__.py
@@ -0,0 +1,9 @@
+from glob import glob
+import os
+
+from torch import ops
+
+_LIB_PATH = glob(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'involution.*.so'))[0]
+ops.load_library(_LIB_PATH)
+
+from .involution2d import Involution2d
diff --git a/involution/involution2d.py b/involution/involution2d.py
new file mode 100644
index 0000000..bdf41ec
--- /dev/null
+++ b/involution/involution2d.py
@@ -0,0 +1,123 @@
+from typing import Optional, Tuple, Union
+import torch
+import torch.nn as nn
+from torch.nn.modules.utils import _pair
+from torch import ops
+
+def _involution2d(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ kernel_size: Union[int, Tuple[int, int]] = 7,
+ stride: Union[int, Tuple[int, int]] = 1,
+ padding: Union[int, Tuple[int, int]] = 0,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ groups: int = 1,
+ bias: torch.Tensor = None,
+ ) -> torch.Tensor:
+ kernel_size_ = _pair(kernel_size)
+ stride_ = _pair(stride)
+ padding_ = _pair(padding)
+ dilation_ = _pair(dilation)
+
+ output: torch.Tensor = ops.involution.involution2d(input, weight, kernel_size_, stride_, padding_, dilation_, groups)
+
+ if bias is not None:
+ output += bias.view(1, -1, 1, 1)
+
+ return output
+
+class Involution2d(nn.Module):
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]] = 7,
+ stride: Union[int, Tuple[int, int]] = 1,
+ padding: Union[int, Tuple[int, int]] = 3,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ groups: int = 1,
+ bias: bool = False,
+ sigma_mapping: Optional[nn.Module] = None,
+ reduce_ratio: int = 1,
+ ) -> None:
+ """2D Involution: https://arxiv.org/pdf/2103.06255.pdf
+ Args:
+ in_channels (int): Number of input channels
+ out_channels (int): Number of output channels
+ kernel_size (Union[int, Tuple[int, int]], optional): Kernel size to be used. Defaults to 7.
+ stride (Union[int, Tuple[int, int]], optional): Stride factor to be utilized. Defaults to 1.
+ padding (Union[int, Tuple[int, int]], optional): Padding to be used in unfold operation. Defaults to 3.
+ dilation (Union[int, Tuple[int, int]], optional): Dilation in unfold to be employed. Defaults to 1.
+ groups (int, optional): Number of groups to be employed. Defaults to 1.
+ bias (bool, optional): If true bias is utilized in each convolution layer. Defaults to False.
+ sigma_mapping (Optional[nn.Module], optional): Non-linear mapping as introduced in the paper. If none BN + ReLU is utilized
+ reduce_ratio (int, optional): Reduce ration of involution channels. Defaults to 1.
+ """
+ super(Involution2d, self).__init__()
+
+ assert isinstance(in_channels, int) and in_channels > 0, \
+ '"in_channels" must be a positive integer.'
+ assert isinstance(out_channels, int) and out_channels > 0, \
+ '"out_channels" must be a positive integer.'
+ assert isinstance(kernel_size, (int, tuple)), \
+ '"kernel_size" must be an int or a tuple of ints.'
+ assert isinstance(stride, (int, tuple)), \
+ '"stride" must be an int or a tuple of ints.'
+ assert isinstance(padding, (int, tuple)), \
+ '"padding" must be an int or a tuple of ints.'
+ assert isinstance(dilation, (int, tuple)), \
+ '"dilation" must be an int or a tuple of ints.'
+ assert isinstance(groups, int) and groups > 0, \
+ '"groups" must be a positive integer.'
+ assert in_channels % groups == 0, '"in_channels" must be divisible by "groups".'
+ assert out_channels % groups == 0, '"out_channels" must be divisible by "groups".'
+ assert isinstance(bias, bool), '"bias" must be a bool.'
+ assert isinstance(sigma_mapping, nn.Module) or sigma_mapping is None, \
+ '"sigma_mapping" muse be an int or a tuple of ints.'
+ assert isinstance(reduce_ratio, int) and reduce_ratio > 0, \
+ '"reduce_ratio" must be a positive integer.'
+
+ self.in_channels: int = in_channels
+ self.out_channels: int = out_channels
+ self.kernel_size: Tuple[int, int] = _pair(kernel_size)
+ self.stride: Tuple[int, int] = _pair(stride)
+ self.padding: Tuple[int, int] = _pair(padding)
+ self.dilation: Tuple[int, int] = _pair(dilation)
+ self.groups: int = groups
+ self.bias: bool = bias
+ self.reduce_ratio: int = reduce_ratio
+
+ self.sigma_mapping = sigma_mapping if isinstance(sigma_mapping, nn.Module) else nn.Sequential(
+ nn.BatchNorm2d(num_features=self.out_channels //
+ self.reduce_ratio, momentum=0.3),
+ nn.ReLU()
+ )
+ self.initial_mapping = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, bias=bias) \
+ if self.in_channels != self.out_channels else nn.Identity()
+ self.o_mapping = nn.AvgPool2d(
+ kernel_size=self.stride) if self.stride[0] > 1 or self.stride[1] > 1 else nn.Identity()
+ self.reduce_mapping = nn.Conv2d(
+ in_channels=self.in_channels, out_channels=self.out_channels // self.reduce_ratio, kernel_size=1, bias=bias)
+ self.span_mapping = nn.Conv2d(in_channels=self.out_channels // self.reduce_ratio,
+ out_channels=self.kernel_size[0] * self.kernel_size[1] * self.groups, kernel_size=1, bias=bias)
+
+ def __repr__(self) -> str:
+ """Method returns information about the module
+ Returns:
+ str: Info string
+ """
+ return (f'{self.__class__.__name__}({self.in_channels}, {self.out_channels}, kernel_size=({self.kernel_size[0]}, {self.kernel_size[1]}), '
+ f'stride=({self.stride[0]}, {self.stride[1]}), padding=({self.padding[0]}, {self.padding[1]}), dilation=({self.dilation[0], self.dilation[1]}), '
+ f'groups={self.groups}, bias={self.bias}, reduce_ratio={self.reduce_ratio}, sigma_mapping={str(self.sigma_mapping)}'
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ """Forward pass
+ Args:
+ input (torch.Tensor): Input tensor of the shape [batch size, in channels, height, width]
+ Returns:
+ torch.Tensor: Output tensor of the shape [batch size, out channels, height, width] (w/ same padding)
+ """
+ weight: torch.Tensor = self.span_mapping(self.sigma_mapping(self.reduce_mapping(self.o_mapping(input))))
+ input_init: torch.Tensor = self.initial_mapping(input)
+
+ return _involution2d(input_init, weight, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..5207016
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,70 @@
+import os
+from os.path import abspath, dirname, join
+from setuptools import setup, find_packages
+from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension
+
+INCLUDE_DIR = join(dirname(abspath(__file__)), 'include')
+EXTRA_COMPILE_ARGS = ['-O3']
+
+EXTENSION = []
+
+CC = ['52', '53', '60', '61', '62', '70', '72', '75', '80']
+
+if os.getenv('USE_OPENMP', '1') == '1':
+ EXTRA_COMPILE_ARGS.append('-fopenmp')
+
+if os.getenv('USE_CUDA', '1') == '1':
+ EXTRA_COMPILE_ARGS.append('-DUSE_CUDA')
+
+ GENERATE_CODES = []
+
+ for cc in CC:
+ GENERATE_CODES.append('--generate-code')
+ GENERATE_CODES.append(f'arch=compute_{cc},code=compute_{cc}')
+
+ EXTENSION.append(
+ CUDAExtension(
+ name='involution',
+ sources=[
+ 'src/involution2d_cpu.cpp',
+ 'src/involution2d_cuda.cu',
+ 'src/pytorch_wrapper.cpp',
+ ],
+ include_dirs=[
+ INCLUDE_DIR
+ ],
+ extra_compile_args={
+ 'cxx': EXTRA_COMPILE_ARGS,
+ 'nvcc': ['-O3'] + GENERATE_CODES,
+ }
+ )
+ )
+else:
+ EXTENSION.append(
+ CppExtension(
+ name='involution',
+ sources=[
+ 'src/involution2d_cpu.cpp',
+ 'src/pytorch_wrapper.cpp',
+ ],
+ include_dirs=[
+ INCLUDE_DIR
+ ],
+ extra_compile_args=EXTRA_COMPILE_ARGS
+ )
+ )
+
+setup(
+ name='involution-pytorch',
+ version="0.1.0",
+ url="https://github.com/shikishima-TasakiLab/Involution-PyTorch",
+ license="MIT License",
+ author="Junya Shikishima",
+ author_email="160442065@ccalumni.meijo-u.ac.jp",
+ description="PyTorch Involution",
+ packages=find_packages(),
+ ext_modules=EXTENSION,
+ cmdclass={
+ 'build_ext': BuildExtension,
+ }
+)
diff --git a/src/involution2d_cpu.cpp b/src/involution2d_cpu.cpp
new file mode 100644
index 0000000..d9af583
--- /dev/null
+++ b/src/involution2d_cpu.cpp
@@ -0,0 +1,359 @@
+#include "involution2d_cpu.h"
+
+namespace involution {
+namespace cpu {
+
+template
+static void involution2d_forward_frame(
+ const at::Tensor& in_data,
+ const at::Tensor& weight_data,
+ at::Tensor& out_data,
+ const at::IntArrayRef& kernel_size,
+ const at::IntArrayRef& padding,
+ const at::IntArrayRef& stride,
+ const at::IntArrayRef& dilation
+) {
+ auto num_elements = out_data.numel();
+ const auto groups = weight_data.size(1);
+ const auto channels = in_data.size(1);
+ const auto in_height = in_data.size(2);
+ const auto in_width = in_data.size(3);
+ const auto out_height = out_data.size(2);
+ const auto out_width = out_data.size(3);
+
+ auto in_data_a = in_data.accessor();
+ auto weight_data_a = weight_data.accessor();
+ auto* out_data_p = out_data.data_ptr();
+
+ #pragma omp parallel for
+ for (int64_t idx = 0l; idx < num_elements; idx++) {
+ const int64_t w = idx % out_width;
+ const int64_t h = (idx / out_width) % out_height;
+ int64_t divisor = out_width * out_height;
+ const int64_t c = (idx / divisor) % channels;
+ divisor *= channels;
+ const int64_t n = idx / divisor;
+ const int64_t g = c / (channels / groups);
+
+ scalar_t value = 0;
+
+ for (int64_t kh = 0l; kh < kernel_size[0]; kh++) {
+ const int64_t h_in = h * stride[0] + kh * dilation[0] - padding[0];
+
+ if ((0l <= h_in) && (h_in < in_height)) {
+ for (int64_t kw = 0l; kw < kernel_size[1]; kw++) {
+ const int64_t w_in = w * stride[1] + kw * dilation[1] - padding[1];
+
+ if ((0l <= w_in) && (w_in < in_width)) {
+ value += weight_data_a[n][g][kh][kw][h][w] * in_data_a[n][c][h_in][w_in];
+ }
+ }
+ }
+ }
+ out_data_p[idx] = value;
+ }
+}
+
+at::Tensor involution2d_forward(
+ const at::Tensor& input,
+ const at::Tensor& weight,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+) {
+ AT_ASSERTM(input.device().is_cpu(), "\"input\" must be a CPU tensor.");
+ AT_ASSERTM(weight.device().is_cpu(), "\"weight\" must be a CPU tensor.");
+
+ at::TensorArg input_t{input, "input", 1}, weight_t{weight, "weight", 2};
+
+ at::CheckedFrom c = __func__;
+ at::checkAllSameType(c, {input_t, weight_t});
+
+ const auto batch_size = input.size(0);
+ const auto channels = input.size(1);
+ const auto in_height = input.size(2);
+ const auto in_width = input.size(3);
+
+ const auto weight_height = weight.size(2);
+ const auto weight_width = weight.size(3);
+
+ const at::Tensor weight_ = weight.view({batch_size, groups, kernel_size[0], kernel_size[1], weight_height, weight_width});
+
+ const auto out_height = (in_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1;
+ const auto out_width = (in_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1;
+
+ at::Tensor output = at::zeros({batch_size, channels, out_height, out_width}, input.options());
+
+ if (output.numel() == 0) {
+ return output;
+ }
+
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::kHalf,
+ at::kBFloat16,
+ input.scalar_type(),
+ "involution2d_forward_kernel", [&] {
+ involution2d_forward_frame(
+ input,
+ weight_,
+ output,
+ kernel_size,
+ padding,
+ stride,
+ dilation
+ );
+ }
+ );
+ return output;
+}
+
+template
+static void involution2d_backward_grad_input_frame(
+ const at::Tensor& out_diff,
+ const at::Tensor& weight_data,
+ at::Tensor& in_diff,
+ const at::IntArrayRef& kernel_size,
+ const at::IntArrayRef& padding,
+ const at::IntArrayRef& stride,
+ const at::IntArrayRef& dilation
+) {
+ auto num_elements = in_diff.numel();
+ const auto groups = weight_data.size(1);
+ const auto channels = in_diff.size(1);
+ const auto in_height = in_diff.size(2);
+ const auto in_width = in_diff.size(3);
+ const auto out_height = out_diff.size(2);
+ const auto out_width = out_diff.size(3);
+
+ auto out_diff_a = out_diff.accessor();
+ auto weight_data_a = weight_data.accessor();
+ auto* in_diff_p = in_diff.data_ptr();
+
+ #pragma omp parallel for
+ for (int64_t idx = 0l; idx < num_elements; idx++) {
+ const int64_t w = idx % in_width;
+ const int64_t h = (idx / in_width) % in_height;
+ int64_t divisor = in_width * in_height;
+ const int64_t c = (idx / divisor) % channels;
+ divisor *= channels;
+ const int64_t n = idx / divisor;
+ const int64_t g = c / (channels / groups);
+
+ scalar_t value = 0;
+
+ for (int64_t kh = 0l; kh < kernel_size[0]; kh++) {
+ const int64_t h_out_s = h + padding[0] - kh * dilation[0];
+
+ for (int64_t kw = 0l; kw < kernel_size[1]; kw++) {
+ const int64_t w_out_s = w + padding[1] - kw * dilation[1];
+
+ if (((h_out_s % stride[0]) == 0) && ((w_out_s % stride[1]) == 0)) {
+ const int64_t h_out = h_out_s / stride[0];
+ const int64_t w_out = h_out_s / stride[1];
+
+ if ((0l <= h_out) && (h_out < out_height) && (0l <= w_out) && (w_out < out_width)) {
+ value += weight_data_a[n][g][kh][kw][h_out][w_out] * out_diff_a[n][c][h_out][w_out];
+ }
+ }
+ }
+ }
+ in_diff_p[idx] = value;
+ }
+}
+
+at::Tensor involution2d_backward_grad_input(
+ const at::Tensor& grad,
+ const at::Tensor& weight,
+ const std::vector& input_shape,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+) {
+ AT_ASSERTM(grad.device().is_cpu(), "\"grad\" must be a CPU tensor.");
+ AT_ASSERTM(weight.device().is_cpu(), "\"weight\" must be a CPU tensor.");
+
+ at::TensorArg grad_t{grad, "grad", 1}, weight_t{weight, "weight", 2};
+
+ at::CheckedFrom c = __func__;
+ at::checkAllSameType(c, {grad_t, weight_t});
+
+ const auto batch_size = input_shape[0];
+
+ const auto weight_height = weight.size(2);
+ const auto weight_width = weight.size(3);
+
+ const at::Tensor weight_ = weight.view({batch_size, groups, kernel_size[0], kernel_size[1], weight_height, weight_width});
+
+ at::Tensor grad_input = at::zeros(input_shape, grad.options());
+
+ if (grad_input.numel() == 0) {
+ return grad_input;
+ }
+
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, grad.scalar_type(), "involution2d_backward_grad_input_frame", [&] {
+ involution2d_backward_grad_input_frame(
+ grad,
+ weight_,
+ grad_input,
+ kernel_size,
+ padding,
+ stride,
+ dilation
+ );
+ });
+
+ return grad_input;
+}
+
+template
+static void involution2d_backward_grad_weight_frame(
+ const at::Tensor& out_diff,
+ const at::Tensor& in_data,
+ at::Tensor& weight_diff,
+ const at::IntArrayRef& kernel_size,
+ const at::IntArrayRef& padding,
+ const at::IntArrayRef& stride,
+ const at::IntArrayRef& dilation
+) {
+ auto num_elements = weight_diff.numel();
+ const auto groups = weight_diff.size(1);
+ const auto batch_size = in_data.size(0);
+ const auto channels = in_data.size(1);
+ const auto in_height = in_data.size(2);
+ const auto in_width = in_data.size(3);
+ const auto out_height = out_diff.size(2);
+ const auto out_width = out_diff.size(3);
+ const auto channels_per_group = channels / groups;
+
+ auto out_diff_a = out_diff.accessor();
+ auto in_data_a = in_data.accessor();
+ auto* weight_diff_p = weight_diff.data_ptr();
+
+ #pragma omp parallel for
+ for (int64_t idx = 0l; idx < num_elements; idx++) {
+ const int64_t w = idx % out_width;
+ const int64_t h = (idx / out_width) % out_height;
+ int64_t divisor = out_width * out_height;
+ const int64_t kw = (idx / divisor) % kernel_size[1];
+ divisor *= kernel_size[1];
+ const int64_t kh = (idx / divisor) % kernel_size[0];
+
+ const int64_t h_in = h * stride[0] + kh * dilation[0] - padding[0];
+ const int64_t w_in = w * stride[1] + kw * dilation[1] - padding[1];
+
+ if ((0l <= h_in) && (h_in < in_height) && (0l <= w_in) && (w_in < in_width)) {
+ divisor *= kernel_size[0];
+ const int64_t g = (idx / divisor) % groups;
+ divisor *= groups;
+ const int64_t n = (idx / divisor) % batch_size;
+
+ scalar_t value = 0;
+
+ for (int64_t c = g * channels_per_group; c < (g + 1) * channels_per_group; c++) {
+ value += out_diff_a[n][c][h][w] * in_data_a[n][c][h_in][w_in];
+ }
+ weight_diff_p[idx] = value;
+ }
+ else {
+ weight_diff_p[idx] = 0;
+ }
+ }
+}
+
+at::Tensor involution2d_backward_grad_weight(
+ const at::Tensor& grad,
+ const at::Tensor& input,
+ const std::vector& weight_shape,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+) {
+ AT_ASSERTM(grad.device().is_cpu(), "\"grad\" must be a CPU tensor.");
+ AT_ASSERTM(input.device().is_cpu(), "\"input\" must be a CPU tensor.");
+
+ at::TensorArg grad_t{grad, "grad", 1}, input_t{input, "input", 2};
+
+ at::CheckedFrom c = __func__;
+ at::checkAllSameType(c, {grad_t, input_t});
+
+ const auto batch_size = input.size(0);
+
+ at::Tensor grad_weight = at::zeros({batch_size, groups, kernel_size[0], kernel_size[1], weight_shape[2], weight_shape[3]}, grad.options());
+
+ if (grad_weight.numel() == 0) {
+ return grad_weight.view(weight_shape);
+ }
+
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::kHalf,
+ at::kBFloat16,
+ grad.scalar_type(),
+ "involution2d_backward_grad_weight_kernel", [&] {
+ involution2d_backward_grad_weight_frame(
+ grad,
+ input,
+ grad_weight,
+ kernel_size,
+ padding,
+ stride,
+ dilation
+ );
+ }
+ );
+ return grad_weight.view(weight_shape);
+}
+
+std::vector involution2d_backward(
+ const at::Tensor& grad,
+ const at::Tensor& weight,
+ const at::Tensor& input,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+) {
+ auto input_sizes = input.sizes();
+ std::vector input_size;
+ std::copy(input_sizes.begin(), input_sizes.end(), std::back_inserter(input_size));
+
+ auto grad_input = involution2d_backward_grad_input(
+ grad,
+ weight,
+ input_size,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ groups
+ );
+
+ auto weight_sizes = weight.sizes();
+ std::vector weight_size;
+ std::copy(weight_sizes.begin(), weight_sizes.end(), std::back_inserter(weight_size));
+
+ auto grad_weight = involution2d_backward_grad_weight(
+ grad,
+ input,
+ weight_size,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ groups
+ );
+
+ // std::vector output{grad_input, grad_weight};
+
+ // return output;
+ return {grad_input, grad_weight};
+}
+
+} // namespace cpu
+} // namespace involution
diff --git a/src/involution2d_cuda.cu b/src/involution2d_cuda.cu
new file mode 100644
index 0000000..925e6a5
--- /dev/null
+++ b/src/involution2d_cuda.cu
@@ -0,0 +1,407 @@
+#include
+
+namespace involution {
+namespace cuda {
+
+static u_int32_t ceildiv(u_int32_t num_elements, u_int32_t threads) {
+ return (num_elements + threads - 1) / threads;
+}
+
+template
+C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS_PER_BLOCK)
+__global__ static void involution2d_forward_kernel(
+ const at::GenericPackedTensorAccessor in_data,
+ const at::GenericPackedTensorAccessor weight_data,
+ scalar_t* const out_data,
+ const int64_t num_elements,
+ const int64_t channels,
+ const int64_t groups,
+ const int64_t in_height, const int64_t in_width,
+ const int64_t out_height, const int64_t out_width,
+ const int64_t kernel_height, const int64_t kernel_width,
+ const int64_t pad_h, const int64_t pad_w,
+ const int64_t stride_h, const int64_t stride_w,
+ const int64_t dilation_h, const int64_t dilation_w
+) {
+ CUDA_KERNEL_LOOP(idx, num_elements) {
+ const int64_t w = idx % out_width;
+ const int64_t h = (idx / out_width) % out_height;
+ int64_t divisor = out_width * out_height;
+ const int64_t c = (idx / divisor) % channels;
+ divisor *= channels;
+ const int64_t n = idx / divisor;
+ const int64_t g = c / (channels / groups);
+
+ scalar_t value = 0;
+
+ for (int64_t kh = 0l; kh < kernel_height; kh++) {
+ const int64_t h_in = h * stride_h + kh * dilation_h - pad_h;
+
+ if ((0l <= h_in) && (h_in < in_height)) {
+ for (int64_t kw = 0l; kw < kernel_width; kw++) {
+ const int64_t w_in = w * stride_w + kw * dilation_w - pad_w;
+
+ if ((0l <= w_in) && (w_in < in_width)) {
+ value += weight_data[n][g][kh][kw][h][w] * in_data[n][c][h_in][w_in];
+ }
+ }
+ }
+ }
+
+ out_data[idx] = value;
+ }
+}
+
+at::Tensor involution2d_forward(
+ const at::Tensor& input,
+ const at::Tensor& weight,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+) {
+ AT_ASSERTM(input.device().is_cuda(), "\"input\" must be a CUDA tensor.");
+ AT_ASSERTM(weight.device().is_cuda(), "\"weight\" must be a CUDA tensor.");
+
+ at::TensorArg input_t{input, "input", 1}, weight_t{weight, "weight", 2};
+
+ at::CheckedFrom c = __func__;
+ at::checkAllSameGPU(c, {input_t, weight_t});
+ at::checkAllSameType(c, {input_t, weight_t});
+
+ at::cuda::CUDAGuard device_guard(input.device());
+
+ const auto batch_size = input.size(0);
+ const auto channels = input.size(1);
+ const auto in_height = input.size(2);
+ const auto in_width = input.size(3);
+
+ const auto weight_height = weight.size(2);
+ const auto weight_width = weight.size(3);
+
+ const at::Tensor weight_ = weight.view({batch_size, groups, kernel_size[0], kernel_size[1], weight_height, weight_width});
+
+ const auto out_height = (in_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1;
+ const auto out_width = (in_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1;
+
+ at::Tensor output = at::zeros({batch_size, channels, out_height, out_width}, input.options());
+ const auto num_elements = output.numel();
+
+ if (num_elements == 0) {
+ AT_CUDA_CHECK(cudaGetLastError());
+ return output;
+ }
+
+ const auto threads = std::min(static_cast(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock), CUDA_MAX_THREADS_PER_BLOCK);
+ const dim3 num_blocks(ceildiv(num_elements, threads), 1u, 1u);
+ const dim3 threads_per_block(threads, 1u, 1u);
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::kHalf,
+ at::kBFloat16,
+ input.scalar_type(),
+ "involution2d_forward_kernel", [&] {
+ involution2d_forward_kernel<<>>(
+ input.generic_packed_accessor(),
+ weight_.generic_packed_accessor(),
+ output.data_ptr(),
+ num_elements,
+ channels,
+ groups,
+ in_height, in_width,
+ out_height, out_width,
+ kernel_size[0], kernel_size[1],
+ padding[0], padding[1],
+ stride[0], stride[1],
+ dilation[0], dilation[1]
+ );
+ }
+ );
+ AT_CUDA_CHECK(cudaGetLastError());
+ return output;
+}
+
+template
+C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS_PER_BLOCK)
+__global__ static void involution2d_backward_grad_input_kernel(
+ const at::GenericPackedTensorAccessor out_diff,
+ const at::GenericPackedTensorAccessor weight_data,
+ scalar_t* const in_diff,
+ const int64_t num_elements,
+ const int64_t channels,
+ const int64_t groups,
+ const int64_t in_height, const int64_t in_width,
+ const int64_t out_height, const int64_t out_width,
+ const int64_t kernel_height, const int64_t kernel_width,
+ const int64_t pad_h, const int64_t pad_w,
+ const int64_t stride_h, const int64_t stride_w,
+ const int64_t dilation_h, const int64_t dilation_w
+) {
+ CUDA_KERNEL_LOOP(idx, num_elements) {
+ const int64_t w = idx % in_width;
+ const int64_t h = (idx / in_width) % in_height;
+ int64_t divisor = in_width * in_height;
+ const int64_t c = (idx / divisor) % channels;
+ divisor *= channels;
+ const int64_t n = idx / divisor;
+ const int64_t g = c / (channels / groups);
+
+ scalar_t value = 0;
+
+ for (int64_t kh = 0l; kh < kernel_height; kh++) {
+ const int64_t h_out_s = h + pad_h - kh * dilation_h;
+
+ for (int64_t kw = 0l; kw < kernel_width; kw++) {
+ const int64_t w_out_s = w + pad_w - kw * dilation_w;
+
+ if (((h_out_s % stride_h) == 0) && ((w_out_s % stride_w) == 0)) {
+ const int64_t h_out = h_out_s / stride_h;
+ const int64_t w_out = h_out_s / stride_w;
+
+ if ((0l <= h_out) && (h_out < out_height) && (0l <= w_out) && (w_out < out_width)) {
+ value += weight_data[n][g][kh][kw][h_out][w_out] * out_diff[n][c][h_out][w_out];
+ }
+ }
+ }
+ }
+ in_diff[idx] = value;
+ }
+}
+
+at::Tensor involution2d_backward_grad_input(
+ const at::Tensor& grad,
+ const at::Tensor& weight,
+ const std::vector& input_shape,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+) {
+ AT_ASSERTM(grad.device().is_cuda(), "\"grad\" must be a CUDA tensor.");
+ AT_ASSERTM(weight.device().is_cuda(), "\"weight\" must be a CUDA tensor.");
+
+ at::TensorArg grad_t{grad, "grad", 1}, weight_t{weight, "weight", 2};
+
+ at::CheckedFrom c = __func__;
+ at::checkAllSameGPU(c, {grad_t, weight_t});
+ at::checkAllSameType(c, {grad_t, weight_t});
+
+ at::cuda::CUDAGuard device_guard(grad.device());
+
+ const auto batch_size = input_shape[0];
+ const auto channels = input_shape[1];
+ const auto in_height = input_shape[2];
+ const auto in_width = input_shape[3];
+
+ const auto weight_height = weight.size(2);
+ const auto weight_width = weight.size(3);
+
+ const at::Tensor weight_ = weight.view({batch_size, groups, kernel_size[0], kernel_size[1], weight_height, weight_width});
+
+ const auto out_height = grad.size(2);
+ const auto out_width = grad.size(3);
+
+ at::Tensor grad_input = at::zeros(input_shape, grad.options());
+ const auto num_elements = grad_input.numel();
+
+ if (num_elements == 0) {
+ AT_CUDA_CHECK(cudaGetLastError());
+ return grad_input;
+ }
+
+ const auto threads = std::min(static_cast(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock), CUDA_MAX_THREADS_PER_BLOCK);
+ const dim3 num_blocks(ceildiv(num_elements, threads), 1u, 1u);
+ const dim3 threads_per_block(threads, 1u, 1u);
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::kHalf,
+ at::kBFloat16,
+ grad.scalar_type(),
+ "involution2d_backward_grad_input_kernel", [&] {
+ involution2d_backward_grad_input_kernel<<>>(
+ grad.generic_packed_accessor(),
+ weight_.generic_packed_accessor(),
+ grad_input.data_ptr(),
+ num_elements,
+ channels,
+ groups,
+ in_height, in_width,
+ out_height, out_width,
+ kernel_size[0], kernel_size[1],
+ padding[0], padding[1],
+ stride[0], stride[1],
+ dilation[0], dilation[1]
+ );
+ }
+ );
+ AT_CUDA_CHECK(cudaGetLastError());
+ return grad_input;
+}
+
+template
+C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS_PER_BLOCK)
+__global__ static void involution2d_backward_grad_weight_kernel(
+ const at::GenericPackedTensorAccessor out_diff,
+ const at::GenericPackedTensorAccessor in_data,
+ scalar_t* const weight_diff,
+ const int64_t num_elements,
+ const int64_t batch_size,
+ const int64_t channels_per_group,
+ const int64_t groups,
+ const int64_t in_height, const int64_t in_width,
+ const int64_t out_height, const int64_t out_width,
+ const int64_t kernel_height, const int64_t kernel_width,
+ const int64_t pad_h, const int64_t pad_w,
+ const int64_t stride_h, const int64_t stride_w,
+ const int64_t dilation_h, const int64_t dilation_w
+) {
+ CUDA_KERNEL_LOOP(idx, num_elements) {
+ const int64_t w = idx % out_width;
+ const int64_t h = (idx / out_width) % out_height;
+ int64_t divisor = out_width * out_height;
+ const int64_t kw = (idx / divisor) % kernel_width;
+ divisor *= kernel_width;
+ const int64_t kh = (idx / divisor) % kernel_height;
+
+ const int64_t h_in = -pad_h + h * stride_h + kh * dilation_h;
+ const int64_t w_in = -pad_w + w * stride_w + kw * dilation_w;
+
+ if ((0l <= h_in) && (h_in < in_height) && (0l <= w_in) && (w_in < in_width)) {
+ divisor *= kernel_height;
+ const int64_t g = (idx / divisor) % groups;
+ divisor *= groups;
+ const int64_t n = (idx / divisor) % batch_size;
+
+ scalar_t value = 0;
+
+ for (int64_t c = g * channels_per_group; c < (g + 1) * channels_per_group; c++) {
+ value += out_diff[n][c][h][w] * in_data[n][c][h_in][w_in];
+ }
+ weight_diff[idx] = value;
+ }
+ else {
+ weight_diff[idx] = 0;
+ }
+ }
+}
+
+at::Tensor involution2d_backward_grad_weight(
+ const at::Tensor& grad,
+ const at::Tensor& input,
+ const std::vector& weight_shape,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+) {
+ AT_ASSERTM(grad.device().is_cuda(), "\"grad\" must be a CUDA tensor.");
+ AT_ASSERTM(input.device().is_cuda(), "\"input\" must be a CUDA tensor.");
+
+ at::TensorArg grad_t{grad, "grad", 1}, input_t{input, "input", 2};
+
+ at::CheckedFrom c = __func__;
+ at::checkAllSameGPU(c, {grad_t, input_t});
+ at::checkAllSameType(c, {grad_t, input_t});
+
+ at::cuda::CUDAGuard device_guard(grad.device());
+
+ const auto batch_size = input.size(0);
+ const auto channels = input.size(1);
+ const auto in_height = input.size(2);
+ const auto in_width = input.size(3);
+
+ const auto out_height = grad.size(2);
+ const auto out_width = grad.size(3);
+
+ at::Tensor grad_weight = at::zeros({batch_size, groups, kernel_size[0], kernel_size[1], weight_shape[2], weight_shape[3]}, grad.options());
+ const auto num_elements = grad_weight.numel();
+
+ if (num_elements == 0) {
+ AT_CUDA_CHECK(cudaGetLastError());
+ return grad_weight.view(weight_shape);
+ }
+
+ const auto threads = std::min(static_cast(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock), CUDA_MAX_THREADS_PER_BLOCK);
+ const dim3 num_blocks(ceildiv(num_elements, threads), 1u, 1u);
+ const dim3 threads_per_block(threads, 1u, 1u);
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::kHalf,
+ at::kBFloat16,
+ grad.scalar_type(),
+ "involution2d_backward_grad_weight_kernel", [&] {
+ involution2d_backward_grad_weight_kernel<<>>(
+ grad.generic_packed_accessor(),
+ input.generic_packed_accessor(),
+ grad_weight.data_ptr(),
+ num_elements,
+ batch_size,
+ channels / groups,
+ groups,
+ in_height, in_width,
+ out_height, out_width,
+ kernel_size[0], kernel_size[1],
+ padding[0], padding[1],
+ stride[0], stride[1],
+ dilation[0], dilation[1]
+ );
+ }
+ );
+ AT_CUDA_CHECK(cudaGetLastError());
+ return grad_weight.view(weight_shape);
+}
+
+std::vector involution2d_backward(
+ const at::Tensor& grad,
+ const at::Tensor& weight,
+ const at::Tensor& input,
+ const std::vector& kernel_size,
+ const std::vector& stride,
+ const std::vector& padding,
+ const std::vector& dilation,
+ const int64_t groups
+) {
+ auto input_sizes = input.sizes();
+ std::vector input_size;
+ std::copy(input_sizes.begin(), input_sizes.end(), std::back_inserter(input_size));
+
+ auto grad_input = involution2d_backward_grad_input(
+ grad,
+ weight,
+ input_size,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ groups
+ );
+
+ auto weight_sizes = weight.sizes();
+ std::vector weight_size;
+ std::copy(weight_sizes.begin(), weight_sizes.end(), std::back_inserter(weight_size));
+
+ auto grad_weight = involution2d_backward_grad_weight(
+ grad,
+ input,
+ weight_size,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ groups
+ );
+
+ return {grad_input, grad_weight};
+}
+
+} // namespace cuda
+} // namespace involution
diff --git a/src/pytorch_wrapper.cpp b/src/pytorch_wrapper.cpp
new file mode 100644
index 0000000..af492c6
--- /dev/null
+++ b/src/pytorch_wrapper.cpp
@@ -0,0 +1,39 @@
+#include
+#include "involution2d_wrapper.h"
+
+TORCH_LIBRARY(involution, m) {
+ m.def("involution2d(Tensor input, Tensor weight, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor");
+ m.def("_involution2d_backward_grad_input(Tensor grad, Tensor weight, int[] input_shape, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor");
+ m.def("_involution2d_backward_grad_weight(Tensor grad, Tensor input, int[] weight_shape, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor");
+ m.def("_involution2d_backward(Tensor grad, Tensor weight, Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor[]");
+}
+
+TORCH_LIBRARY_IMPL(involution, CPU, m) {
+ m.impl("involution2d", involution::cpu::involution2d_forward);
+ m.impl("_involution2d_backward_grad_input", involution::cpu::involution2d_backward_grad_input);
+ m.impl("_involution2d_backward_grad_weight", involution::cpu::involution2d_backward_grad_weight);
+ m.impl("_involution2d_backward", involution::cpu::involution2d_backward);
+}
+
+#ifdef USE_CUDA
+TORCH_LIBRARY_IMPL(involution, CUDA, m) {
+ m.impl("involution2d", involution::cuda::involution2d_forward);
+ m.impl("_involution2d_backward_grad_input", involution::cuda::involution2d_backward_grad_input);
+ m.impl("_involution2d_backward_grad_weight", involution::cuda::involution2d_backward_grad_weight);
+ m.impl("_involution2d_backward", involution::cuda::involution2d_backward);
+}
+#endif
+
+TORCH_LIBRARY_IMPL(involution, AutogradCPU, m) {
+ m.impl("involution2d", involution::cpu::involution2d_autograd);
+}
+
+#ifdef USE_CUDA
+TORCH_LIBRARY_IMPL(involution, AutogradCUDA, m) {
+ m.impl("involution2d", involution::cuda::involution2d_autograd);
+}
+
+TORCH_LIBRARY_IMPL(involution, Autocast, m) {
+ m.impl("involution2d", involution::cuda::involution2d_autocast);
+}
+#endif