Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated conv function to use a NumPy-based implementation #21051

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
300 changes: 288 additions & 12 deletions keras/src/backend/numpy/nn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import jax
import numpy as np
from jax import lax
from numpy.lib._stride_tricks_impl import as_strided

from keras.src import backend
from keras.src.backend.common.backend_utils import (
Expand Down Expand Up @@ -355,6 +356,260 @@ def _convert_to_lax_conv_dimension_numbers(
)


def _same_padding(input_size, kernel_size, stride):
if input_size % stride == 0:
padding = max(kernel_size - stride, 0)
else:
padding = max(kernel_size - (input_size % stride), 0)
return padding // 2, padding - padding // 2


def np_conv1d(
x,
kernel_weights,
strides,
padding,
data_format,
dilation_rate,
groups,
):
if data_format == "channels_first":
x = x.swapaxes(1, 2)
h_stride = strides[0] if isinstance(strides, (tuple, list)) else strides
dilation_rate = (
dilation_rate[0]
if isinstance(dilation_rate, (tuple, list))
else dilation_rate
)
kernel_size, ch_in, ch_out = kernel_weights.shape

if dilation_rate > 1:
dilated_size = kernel_size + (dilation_rate - 1) * (kernel_size - 1)
new_kernel = np.zeros(
(dilated_size, ch_in, ch_out), dtype=kernel_weights.dtype
)
new_kernel[::dilation_rate] = kernel_weights
kernel_weights, kernel_size = new_kernel, dilated_size

if padding != "valid":
n_batch, h_x, _ = x.shape
h_pad = _same_padding(h_x, kernel_size, h_stride)
npad = [(0, 0)] * x.ndim
npad[1] = (h_pad[0] + h_pad[1], 0) if padding == "causal" else h_pad
x = np.pad(x, pad_width=npad, mode="constant", constant_values=0)

n_batch, h_x, _ = x.shape
h_out = (h_x - kernel_size) // h_stride + 1
kernel_weights = kernel_weights.reshape(-1, ch_out)

out_grps = []
ch_out_per_grp = ch_out // groups
for grp in range(groups):
x_in = x[..., grp * ch_in : (grp + 1) * ch_in]
stride_shape = (n_batch, h_out, kernel_size, ch_in)
strides = (
x_in.strides[0],
h_stride * x_in.strides[1],
x_in.strides[1],
x_in.strides[2],
)

inner_dim = kernel_size * ch_in
x_strided = as_strided(
x_in, shape=stride_shape, strides=strides
).reshape(-1, inner_dim)

kernel_weights_grp = kernel_weights[
..., grp * ch_out_per_grp : (grp + 1) * ch_out_per_grp
]
result = x_strided @ kernel_weights_grp
out_grps.append(result.reshape(n_batch, h_out, -1))
out = np.concatenate(out_grps, axis=-1)
if data_format == "channels_first":
out = out.swapaxes(1, 2)
return out


def np_conv2d(
x,
kernel_weights,
strides,
padding,
data_format,
dilation_rate,
groups,
):
if data_format == "channels_first":
x = x.transpose((0, 2, 3, 1))

h_stride, w_stride = (
(strides, strides) if isinstance(strides, int) else strides
)
h_dilation, w_dilation = (
(dilation_rate, dilation_rate)
if isinstance(dilation_rate, int)
else dilation_rate
)
h_kernel, w_kernel, ch_in, ch_out = kernel_weights.shape

if h_dilation > 1 or w_dilation > 1:
new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1)
new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1)
new_kernel_size = (new_h_kernel, new_w_kernel)

new_kernel_weights = np.zeros(
(*new_kernel_size, ch_in, ch_out), dtype=kernel_weights.dtype
)
new_kernel_weights[::h_dilation, ::w_dilation] = kernel_weights
kernel_weights = new_kernel_weights
h_kernel, w_kernel = kernel_weights.shape[:2]

if padding == "same":
n_batch, h_x, w_x, _ = x.shape
h_pad = _same_padding(h_x, h_kernel, h_stride)
w_pad = _same_padding(w_x, w_kernel, w_stride)
npad = [(0, 0)] * x.ndim
npad[1], npad[2] = h_pad, w_pad
x = np.pad(x, pad_width=npad, mode="constant", constant_values=0)

n_batch, h_x, w_x, _ = x.shape
h_out = (h_x - h_kernel) // h_stride + 1
w_out = (w_x - w_kernel) // w_stride + 1

out_grps = []
ch_out_groups = ch_out // groups
for grp in range(1, groups + 1):
x_in = x[..., (grp - 1) * ch_in : grp * ch_in]

stride_shape = (n_batch, h_out, w_out, h_kernel, w_kernel, ch_in)
strides = (
x_in.strides[0],
h_stride * x_in.strides[1],
w_stride * x_in.strides[2],
x_in.strides[1],
x_in.strides[2],
x_in.strides[3],
)

inner_dim = h_kernel * w_kernel * ch_in
x_strided = as_strided(
x_in, shape=stride_shape, strides=strides
).reshape(-1, inner_dim)

kernel_weights_grp = kernel_weights[
..., (grp - 1) * ch_out_groups : grp * ch_out_groups
].reshape(-1, ch_out_groups)
out_grps.append(
(x_strided @ kernel_weights_grp).reshape(n_batch, h_out, w_out, -1)
)
out = np.concatenate(out_grps, axis=-1)

if data_format == "channels_first":
out = out.transpose((0, 3, 1, 2))

return out


def np_conv3d(
x,
kernel_weights,
strides,
padding,
data_format,
dilation_rate,
groups,
):
if data_format == "channels_first":
x = x.transpose((0, 2, 3, 4, 1))
if isinstance(strides, (tuple, list)):
h_stride, w_stride, d_stride = strides
else:
h_stride = strides
w_stride = strides
d_stride = strides
if isinstance(dilation_rate, (tuple, list)):
h_dilation, w_dilation, d_dilation = dilation_rate
else:
h_dilation = dilation_rate
w_dilation = dilation_rate
d_dilation = dilation_rate

h_kernel, w_kernel, d_kernel, ch_in, ch_out = kernel_weights.shape

if h_dilation > 1 or w_dilation > 1 or d_dilation > 1:
new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1)
new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1)
new_d_kernel = d_kernel + (d_dilation - 1) * (d_kernel - 1)
new_kernel_size_tuple = (new_h_kernel, new_w_kernel, new_d_kernel)
new_kernel_weights = np.zeros(
(*new_kernel_size_tuple, ch_in, ch_out),
dtype=kernel_weights.dtype,
)
new_kernel_weights[::h_dilation, ::w_dilation, ::d_dilation] = (
kernel_weights
)
kernel_weights = new_kernel_weights
h_kernel, w_kernel, d_kernel = kernel_weights.shape[:3]

if padding == "same":
n_batch, h_x, w_x, d_x, _ = x.shape
h_pad = _same_padding(h_x, h_kernel, h_stride)
w_pad = _same_padding(w_x, w_kernel, w_stride)
d_pad = _same_padding(d_x, d_kernel, d_stride)
npad = [(0, 0)] * x.ndim
npad[1] = h_pad
npad[2] = w_pad
npad[3] = d_pad
x = np.pad(x, pad_width=npad, mode="constant", constant_values=0)

n_batch, h_x, w_x, d_x, _ = x.shape
h_out = int((h_x - h_kernel) / h_stride) + 1
w_out = int((w_x - w_kernel) / w_stride) + 1
d_out = int((d_x - d_kernel) / d_stride) + 1

out_grps = []
ch_out_groups = ch_out // groups
for grp in range(1, groups + 1):
x_in = x[..., (grp - 1) * ch_in : grp * ch_in]
stride_shape = (
n_batch,
h_out,
w_out,
d_out,
h_kernel,
w_kernel,
d_kernel,
ch_in,
)
strides = (
x_in.strides[0],
h_stride * x_in.strides[1],
w_stride * x_in.strides[2],
d_stride * x_in.strides[3],
x_in.strides[1],
x_in.strides[2],
x_in.strides[3],
x_in.strides[4],
)
inner_dim = h_kernel * w_kernel * d_kernel * ch_in
x_strided = as_strided(
x_in, shape=stride_shape, strides=strides, writeable=False
)
x_strided = x_strided.reshape(-1, inner_dim)

kernel_weights_grp = kernel_weights[
..., (grp - 1) * ch_out_groups : grp * ch_out_groups
].reshape(-1, ch_out_groups)

result = x_strided @ kernel_weights_grp
out_grps.append(result.reshape(n_batch, h_out, w_out, d_out, -1))
out = np.concatenate(out_grps, axis=-1)

if data_format == "channels_first":
out = out.transpose((0, 4, 1, 2, 3))
return out


def conv(
inputs,
kernel,
Expand All @@ -365,11 +620,6 @@ def conv(
):
data_format = backend.standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format,
transpose=False,
)
strides = _convert_to_spatial_operand(
strides,
num_spatial_dims,
Expand All @@ -394,17 +644,43 @@ def conv(
f"kernel in_channels {kernel_in_channels}. "
)
feature_group_count = channels // kernel_in_channels
return np.array(
jax.lax.conv_general_dilated(

if num_spatial_dims == 1:
return np_conv1d(
inputs,
kernel if is_tensor(kernel) else kernel.numpy(),
convert_to_tensor(kernel),
strides,
padding,
rhs_dilation=dilation_rate,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
data_format,
dilation_rate,
feature_group_count,
)
elif num_spatial_dims == 2:
return np_conv2d(
inputs,
convert_to_tensor(kernel),
strides,
padding,
data_format,
dilation_rate,
feature_group_count,
)
elif num_spatial_dims == 3:
return np_conv3d(
inputs,
convert_to_tensor(kernel),
strides,
padding,
data_format,
dilation_rate,
feature_group_count,
)
else:
raise ValueError(
"Inputs to conv operation should have ndim=3, 4, or 5,"
"corresponding to 1D, 2D and 3D inputs. Received input "
f"shape: {inputs.shape}."
)
)


def depthwise_conv(
Expand Down