22import torch .nn .functional as tnn
33
44from keras .src import backend
5- from keras .src import tree
65from keras .src .backend .common .backend_utils import (
76 compute_conv_transpose_padding_args_for_torch ,
87)
@@ -204,17 +203,27 @@ def sparsemax(logits, axis=-1):
204203def _compute_padding_length (
205204 input_length , kernel_length , stride , dilation_rate = 1
206205):
207- """Compute padding length along one dimension."""
208- total_padding_length = (
209- dilation_rate * (kernel_length - 1 ) - (input_length - 1 ) % stride
210- )
211- left_padding = total_padding_length // 2
212- right_padding = (total_padding_length + 1 ) // 2
206+ """Compute padding length along one dimension with support
207+ for asymmetric padding."""
208+ effective_k_size = (kernel_length - 1 ) * dilation_rate + 1
209+ if stride == 1 :
210+ # total padding is kernel_size - 1
211+ total_padding = effective_k_size - 1
212+ else :
213+ # calc. needed padding for case with stride involved
214+ output_size = (input_length + stride - 1 ) // stride
215+ total_padding = max (
216+ 0 , (output_size - 1 ) * stride + effective_k_size - input_length
217+ )
218+
219+ # divide padding evenly, with extra pixel going at the end if needed
220+ left_padding = total_padding // 2
221+ right_padding = total_padding - left_padding
213222 return (left_padding , right_padding )
214223
215224
216225def _apply_same_padding (
217- inputs , kernel_size , strides , operation_type , dilation_rate = 1
226+ inputs , kernel_size , strides , data_format , operation_type , dilation_rate = 1
218227):
219228 """Apply same padding to the input tensor.
220229
@@ -231,50 +240,49 @@ def _apply_same_padding(
231240 """
232241 spatial_shape = inputs .shape [2 :]
233242 num_spatial_dims = len (spatial_shape )
234- padding = ()
243+ padding = []
244+
245+ if operation_type != "pooling" :
246+ dilation_rate = standardize_tuple (
247+ dilation_rate , num_spatial_dims , "dilation_rate"
248+ )
235249
236250 for i in range (num_spatial_dims ):
237- if operation_type == "pooling" :
238- padding_size = _compute_padding_length (
239- spatial_shape [i ], kernel_size [i ], strides [i ]
240- )
241- mode = "replicate"
242- else :
243- dilation_rate = standardize_tuple (
244- dilation_rate , num_spatial_dims , "dilation_rate"
245- )
246- padding_size = _compute_padding_length (
247- spatial_shape [i ], kernel_size [i ], strides [i ], dilation_rate [i ]
248- )
249- mode = "constant"
250- padding = (padding_size ,) + padding
251+ dil = 1 if operation_type == "pooling" else dilation_rate [i ]
252+ pad = _compute_padding_length (
253+ spatial_shape [i ], kernel_size [i ], strides [i ], dil
254+ )
255+ padding .append (pad )
251256
252- if all ([left == right for left , right in padding ]):
257+ # convert padding to torch format
258+ if all (left == right for left , right in padding ):
253259 return inputs , [left for left , _ in padding ]
254260
255- flattened_padding = tuple (
256- value for left_and_right in padding for value in left_and_right
257- )
258- return tnn .pad (inputs , pad = flattened_padding , mode = mode ), 0
261+ # else, need to pad manually
262+ flattened_padding = []
263+ for pad in reversed (padding ):
264+ flattened_padding .extend (pad )
265+
266+ mode = "replicate" if operation_type == "pooling" else "constant"
267+ return tnn .pad (inputs , pad = tuple (flattened_padding ), mode = mode ), 0
259268
260269
261270def _transpose_spatial_inputs (inputs ):
262- num_spatial_dims = inputs . ndim - 2
271+ """Transpose inputs from channels_last to channels_first format."""
263272 # Torch pooling does not support `channels_last` format, so
264273 # we need to transpose to `channels_first` format.
265- if num_spatial_dims == 1 :
266- inputs = torch .permute (inputs , (0 , 2 , 1 ))
267- elif num_spatial_dims == 2 :
268- inputs = torch .permute (inputs , (0 , 3 , 1 , 2 ))
269- elif num_spatial_dims == 3 :
270- inputs = torch .permute (inputs , (0 , 4 , 1 , 2 , 3 ))
271- else :
272- raise ValueError (
273- "Inputs must have ndim=3, 4 or 5, "
274- "corresponding to 1D, 2D and 3D inputs. "
275- f"Received input shape: { inputs .shape } ."
276- )
277- return inputs
274+ ndim = inputs .ndim - 2
275+ if ndim == 1 : # 1D case
276+ return torch .permute (inputs , (0 , 2 , 1 ))
277+ elif ndim == 2 : # 2D case
278+ return torch .permute (inputs , (0 , 3 , 1 , 2 ))
279+ elif ndim == 3 : # 3D case
280+ return torch .permute (inputs , (0 , 4 , 1 , 2 , 3 ))
281+ raise ValueError (
282+ "Inputs must have ndim=3, 4 or 5, "
283+ "corresponding to 1D, 2D and 3D inputs. "
284+ f"Received input shape: { inputs .shape } ."
285+ )
278286
279287
280288def _transpose_spatial_outputs (outputs ):
@@ -309,6 +317,7 @@ def max_pool(
309317 padding = "valid" ,
310318 data_format = None ,
311319):
320+ """Fixed max pooling implementation."""
312321 inputs = convert_to_tensor (inputs )
313322 num_spatial_dims = inputs .ndim - 2
314323 pool_size = standardize_tuple (pool_size , num_spatial_dims , "pool_size" )
@@ -325,7 +334,7 @@ def max_pool(
325334 # Torch does not natively support `"same"` padding, we need to manually
326335 # apply the right amount of padding to `inputs`.
327336 inputs , padding = _apply_same_padding (
328- inputs , pool_size , strides , operation_type = "pooling"
337+ inputs , pool_size , strides , data_format , "pooling"
329338 )
330339 else :
331340 padding = 0
@@ -370,26 +379,36 @@ def average_pool(
370379 padding = "valid" ,
371380 data_format = None ,
372381):
382+ """Fixed average pooling with correct padding calculation."""
373383 inputs = convert_to_tensor (inputs )
374384 num_spatial_dims = inputs .ndim - 2
375385 pool_size = standardize_tuple (pool_size , num_spatial_dims , "pool_size" )
376- if strides is None :
377- strides = pool_size
378- else :
379- strides = standardize_tuple (strides , num_spatial_dims , "strides" )
386+ strides = (
387+ pool_size
388+ if strides is None
389+ else standardize_tuple (strides , num_spatial_dims , "strides" )
390+ )
380391
381392 data_format = backend .standardize_data_format (data_format )
393+ orig_format = data_format
394+
382395 if data_format == "channels_last" :
383396 inputs = _transpose_spatial_inputs (inputs )
397+
384398 if padding == "same" :
385399 # Torch does not natively support `"same"` padding, we need to manually
386400 # apply the right amount of padding to `inputs`.
387401 inputs , padding = _apply_same_padding (
388- inputs , pool_size , strides , operation_type = "pooling"
402+ inputs ,
403+ pool_size ,
404+ strides ,
405+ "channels_first" , # we're in channels_first here
406+ "pooling" ,
389407 )
390408 else :
391409 padding = 0
392410
411+ # apply pooling
393412 if num_spatial_dims == 1 :
394413 outputs = tnn .avg_pool1d (
395414 inputs ,
@@ -420,8 +439,10 @@ def average_pool(
420439 "corresponding to 1D, 2D and 3D inputs. "
421440 f"Received input shape: { inputs .shape } ."
422441 )
423- if data_format == "channels_last" :
442+
443+ if orig_format == "channels_last" :
424444 outputs = _transpose_spatial_outputs (outputs )
445+
425446 return outputs
426447
427448
@@ -433,6 +454,7 @@ def conv(
433454 data_format = None ,
434455 dilation_rate = 1 ,
435456):
457+ """Convolution with fixed group handling."""
436458 inputs = convert_to_tensor (inputs )
437459 kernel = convert_to_tensor (kernel )
438460 num_spatial_dims = inputs .ndim - 2
@@ -441,53 +463,59 @@ def conv(
441463 data_format = backend .standardize_data_format (data_format )
442464 if data_format == "channels_last" :
443465 inputs = _transpose_spatial_inputs (inputs )
444- # Transpose kernel from keras format to torch format.
466+
445467 kernel = _transpose_conv_kernel (kernel )
446- if padding == "same" and any (d != 1 for d in tree .flatten (strides )):
447- # Torch does not support this case in conv2d().
448- # Manually pad the tensor.
468+
469+ # calc. groups snippet
470+ in_channels = inputs .shape [1 ]
471+ kernel_in_channels = kernel .shape [1 ]
472+ if in_channels % kernel_in_channels != 0 :
473+ raise ValueError (
474+ f"Input channels ({ in_channels } ) must be divisible by "
475+ f"kernel input channels ({ kernel_in_channels } )"
476+ )
477+ groups = in_channels // kernel_in_channels
478+
479+ # handle padding
480+ if padding == "same" :
449481 inputs , padding = _apply_same_padding (
450482 inputs ,
451483 kernel .shape [2 :],
452484 strides ,
453- operation_type = "conv" ,
454- dilation_rate = dilation_rate ,
455- )
456- channels = inputs .shape [1 ]
457- kernel_in_channels = kernel .shape [1 ]
458- if channels % kernel_in_channels > 0 :
459- raise ValueError (
460- "The number of input channels must be evenly divisible by "
461- f"kernel.shape[1]. Received: inputs.shape={ inputs .shape } , "
462- f"kernel.shape={ kernel .shape } "
485+ data_format ,
486+ "conv" ,
487+ dilation_rate ,
463488 )
464- groups = channels // kernel_in_channels
489+ else :
490+ padding = 0
491+
492+ # apply convolution
465493 if num_spatial_dims == 1 :
466494 outputs = tnn .conv1d (
467495 inputs ,
468496 kernel ,
469497 stride = strides ,
498+ padding = padding ,
470499 dilation = dilation_rate ,
471500 groups = groups ,
472- padding = padding ,
473501 )
474502 elif num_spatial_dims == 2 :
475503 outputs = tnn .conv2d (
476504 inputs ,
477505 kernel ,
478506 stride = strides ,
507+ padding = padding ,
479508 dilation = dilation_rate ,
480509 groups = groups ,
481- padding = padding ,
482510 )
483511 elif num_spatial_dims == 3 :
484512 outputs = tnn .conv3d (
485513 inputs ,
486514 kernel ,
487515 stride = strides ,
516+ padding = padding ,
488517 dilation = dilation_rate ,
489518 groups = groups ,
490- padding = padding ,
491519 )
492520 else :
493521 raise ValueError (
0 commit comments