From d34fb9ff1141737c02ea521a40e80241a0cf5955 Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Sun, 16 Mar 2025 22:04:31 +0900 Subject: [PATCH 1/5] Update conv method for numpy --- keras/src/backend/numpy/nn.py | 300 ++++++++++++++++++++++++++++++++-- 1 file changed, 288 insertions(+), 12 deletions(-) diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index faf3728e530a..1d95cdd864b1 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -1,6 +1,8 @@ import jax import numpy as np from jax import lax +from numpy.lib._stride_tricks_impl import as_strided +from scipy.linalg._fblas import dgemm from keras.src import backend from keras.src.backend.common.backend_utils import ( @@ -355,6 +357,259 @@ def _convert_to_lax_conv_dimension_numbers( ) +def _same_padding(input_size, kernel_size, stride): + if input_size % stride == 0: + padding = max(kernel_size - stride, 0) + else: + padding = max(kernel_size - (input_size % stride), 0) + return padding // 2, padding - padding // 2 + + +def np_conv1d( + x, + kernel_weights, + strides, + padding, + data_format, + dilation_rate, + groups, +): + if data_format == "channels_first": + x = x.swapaxes(1, 2) + h_stride = strides[0] if isinstance(strides, (tuple, list)) else strides + dilation_rate = ( + dilation_rate[0] + if isinstance(dilation_rate, (tuple, list)) + else dilation_rate + ) + kernel_size, ch_in, ch_out = kernel_weights.shape + + if dilation_rate > 1: + dilated_size = kernel_size + (dilation_rate - 1) * (kernel_size - 1) + new_kernel = np.zeros( + (dilated_size, ch_in, ch_out), dtype=kernel_weights.dtype + ) + new_kernel[::dilation_rate] = kernel_weights + kernel_weights, kernel_size = new_kernel, dilated_size + + if padding != "valid": + n_batch, h_x, _ = x.shape + h_pad = _same_padding(h_x, kernel_size, h_stride) + npad = [(0, 0)] * x.ndim + npad[1] = (h_pad[0] + h_pad[1], 0) if padding == "causal" else h_pad + x = np.pad(x, pad_width=npad, mode="constant", constant_values=0) + + n_batch, h_x, _ = x.shape + h_out = (h_x - kernel_size) // h_stride + 1 + kernel_weights = kernel_weights.reshape(-1, ch_out) + + out_grps = [] + ch_out_per_grp = ch_out // groups + for grp in range(groups): + x_in = x[..., grp * ch_in : (grp + 1) * ch_in] + x_strided = as_strided( + x_in, + shape=(n_batch, h_out, kernel_size, ch_in), + strides=( + x_in.strides[0], + h_stride * x_in.strides[1], + x_in.strides[1], + x_in.strides[2], + ), + ).reshape(n_batch * h_out, -1) + + result = dgemm( + 1.0, + x_strided, + kernel_weights[ + ..., grp * ch_out_per_grp : (grp + 1) * ch_out_per_grp + ], + ) + out_grps.append(result.reshape(n_batch, h_out, -1)) + + out = np.concatenate(out_grps, axis=-1) + if data_format == "channels_first": + out = out.swapaxes(1, 2) + return out + + +def np_conv2d( + x, + kernel_weights, + strides, + padding, + data_format, + dilation_rate, + groups, +): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 1)) + + h_stride, w_stride = ( + (strides, strides) if isinstance(strides, int) else strides + ) + h_dilation, w_dilation = ( + (dilation_rate, dilation_rate) + if isinstance(dilation_rate, int) + else dilation_rate + ) + h_kernel, w_kernel, ch_in, ch_out = kernel_weights.shape + + if h_dilation > 1 or w_dilation > 1: + new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) + new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1) + new_kernel_size = (new_h_kernel, new_w_kernel) + + new_kernel_weights = np.zeros( + (*new_kernel_size, ch_in, ch_out), dtype=kernel_weights.dtype + ) + new_kernel_weights[::h_dilation, ::w_dilation] = kernel_weights + kernel_weights = new_kernel_weights + h_kernel, w_kernel = kernel_weights.shape[:2] + + if padding == "same": + n_batch, h_x, w_x, _ = x.shape + h_pad = _same_padding(h_x, h_kernel, h_stride) + w_pad = _same_padding(w_x, w_kernel, w_stride) + npad = [(0, 0)] * x.ndim + npad[1], npad[2] = h_pad, w_pad + x = np.pad(x, pad_width=npad, mode="constant", constant_values=0) + + n_batch, h_x, w_x, _ = x.shape + h_out = (h_x - h_kernel) // h_stride + 1 + w_out = (w_x - w_kernel) // w_stride + 1 + + out_grps = [] + ch_out_groups = ch_out // groups + for grp in range(1, groups + 1): + x_in = x[..., (grp - 1) * ch_in : grp * ch_in] + + stride_shape = (n_batch, h_out, w_out, h_kernel, w_kernel, ch_in) + strides = ( + x_in.strides[0], + h_stride * x_in.strides[1], + w_stride * x_in.strides[2], + x_in.strides[1], + x_in.strides[2], + x_in.strides[3], + ) + + inner_dim = h_kernel * w_kernel * ch_in + x_strided = as_strided( + x_in, shape=stride_shape, strides=strides + ).reshape(-1, inner_dim) + + kernel_weights_grp = kernel_weights[ + ..., (grp - 1) * ch_out_groups : grp * ch_out_groups + ].reshape(-1, ch_out_groups) + out_grps.append( + (x_strided @ kernel_weights_grp).reshape(n_batch, h_out, w_out, -1) + ) + out = np.concatenate(out_grps, axis=-1) + + if data_format == "channels_first": + out = out.transpose((0, 3, 1, 2)) + + return out + + +def np_conv3d( + x, + kernel_weights, + strides, + padding, + data_format, + dilation_rate, + groups, +): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 4, 1)) + if isinstance(strides, (tuple, list)): + h_stride, w_stride, d_stride = strides + else: + h_stride = strides + w_stride = strides + d_stride = strides + if isinstance(dilation_rate, (tuple, list)): + h_dilation, w_dilation, d_dilation = dilation_rate + else: + h_dilation = dilation_rate + w_dilation = dilation_rate + d_dilation = dilation_rate + + h_kernel, w_kernel, d_kernel, ch_in, ch_out = kernel_weights.shape + + if h_dilation > 1 or w_dilation > 1 or d_dilation > 1: + new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) + new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1) + new_d_kernel = d_kernel + (d_dilation - 1) * (d_kernel - 1) + new_kenel_size_tuple = (new_h_kernel, new_w_kernel, new_d_kernel) + new_kernel_weights = np.zeros( + (*new_kenel_size_tuple, ch_in, ch_out), + dtype=kernel_weights.dtype, + ) + new_kernel_weights[::h_dilation, ::w_dilation, ::d_dilation] = ( + kernel_weights + ) + kernel_weights = new_kernel_weights + h_kernel, w_kernel, d_kernel = kernel_weights.shape[:3] + + if padding == "same": + n_batch, h_x, w_x, d_x, _ = x.shape + h_pad = _same_padding(h_x, h_kernel, h_stride) + w_pad = _same_padding(w_x, w_kernel, w_stride) + d_pad = _same_padding(d_x, d_kernel, d_stride) + npad = [(0, 0)] * x.ndim + npad[1] = h_pad + npad[2] = w_pad + npad[3] = d_pad + x = np.pad(x, pad_width=npad, mode="constant", constant_values=0) + + n_batch, h_x, w_x, d_x, _ = x.shape + h_out = int((h_x - h_kernel) / h_stride) + 1 + w_out = int((w_x - w_kernel) / w_stride) + 1 + d_out = int((d_x - d_kernel) / d_stride) + 1 + + out_grps = [] + for grp in range(1, groups + 1): + x_in = x[..., (grp - 1) * ch_in : grp * ch_in] + stride_shape = ( + n_batch, + h_out, + w_out, + d_out, + h_kernel, + w_kernel, + d_kernel, + ch_in, + ) + strides = ( + x_in.strides[0], + h_stride * x_in.strides[1], + w_stride * x_in.strides[2], + d_stride * x_in.strides[3], + x_in.strides[1], + x_in.strides[2], + x_in.strides[3], + x_in.strides[4], + ) + inner_dim = h_kernel * w_kernel * d_kernel * ch_in + x_strided = as_strided(x_in, shape=stride_shape, strides=strides) + x_strided = x_strided.reshape(-1, inner_dim) + ch_out_groups = ch_out // groups + kernel_weights_grp = kernel_weights[ + ..., (grp - 1) * ch_out_groups : grp * ch_out_groups + ].reshape(-1, ch_out_groups) + + result = x_strided @ kernel_weights_grp + out_grps.append(result.reshape(n_batch, h_out, w_out, d_out, -1)) + out = np.concatenate(out_grps, axis=-1) + + if data_format == "channels_first": + out = out.transpose((0, 4, 1, 2, 3)) + return out + + def conv( inputs, kernel, @@ -365,11 +620,6 @@ def conv( ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 - dimension_numbers = _convert_to_lax_conv_dimension_numbers( - num_spatial_dims, - data_format, - transpose=False, - ) strides = _convert_to_spatial_operand( strides, num_spatial_dims, @@ -394,17 +644,43 @@ def conv( f"kernel in_channels {kernel_in_channels}. " ) feature_group_count = channels // kernel_in_channels - return np.array( - jax.lax.conv_general_dilated( + + if num_spatial_dims == 1: + return np_conv1d( inputs, - kernel if is_tensor(kernel) else kernel.numpy(), + convert_to_tensor(kernel), strides, padding, - rhs_dilation=dilation_rate, - dimension_numbers=dimension_numbers, - feature_group_count=feature_group_count, + data_format, + dilation_rate, + feature_group_count, + ) + elif num_spatial_dims == 2: + return np_conv2d( + inputs, + convert_to_tensor(kernel), + strides, + padding, + data_format, + dilation_rate, + feature_group_count, + ) + elif num_spatial_dims == 3: + return np_conv3d( + inputs, + convert_to_tensor(kernel), + strides, + padding, + data_format, + dilation_rate, + feature_group_count, + ) + else: + raise ValueError( + "Inputs to conv operation should have ndim=3, 4, or 5," + "corresponding to 1D, 2D and 3D inputs. Received input " + f"shape: {inputs.shape}." ) - ) def depthwise_conv( From 908800e17e47d9df7a1ea34466f72a1f0e5262a7 Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Mon, 17 Mar 2025 00:28:07 +0900 Subject: [PATCH 2/5] Update conv1d --- keras/src/backend/numpy/nn.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index 1d95cdd864b1..f949603786e1 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -2,7 +2,6 @@ import numpy as np from jax import lax from numpy.lib._stride_tricks_impl import as_strided -from scipy.linalg._fblas import dgemm from keras.src import backend from keras.src.backend.common.backend_utils import ( @@ -416,14 +415,12 @@ def np_conv1d( x_in.strides[1], x_in.strides[2], ), - ).reshape(n_batch * h_out, -1) - - result = dgemm( - 1.0, - x_strided, - kernel_weights[ + ).reshape(n_batch * h_out, kernel_size * ch_in) + result = ( + x_strided + @ kernel_weights[ ..., grp * ch_out_per_grp : (grp + 1) * ch_out_per_grp - ], + ] ) out_grps.append(result.reshape(n_batch, h_out, -1)) From 837df653646af48e85a73b44c962bdeaedaf93be Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Mon, 17 Mar 2025 20:52:41 +0900 Subject: [PATCH 3/5] Update conv3d --- keras/src/backend/numpy/nn.py | 134 +++++++++++++++++++++++++++++----- 1 file changed, 114 insertions(+), 20 deletions(-) diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index f949603786e1..b08b74550163 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -406,24 +406,24 @@ def np_conv1d( ch_out_per_grp = ch_out // groups for grp in range(groups): x_in = x[..., grp * ch_in : (grp + 1) * ch_in] - x_strided = as_strided( - x_in, - shape=(n_batch, h_out, kernel_size, ch_in), - strides=( - x_in.strides[0], - h_stride * x_in.strides[1], - x_in.strides[1], - x_in.strides[2], - ), - ).reshape(n_batch * h_out, kernel_size * ch_in) - result = ( - x_strided - @ kernel_weights[ - ..., grp * ch_out_per_grp : (grp + 1) * ch_out_per_grp - ] + stride_shape = (n_batch, h_out, kernel_size, ch_in) + strides = ( + x_in.strides[0], + h_stride * x_in.strides[1], + x_in.strides[1], + x_in.strides[2], ) - out_grps.append(result.reshape(n_batch, h_out, -1)) + inner_dim = kernel_size * ch_in + x_strided = as_strided( + x_in, shape=stride_shape, strides=strides + ).reshape(-1, inner_dim) + + kernel_weights_grp = kernel_weights[ + ..., grp * ch_out_per_grp : (grp + 1) * ch_out_per_grp + ] + result = x_strided @ kernel_weights_grp + out_grps.append(result.reshape(n_batch, h_out, -1)) out = np.concatenate(out_grps, axis=-1) if data_format == "channels_first": out = out.swapaxes(1, 2) @@ -510,6 +510,97 @@ def np_conv2d( return out +from numpy.lib.stride_tricks import as_strided + + +def optimized_np_conv3d( + x, + kernel_weights, + strides, + padding, + data_format, + dilation_rate, + groups, +): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 4, 1)) # Convert to channels_last + + h_stride, w_stride, d_stride = ( + strides if isinstance(strides, (tuple, list)) else (strides,) * 3 + ) + h_dilation, w_dilation, d_dilation = ( + dilation_rate + if isinstance(dilation_rate, (tuple, list)) + else (dilation_rate,) * 3 + ) + + h_kernel, w_kernel, d_kernel, ch_in, ch_out = kernel_weights.shape + + if padding == "same": + n_batch, h_x, w_x, d_x, _ = x.shape + h_pad = _same_padding(h_x, h_kernel, h_stride) + w_pad = _same_padding(w_x, w_kernel, w_stride) + d_pad = _same_padding(d_x, d_kernel, d_stride) + x = np.pad( + x, + pad_width=[(0, 0), h_pad, w_pad, d_pad, (0, 0)], + mode="constant", + constant_values=0, + ) + + n_batch, h_x, w_x, d_x, _ = x.shape + h_out = (h_x - (h_kernel - 1) * h_dilation - 1) // h_stride + 1 + w_out = (w_x - (w_kernel - 1) * w_dilation - 1) // w_stride + 1 + d_out = (d_x - (d_kernel - 1) * d_dilation - 1) // d_stride + 1 + + # Process groups efficiently + out_grps = [] + ch_out_groups = ch_out // groups + + for grp in range(groups): + x_in = x[..., grp * ch_in : (grp + 1) * ch_in] + + # Efficient strided extraction + stride_shape = ( + n_batch, + h_out, + w_out, + d_out, + h_kernel, + w_kernel, + d_kernel, + ch_in, + ) + strides = ( + x_in.strides[0], + h_stride * x_in.strides[1], + w_stride * x_in.strides[2], + d_stride * x_in.strides[3], + h_dilation * x_in.strides[1], + w_dilation * x_in.strides[2], + d_dilation * x_in.strides[3], + x_in.strides[4], + ) + x_strided = as_strided( + x_in, shape=stride_shape, strides=strides, writeable=False + ) + + # Convolution using einsum (efficient tensor contraction) + kernel_grp = kernel_weights[ + ..., grp * ch_out_groups : (grp + 1) * ch_out_groups + ] + result = np.einsum("bihwdklc, hwdco -> bihwo", x_strided, kernel_grp) + + out_grps.append(result) + + out = np.concatenate(out_grps, axis=-1) + + if data_format == "channels_first": + out = out.transpose((0, 4, 1, 2, 3)) # Convert back to channels_first + + return out + + def np_conv3d( x, kernel_weights, @@ -540,9 +631,9 @@ def np_conv3d( new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1) new_d_kernel = d_kernel + (d_dilation - 1) * (d_kernel - 1) - new_kenel_size_tuple = (new_h_kernel, new_w_kernel, new_d_kernel) + new_kernel_size_tuple = (new_h_kernel, new_w_kernel, new_d_kernel) new_kernel_weights = np.zeros( - (*new_kenel_size_tuple, ch_in, ch_out), + (*new_kernel_size_tuple, ch_in, ch_out), dtype=kernel_weights.dtype, ) new_kernel_weights[::h_dilation, ::w_dilation, ::d_dilation] = ( @@ -568,6 +659,7 @@ def np_conv3d( d_out = int((d_x - d_kernel) / d_stride) + 1 out_grps = [] + ch_out_groups = ch_out // groups for grp in range(1, groups + 1): x_in = x[..., (grp - 1) * ch_in : grp * ch_in] stride_shape = ( @@ -591,9 +683,11 @@ def np_conv3d( x_in.strides[4], ) inner_dim = h_kernel * w_kernel * d_kernel * ch_in - x_strided = as_strided(x_in, shape=stride_shape, strides=strides) + x_strided = as_strided( + x_in, shape=stride_shape, strides=strides, writeable=False + ) x_strided = x_strided.reshape(-1, inner_dim) - ch_out_groups = ch_out // groups + kernel_weights_grp = kernel_weights[ ..., (grp - 1) * ch_out_groups : grp * ch_out_groups ].reshape(-1, ch_out_groups) From ceb13d80c3acdb881adacf88bdd1df6c6b77bdf5 Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Mon, 17 Mar 2025 20:53:16 +0900 Subject: [PATCH 4/5] Update conv3d --- keras/src/backend/numpy/nn.py | 337 +++++++++++++--------------------- 1 file changed, 123 insertions(+), 214 deletions(-) diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index b08b74550163..fdb06f795325 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -79,7 +79,7 @@ def silu(x): def squareplus(x, b=4): x = convert_to_tensor(x) b = convert_to_tensor(b, dtype=x.dtype) - y = x + np.sqrt(x**2 + b) + y = x + np.sqrt(x ** 2 + b) return y / 2 @@ -116,9 +116,9 @@ def elu(x, alpha=1.0): def selu( - x, - alpha=1.6732632423543772848170429916717, - scale=1.0507009873554804934193349852946, + x, + alpha=1.6732632423543772848170429916717, + scale=1.0507009873554804934193349852946, ): x = convert_to_tensor(x) return np.array(scale, x.dtype) * elu(x, alpha) @@ -130,19 +130,19 @@ def gelu(x, approximate=True): if approximate: sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype) cdf = np.array(0.5, x.dtype) * ( - np.array(1.0, x.dtype) - + np.tanh( - sqrt_2_over_pi - * (x + np.array(0.044715, x.dtype) * (x**3).astype(x.dtype)) - ) + np.array(1.0, x.dtype) + + np.tanh( + sqrt_2_over_pi + * (x + np.array(0.044715, x.dtype) * (x ** 3).astype(x.dtype)) + ) ) return x * cdf else: sqrt_2 = np.sqrt(2).astype(x.dtype) return ( - x - * (scipy.special.erf(x / sqrt_2) + 1).astype(x.dtype) - / np.array(2, x.dtype) + x + * (scipy.special.erf(x / sqrt_2) + 1).astype(x.dtype) + / np.array(2, x.dtype) ) @@ -216,10 +216,10 @@ def sparsemax(logits, axis=-1): def _convert_to_spatial_operand( - x, - num_spatial_dims, - data_format="channels_last", - include_batch_and_channels=True, + x, + num_spatial_dims, + data_format="channels_last", + include_batch_and_channels=True, ): # Helper function that converts an operand to a spatial operand. x = (x,) * num_spatial_dims if isinstance(x, int) else x @@ -233,12 +233,12 @@ def _convert_to_spatial_operand( def _pool( - inputs, - initial_value, - reduce_fn, - pool_size, - strides=None, - padding="valid", + inputs, + initial_value, + reduce_fn, + pool_size, + strides=None, + padding="valid", ): """Helper function to define pooling functions. @@ -273,11 +273,11 @@ def _pool( def max_pool( - inputs, - pool_size, - strides=None, - padding="valid", - data_format=None, + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 @@ -292,11 +292,11 @@ def max_pool( def average_pool( - inputs, - pool_size, - strides, - padding, - data_format=None, + inputs, + pool_size, + strides, + padding, + data_format=None, ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 @@ -332,9 +332,9 @@ def average_pool( def _convert_to_lax_conv_dimension_numbers( - num_spatial_dims, - data_format="channels_last", - transpose=False, + num_spatial_dims, + data_format="channels_last", + transpose=False, ): """Create a `lax.ConvDimensionNumbers` for the given inputs.""" num_dims = num_spatial_dims + 2 @@ -365,13 +365,13 @@ def _same_padding(input_size, kernel_size, stride): def np_conv1d( - x, - kernel_weights, - strides, - padding, - data_format, - dilation_rate, - groups, + x, + kernel_weights, + strides, + padding, + data_format, + dilation_rate, + groups, ): if data_format == "channels_first": x = x.swapaxes(1, 2) @@ -405,7 +405,7 @@ def np_conv1d( out_grps = [] ch_out_per_grp = ch_out // groups for grp in range(groups): - x_in = x[..., grp * ch_in : (grp + 1) * ch_in] + x_in = x[..., grp * ch_in: (grp + 1) * ch_in] stride_shape = (n_batch, h_out, kernel_size, ch_in) strides = ( x_in.strides[0], @@ -420,8 +420,8 @@ def np_conv1d( ).reshape(-1, inner_dim) kernel_weights_grp = kernel_weights[ - ..., grp * ch_out_per_grp : (grp + 1) * ch_out_per_grp - ] + ..., grp * ch_out_per_grp: (grp + 1) * ch_out_per_grp + ] result = x_strided @ kernel_weights_grp out_grps.append(result.reshape(n_batch, h_out, -1)) out = np.concatenate(out_grps, axis=-1) @@ -431,13 +431,13 @@ def np_conv1d( def np_conv2d( - x, - kernel_weights, - strides, - padding, - data_format, - dilation_rate, - groups, + x, + kernel_weights, + strides, + padding, + data_format, + dilation_rate, + groups, ): if data_format == "channels_first": x = x.transpose((0, 2, 3, 1)) @@ -479,7 +479,7 @@ def np_conv2d( out_grps = [] ch_out_groups = ch_out // groups for grp in range(1, groups + 1): - x_in = x[..., (grp - 1) * ch_in : grp * ch_in] + x_in = x[..., (grp - 1) * ch_in: grp * ch_in] stride_shape = (n_batch, h_out, w_out, h_kernel, w_kernel, ch_in) strides = ( @@ -497,8 +497,8 @@ def np_conv2d( ).reshape(-1, inner_dim) kernel_weights_grp = kernel_weights[ - ..., (grp - 1) * ch_out_groups : grp * ch_out_groups - ].reshape(-1, ch_out_groups) + ..., (grp - 1) * ch_out_groups: grp * ch_out_groups + ].reshape(-1, ch_out_groups) out_grps.append( (x_strided @ kernel_weights_grp).reshape(n_batch, h_out, w_out, -1) ) @@ -510,105 +510,14 @@ def np_conv2d( return out -from numpy.lib.stride_tricks import as_strided - - -def optimized_np_conv3d( - x, - kernel_weights, - strides, - padding, - data_format, - dilation_rate, - groups, -): - if data_format == "channels_first": - x = x.transpose((0, 2, 3, 4, 1)) # Convert to channels_last - - h_stride, w_stride, d_stride = ( - strides if isinstance(strides, (tuple, list)) else (strides,) * 3 - ) - h_dilation, w_dilation, d_dilation = ( - dilation_rate - if isinstance(dilation_rate, (tuple, list)) - else (dilation_rate,) * 3 - ) - - h_kernel, w_kernel, d_kernel, ch_in, ch_out = kernel_weights.shape - - if padding == "same": - n_batch, h_x, w_x, d_x, _ = x.shape - h_pad = _same_padding(h_x, h_kernel, h_stride) - w_pad = _same_padding(w_x, w_kernel, w_stride) - d_pad = _same_padding(d_x, d_kernel, d_stride) - x = np.pad( - x, - pad_width=[(0, 0), h_pad, w_pad, d_pad, (0, 0)], - mode="constant", - constant_values=0, - ) - - n_batch, h_x, w_x, d_x, _ = x.shape - h_out = (h_x - (h_kernel - 1) * h_dilation - 1) // h_stride + 1 - w_out = (w_x - (w_kernel - 1) * w_dilation - 1) // w_stride + 1 - d_out = (d_x - (d_kernel - 1) * d_dilation - 1) // d_stride + 1 - - # Process groups efficiently - out_grps = [] - ch_out_groups = ch_out // groups - - for grp in range(groups): - x_in = x[..., grp * ch_in : (grp + 1) * ch_in] - - # Efficient strided extraction - stride_shape = ( - n_batch, - h_out, - w_out, - d_out, - h_kernel, - w_kernel, - d_kernel, - ch_in, - ) - strides = ( - x_in.strides[0], - h_stride * x_in.strides[1], - w_stride * x_in.strides[2], - d_stride * x_in.strides[3], - h_dilation * x_in.strides[1], - w_dilation * x_in.strides[2], - d_dilation * x_in.strides[3], - x_in.strides[4], - ) - x_strided = as_strided( - x_in, shape=stride_shape, strides=strides, writeable=False - ) - - # Convolution using einsum (efficient tensor contraction) - kernel_grp = kernel_weights[ - ..., grp * ch_out_groups : (grp + 1) * ch_out_groups - ] - result = np.einsum("bihwdklc, hwdco -> bihwo", x_strided, kernel_grp) - - out_grps.append(result) - - out = np.concatenate(out_grps, axis=-1) - - if data_format == "channels_first": - out = out.transpose((0, 4, 1, 2, 3)) # Convert back to channels_first - - return out - - def np_conv3d( - x, - kernel_weights, - strides, - padding, - data_format, - dilation_rate, - groups, + x, + kernel_weights, + strides, + padding, + data_format, + dilation_rate, + groups, ): if data_format == "channels_first": x = x.transpose((0, 2, 3, 4, 1)) @@ -661,7 +570,7 @@ def np_conv3d( out_grps = [] ch_out_groups = ch_out // groups for grp in range(1, groups + 1): - x_in = x[..., (grp - 1) * ch_in : grp * ch_in] + x_in = x[..., (grp - 1) * ch_in: grp * ch_in] stride_shape = ( n_batch, h_out, @@ -689,8 +598,8 @@ def np_conv3d( x_strided = x_strided.reshape(-1, inner_dim) kernel_weights_grp = kernel_weights[ - ..., (grp - 1) * ch_out_groups : grp * ch_out_groups - ].reshape(-1, ch_out_groups) + ..., (grp - 1) * ch_out_groups: grp * ch_out_groups + ].reshape(-1, ch_out_groups) result = x_strided @ kernel_weights_grp out_grps.append(result.reshape(n_batch, h_out, w_out, d_out, -1)) @@ -702,12 +611,12 @@ def np_conv3d( def conv( - inputs, - kernel, - strides=1, - padding="valid", - data_format=None, - dilation_rate=1, + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 @@ -775,12 +684,12 @@ def conv( def depthwise_conv( - inputs, - kernel, - strides=1, - padding="valid", - data_format=None, - dilation_rate=1, + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 @@ -822,13 +731,13 @@ def depthwise_conv( def separable_conv( - inputs, - depthwise_kernel, - pointwise_kernel, - strides=1, - padding="valid", - data_format=None, - dilation_rate=1, + inputs, + depthwise_kernel, + pointwise_kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, ): data_format = backend.standardize_data_format(data_format) depthwise_conv_output = depthwise_conv( @@ -850,13 +759,13 @@ def separable_conv( def conv_transpose( - inputs, - kernel, - strides=1, - padding="valid", - output_padding=None, - data_format=None, - dilation_rate=1, + inputs, + kernel, + strides=1, + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=1, ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 @@ -1047,7 +956,7 @@ def moments(x, axes, keepdims=False, synchronized=False): def batch_normalization( - x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 + x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 ): shape = [1] * len(x.shape) shape[axis] = mean.shape[0] @@ -1104,7 +1013,7 @@ def _lengths_to_paddings(lengths, max_length): repeat = (target[:, :-1] == target[:, 1:]).astype(np.float32) repeat = np.pad(repeat, ((0, 0), (0, 1))) - logprobs_phi = logprobs[:, :, mask_index : mask_index + 1] # [B, T, 1] + logprobs_phi = logprobs[:, :, mask_index: mask_index + 1] # [B, T, 1] logprobs_phi = np.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] _one_hot = one_hot(target, num_classes=num_classes) # [B, N, K] @@ -1113,13 +1022,13 @@ def _lengths_to_paddings(lengths, max_length): # [B, N] logalpha_phi_init = ( - np.ones((batch_size, max_label_length + 1), dtype=output.dtype) - * log_epsilon + np.ones((batch_size, max_label_length + 1), dtype=output.dtype) + * log_epsilon ) logalpha_phi_init[:, 0] = 0.0 logalpha_emit_init = ( - np.ones((batch_size, max_label_length), dtype=output.dtype) - * log_epsilon + np.ones((batch_size, max_label_length), dtype=output.dtype) + * log_epsilon ) def update_phi_score(phi, added_score): @@ -1181,10 +1090,10 @@ def np_scan(f, init, xs): def _ctc_greedy_decode( - inputs, - sequence_lengths, - merge_repeated=True, - mask_index=None, + inputs, + sequence_lengths, + merge_repeated=True, + mask_index=None, ): inputs = convert_to_tensor(inputs) sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32") @@ -1224,11 +1133,11 @@ def _ctc_greedy_decode( def _ctc_beam_search_decode( - inputs, - sequence_lengths, - beam_width=100, - top_paths=1, - mask_index=None, + inputs, + sequence_lengths, + beam_width=100, + top_paths=1, + mask_index=None, ): inputs = convert_to_tensor(inputs) sequence_lengths = convert_to_tensor(sequence_lengths) @@ -1344,7 +1253,7 @@ def _step(prev, x): return (paths, scores, masked), None def _decode_batch( - init_paths, init_scores, init_masked, inputs, seqlen_mask + init_paths, init_scores, init_masked, inputs, seqlen_mask ): def np_scan_only_carry(f, init, xs): carry = init @@ -1389,13 +1298,13 @@ def np_scan_only_carry(f, init, xs): def ctc_decode( - inputs, - sequence_lengths, - strategy="greedy", - beam_width=100, - top_paths=1, - merge_repeated=True, - mask_index=0, + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, ): inputs = convert_to_tensor(inputs) dtype = backend.result_type(inputs.dtype, "float32") @@ -1492,15 +1401,15 @@ def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale): def dot_product_attention( - query, - key, - value, - bias=None, - mask=None, - scale=None, - is_causal=False, - flash_attention=None, - attn_logits_soft_cap=None, + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=None, + attn_logits_soft_cap=None, ): if flash_attention is None: flash_attention = False From 31055784fae23a0e0b77aafeee18eea25aab123f Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Mon, 17 Mar 2025 21:09:55 +0900 Subject: [PATCH 5/5] change conv method on numpy --- keras/src/backend/numpy/nn.py | 246 +++++++++++++++++----------------- 1 file changed, 123 insertions(+), 123 deletions(-) diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index fdb06f795325..65cb5d5d6852 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -79,7 +79,7 @@ def silu(x): def squareplus(x, b=4): x = convert_to_tensor(x) b = convert_to_tensor(b, dtype=x.dtype) - y = x + np.sqrt(x ** 2 + b) + y = x + np.sqrt(x**2 + b) return y / 2 @@ -116,9 +116,9 @@ def elu(x, alpha=1.0): def selu( - x, - alpha=1.6732632423543772848170429916717, - scale=1.0507009873554804934193349852946, + x, + alpha=1.6732632423543772848170429916717, + scale=1.0507009873554804934193349852946, ): x = convert_to_tensor(x) return np.array(scale, x.dtype) * elu(x, alpha) @@ -130,19 +130,19 @@ def gelu(x, approximate=True): if approximate: sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype) cdf = np.array(0.5, x.dtype) * ( - np.array(1.0, x.dtype) - + np.tanh( - sqrt_2_over_pi - * (x + np.array(0.044715, x.dtype) * (x ** 3).astype(x.dtype)) - ) + np.array(1.0, x.dtype) + + np.tanh( + sqrt_2_over_pi + * (x + np.array(0.044715, x.dtype) * (x**3).astype(x.dtype)) + ) ) return x * cdf else: sqrt_2 = np.sqrt(2).astype(x.dtype) return ( - x - * (scipy.special.erf(x / sqrt_2) + 1).astype(x.dtype) - / np.array(2, x.dtype) + x + * (scipy.special.erf(x / sqrt_2) + 1).astype(x.dtype) + / np.array(2, x.dtype) ) @@ -216,10 +216,10 @@ def sparsemax(logits, axis=-1): def _convert_to_spatial_operand( - x, - num_spatial_dims, - data_format="channels_last", - include_batch_and_channels=True, + x, + num_spatial_dims, + data_format="channels_last", + include_batch_and_channels=True, ): # Helper function that converts an operand to a spatial operand. x = (x,) * num_spatial_dims if isinstance(x, int) else x @@ -233,12 +233,12 @@ def _convert_to_spatial_operand( def _pool( - inputs, - initial_value, - reduce_fn, - pool_size, - strides=None, - padding="valid", + inputs, + initial_value, + reduce_fn, + pool_size, + strides=None, + padding="valid", ): """Helper function to define pooling functions. @@ -273,11 +273,11 @@ def _pool( def max_pool( - inputs, - pool_size, - strides=None, - padding="valid", - data_format=None, + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 @@ -292,11 +292,11 @@ def max_pool( def average_pool( - inputs, - pool_size, - strides, - padding, - data_format=None, + inputs, + pool_size, + strides, + padding, + data_format=None, ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 @@ -332,9 +332,9 @@ def average_pool( def _convert_to_lax_conv_dimension_numbers( - num_spatial_dims, - data_format="channels_last", - transpose=False, + num_spatial_dims, + data_format="channels_last", + transpose=False, ): """Create a `lax.ConvDimensionNumbers` for the given inputs.""" num_dims = num_spatial_dims + 2 @@ -365,13 +365,13 @@ def _same_padding(input_size, kernel_size, stride): def np_conv1d( - x, - kernel_weights, - strides, - padding, - data_format, - dilation_rate, - groups, + x, + kernel_weights, + strides, + padding, + data_format, + dilation_rate, + groups, ): if data_format == "channels_first": x = x.swapaxes(1, 2) @@ -405,7 +405,7 @@ def np_conv1d( out_grps = [] ch_out_per_grp = ch_out // groups for grp in range(groups): - x_in = x[..., grp * ch_in: (grp + 1) * ch_in] + x_in = x[..., grp * ch_in : (grp + 1) * ch_in] stride_shape = (n_batch, h_out, kernel_size, ch_in) strides = ( x_in.strides[0], @@ -420,8 +420,8 @@ def np_conv1d( ).reshape(-1, inner_dim) kernel_weights_grp = kernel_weights[ - ..., grp * ch_out_per_grp: (grp + 1) * ch_out_per_grp - ] + ..., grp * ch_out_per_grp : (grp + 1) * ch_out_per_grp + ] result = x_strided @ kernel_weights_grp out_grps.append(result.reshape(n_batch, h_out, -1)) out = np.concatenate(out_grps, axis=-1) @@ -431,13 +431,13 @@ def np_conv1d( def np_conv2d( - x, - kernel_weights, - strides, - padding, - data_format, - dilation_rate, - groups, + x, + kernel_weights, + strides, + padding, + data_format, + dilation_rate, + groups, ): if data_format == "channels_first": x = x.transpose((0, 2, 3, 1)) @@ -479,7 +479,7 @@ def np_conv2d( out_grps = [] ch_out_groups = ch_out // groups for grp in range(1, groups + 1): - x_in = x[..., (grp - 1) * ch_in: grp * ch_in] + x_in = x[..., (grp - 1) * ch_in : grp * ch_in] stride_shape = (n_batch, h_out, w_out, h_kernel, w_kernel, ch_in) strides = ( @@ -497,8 +497,8 @@ def np_conv2d( ).reshape(-1, inner_dim) kernel_weights_grp = kernel_weights[ - ..., (grp - 1) * ch_out_groups: grp * ch_out_groups - ].reshape(-1, ch_out_groups) + ..., (grp - 1) * ch_out_groups : grp * ch_out_groups + ].reshape(-1, ch_out_groups) out_grps.append( (x_strided @ kernel_weights_grp).reshape(n_batch, h_out, w_out, -1) ) @@ -511,13 +511,13 @@ def np_conv2d( def np_conv3d( - x, - kernel_weights, - strides, - padding, - data_format, - dilation_rate, - groups, + x, + kernel_weights, + strides, + padding, + data_format, + dilation_rate, + groups, ): if data_format == "channels_first": x = x.transpose((0, 2, 3, 4, 1)) @@ -570,7 +570,7 @@ def np_conv3d( out_grps = [] ch_out_groups = ch_out // groups for grp in range(1, groups + 1): - x_in = x[..., (grp - 1) * ch_in: grp * ch_in] + x_in = x[..., (grp - 1) * ch_in : grp * ch_in] stride_shape = ( n_batch, h_out, @@ -598,8 +598,8 @@ def np_conv3d( x_strided = x_strided.reshape(-1, inner_dim) kernel_weights_grp = kernel_weights[ - ..., (grp - 1) * ch_out_groups: grp * ch_out_groups - ].reshape(-1, ch_out_groups) + ..., (grp - 1) * ch_out_groups : grp * ch_out_groups + ].reshape(-1, ch_out_groups) result = x_strided @ kernel_weights_grp out_grps.append(result.reshape(n_batch, h_out, w_out, d_out, -1)) @@ -611,12 +611,12 @@ def np_conv3d( def conv( - inputs, - kernel, - strides=1, - padding="valid", - data_format=None, - dilation_rate=1, + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 @@ -684,12 +684,12 @@ def conv( def depthwise_conv( - inputs, - kernel, - strides=1, - padding="valid", - data_format=None, - dilation_rate=1, + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 @@ -731,13 +731,13 @@ def depthwise_conv( def separable_conv( - inputs, - depthwise_kernel, - pointwise_kernel, - strides=1, - padding="valid", - data_format=None, - dilation_rate=1, + inputs, + depthwise_kernel, + pointwise_kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, ): data_format = backend.standardize_data_format(data_format) depthwise_conv_output = depthwise_conv( @@ -759,13 +759,13 @@ def separable_conv( def conv_transpose( - inputs, - kernel, - strides=1, - padding="valid", - output_padding=None, - data_format=None, - dilation_rate=1, + inputs, + kernel, + strides=1, + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=1, ): data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 @@ -956,7 +956,7 @@ def moments(x, axes, keepdims=False, synchronized=False): def batch_normalization( - x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 + x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 ): shape = [1] * len(x.shape) shape[axis] = mean.shape[0] @@ -1013,7 +1013,7 @@ def _lengths_to_paddings(lengths, max_length): repeat = (target[:, :-1] == target[:, 1:]).astype(np.float32) repeat = np.pad(repeat, ((0, 0), (0, 1))) - logprobs_phi = logprobs[:, :, mask_index: mask_index + 1] # [B, T, 1] + logprobs_phi = logprobs[:, :, mask_index : mask_index + 1] # [B, T, 1] logprobs_phi = np.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] _one_hot = one_hot(target, num_classes=num_classes) # [B, N, K] @@ -1022,13 +1022,13 @@ def _lengths_to_paddings(lengths, max_length): # [B, N] logalpha_phi_init = ( - np.ones((batch_size, max_label_length + 1), dtype=output.dtype) - * log_epsilon + np.ones((batch_size, max_label_length + 1), dtype=output.dtype) + * log_epsilon ) logalpha_phi_init[:, 0] = 0.0 logalpha_emit_init = ( - np.ones((batch_size, max_label_length), dtype=output.dtype) - * log_epsilon + np.ones((batch_size, max_label_length), dtype=output.dtype) + * log_epsilon ) def update_phi_score(phi, added_score): @@ -1090,10 +1090,10 @@ def np_scan(f, init, xs): def _ctc_greedy_decode( - inputs, - sequence_lengths, - merge_repeated=True, - mask_index=None, + inputs, + sequence_lengths, + merge_repeated=True, + mask_index=None, ): inputs = convert_to_tensor(inputs) sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32") @@ -1133,11 +1133,11 @@ def _ctc_greedy_decode( def _ctc_beam_search_decode( - inputs, - sequence_lengths, - beam_width=100, - top_paths=1, - mask_index=None, + inputs, + sequence_lengths, + beam_width=100, + top_paths=1, + mask_index=None, ): inputs = convert_to_tensor(inputs) sequence_lengths = convert_to_tensor(sequence_lengths) @@ -1253,7 +1253,7 @@ def _step(prev, x): return (paths, scores, masked), None def _decode_batch( - init_paths, init_scores, init_masked, inputs, seqlen_mask + init_paths, init_scores, init_masked, inputs, seqlen_mask ): def np_scan_only_carry(f, init, xs): carry = init @@ -1298,13 +1298,13 @@ def np_scan_only_carry(f, init, xs): def ctc_decode( - inputs, - sequence_lengths, - strategy="greedy", - beam_width=100, - top_paths=1, - merge_repeated=True, - mask_index=0, + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, ): inputs = convert_to_tensor(inputs) dtype = backend.result_type(inputs.dtype, "float32") @@ -1401,15 +1401,15 @@ def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale): def dot_product_attention( - query, - key, - value, - bias=None, - mask=None, - scale=None, - is_causal=False, - flash_attention=None, - attn_logits_soft_cap=None, + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=None, + attn_logits_soft_cap=None, ): if flash_attention is None: flash_attention = False