Skip to content

Commit cdbe4fb

Browse files
feat(ops): Add keras.ops.numpy.rot90 operation (#20723) (#20745)
* feat(ops): Add keras.ops.image.rot90 operation Adds a new operation to rotate tensors by 90 degrees in the specified plane: - Implements rot90 operation in keras.ops.image module - Adds support for multiple rotations (k parameter) and custom axes - Matches numpy.rot90 behavior and API for consistency - Adds comprehensive test coverage including batch images support - Handles input validation for tensor dimensions and axes - Supports symbolic tensor execution The operation follows the same interface as numpy.rot90 and tf.image.rot90: rot90(array, k=1, axes=(0, 1)) * feat: add JAX, NumPy and PyTorch backends for rot90 Add implementations of rot90() for multiple backend frameworks: - JAX backend implementation - NumPy backend implementation - PyTorch backend implementation * Move rot90 from image to numpy ops Move rot90 operation to numpy.py files in backend implementations since it's a numpy op (https://numpy.org/doc/stable/reference/generated/numpy.rot90.html). Now exported as both keras.ops.rot90 and keras.ops.numpy.rot90. * Fix dtype conflict in PyTorch backend's rot90 function Resolved the 'Invalid dtype: object' error by explicitly using to avoid naming conflicts with the custom function. * Replace experimental NumPy rot90 with core TF ops Replace tf.experimental.numpy.rot90 with core TF ops for XLA compatibility. Use convert_to_tensor for input handling.
1 parent e67ac8f commit cdbe4fb

File tree

6 files changed

+233
-0
lines changed

6 files changed

+233
-0
lines changed

keras/src/backend/jax/numpy.py

+13
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,19 @@
1515
from keras.src.backend.jax.core import convert_to_tensor
1616

1717

18+
def rot90(array, k=1, axes=(0, 1)):
19+
"""Rotate an array by 90 degrees in the specified plane."""
20+
if array.ndim < 2:
21+
raise ValueError(
22+
f"Input array must have at least 2 dimensions. Received: array.ndim={array.ndim}"
23+
)
24+
if len(axes) != 2 or axes[0] == axes[1]:
25+
raise ValueError(
26+
f"Invalid axes: {axes}. Axes must be a tuple of two different dimensions."
27+
)
28+
return jnp.rot90(array, k=k, axes=axes)
29+
30+
1831
@sparse.elementwise_binary_union(linear=True, use_sparsify=True)
1932
def add(x1, x2):
2033
x1 = convert_to_tensor(x1)

keras/src/backend/numpy/numpy.py

+13
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,19 @@
88
from keras.src.backend.numpy.core import convert_to_tensor
99

1010

11+
def rot90(array, k=1, axes=(0, 1)):
12+
"""Rotate an array by 90 degrees in the specified plane."""
13+
if array.ndim < 2:
14+
raise ValueError(
15+
f"Input array must have at least 2 dimensions. Received: array.ndim={array.ndim}"
16+
)
17+
if len(axes) != 2 or axes[0] == axes[1]:
18+
raise ValueError(
19+
f"Invalid axes: {axes}. Axes must be a tuple of two different dimensions."
20+
)
21+
return np.rot90(array, k=k, axes=axes)
22+
23+
1124
def add(x1, x2):
1225
if not isinstance(x1, (int, float)):
1326
x1 = convert_to_tensor(x1)

keras/src/backend/tensorflow/numpy.py

+43
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,49 @@
2323
from keras.src.backend.tensorflow.core import shape as shape_op
2424

2525

26+
def rot90(array, k=1, axes=(0, 1)):
27+
"""Rotate an array by 90 degrees in the specified plane."""
28+
array = convert_to_tensor(array)
29+
30+
if array.shape.rank < 2:
31+
raise ValueError(
32+
f"Input array must have at least 2 dimensions. Received: array.ndim={array.shape.rank}"
33+
)
34+
35+
if len(axes) != 2 or axes[0] == axes[1]:
36+
raise ValueError(
37+
f"Invalid axes: {axes}. Axes must be a tuple of two different dimensions."
38+
)
39+
40+
k = k % 4
41+
if k == 0:
42+
return array
43+
44+
axes = tuple(axis if axis >= 0 else array.shape.rank + axis for axis in axes)
45+
46+
perm = [i for i in range(array.shape.rank) if i not in axes]
47+
perm.extend(axes)
48+
array = tf.transpose(array, perm)
49+
50+
shape = tf.shape(array)
51+
non_rot_shape = shape[:-2]
52+
rot_shape = shape[-2:]
53+
54+
array = tf.reshape(array, tf.concat([[-1], rot_shape], axis=0))
55+
56+
for _ in range(k):
57+
array = tf.transpose(array, [0, 2, 1])
58+
array = tf.reverse(array, axis=[1])
59+
array = tf.reshape(array, tf.concat([non_rot_shape, rot_shape], axis=0))
60+
61+
inv_perm = [0] * len(perm)
62+
for i, p in enumerate(perm):
63+
inv_perm[p] = i
64+
array = tf.transpose(array, inv_perm)
65+
66+
return array
67+
68+
2669
@sparse.elementwise_binary_union(tf.sparse.add)
2770
def add(x1, x2):
2871
if not isinstance(x1, (int, float)):

keras/src/backend/torch/numpy.py

+34
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,40 @@
2525
)
2626

2727

28+
def rot90(array, k=1, axes=(0, 1)):
29+
"""Rotate an array by 90 degrees in the specified plane using PyTorch.
30+
31+
Args:
32+
array: Input tensor
33+
k: Number of 90-degree rotations (default=1)
34+
axes: Tuple of two axes that define the plane of rotation (default=(0,1))
35+
36+
Returns:
37+
Rotated tensor
38+
"""
39+
array = convert_to_tensor(array)
40+
41+
if array.ndim < 2:
42+
raise ValueError(
43+
f"Input array must have at least 2 dimensions. Received: array.ndim={array.ndim}"
44+
)
45+
if len(axes) != 2 or axes[0] == axes[1]:
46+
raise ValueError(
47+
f"Invalid axes: {axes}. Axes must be a tuple of two different dimensions."
48+
)
49+
50+
axes = tuple(axis if axis >= 0 else array.ndim + axis for axis in axes)
51+
52+
if not builtins.all(0 <= axis < array.ndim for axis in axes):
53+
raise ValueError(f"Invalid axes {axes} for tensor with {array.ndim} dimensions")
54+
55+
rotated = torch.rot90(array, k=k, dims=axes)
56+
if isinstance(array, np.ndarray):
57+
rotated = rotated.cpu().numpy()
58+
59+
return rotated
60+
61+
2862
def add(x1, x2):
2963
x1 = convert_to_tensor(x1)
3064
x2 = convert_to_tensor(x2)

keras/src/ops/numpy.py

+63
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,69 @@
1616
from keras.src.ops.operation_utils import reduce_shape
1717

1818

19+
class Rot90(Operation):
20+
def __init__(self, k=1, axes=(0, 1)):
21+
super().__init__()
22+
self.k = k
23+
self.axes = axes
24+
25+
def call(self, array):
26+
return backend.numpy.rot90(array, k=self.k, axes=self.axes)
27+
28+
def compute_output_spec(self, array):
29+
array_shape = list(array.shape)
30+
if len(array_shape) < 2:
31+
raise ValueError(
32+
"Input array must have at least 2 dimensions. "
33+
f"Received: array.shape={array_shape}"
34+
)
35+
if len(self.axes) != 2 or self.axes[0] == self.axes[1]:
36+
raise ValueError(
37+
f"Invalid axes: {self.axes}. Axes must be a tuple of two different dimensions."
38+
)
39+
axis1, axis2 = self.axes
40+
array_shape[axis1], array_shape[axis2] = array_shape[axis2], array_shape[axis1]
41+
return KerasTensor(shape=array_shape, dtype=array.dtype)
42+
43+
44+
@keras_export(["keras.ops.rot90", "keras.ops.numpy.rot90"])
45+
def rot90(array, k=1, axes=(0, 1)):
46+
"""Rotate an array by 90 degrees in the plane specified by axes.
47+
48+
This function rotates an array counterclockwise by 90 degrees `k` times
49+
in the plane specified by `axes`. Supports arrays of two or more dimensions.
50+
51+
Args:
52+
array: Input array to rotate.
53+
k: Number of times the array is rotated by 90 degrees.
54+
axes: A tuple of two integers specifying the plane for rotation.
55+
56+
Returns:
57+
Rotated array.
58+
59+
Examples:
60+
61+
>>> import numpy as np
62+
>>> from keras import ops
63+
>>> m = np.array([[1, 2], [3, 4]])
64+
>>> rotated = ops.rot90(m)
65+
>>> rotated
66+
array([[2, 4],
67+
[1, 3]])
68+
69+
>>> m = np.arange(8).reshape((2, 2, 2))
70+
>>> rotated = ops.rot90(m, k=1, axes=(1, 2))
71+
>>> rotated
72+
array([[[1, 3],
73+
[0, 2]],
74+
[[5, 7],
75+
[4, 6]]])
76+
"""
77+
if any_symbolic_tensors((array,)):
78+
return Rot90(k=k, axes=axes).symbolic_call(array)
79+
return backend.numpy.rot90(array, k=k, axes=axes)
80+
81+
1982
def shape_equal(shape1, shape2, axis=None, allow_none=True):
2083
"""Check if two shapes are equal.
2184

keras/src/ops/numpy_test.py

+67
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,73 @@
1919
from keras.src.testing.test_utils import named_product
2020

2121

22+
class NumPyTestRot90(testing.TestCase):
23+
def test_basic(self):
24+
array = np.array([[1, 2], [3, 4]])
25+
rotated = knp.rot90(array)
26+
expected = np.array([[2, 4], [1, 3]])
27+
assert np.array_equal(rotated, expected), f"Failed basic 2D test: {rotated}"
28+
29+
def test_multiple_k(self):
30+
array = np.array([[1, 2], [3, 4]])
31+
32+
# k=2 (180 degrees rotation)
33+
rotated = knp.rot90(array, k=2)
34+
expected = np.array([[4, 3], [2, 1]])
35+
assert np.array_equal(rotated, expected), f"Failed k=2 test: {rotated}"
36+
37+
# k=3 (270 degrees rotation)
38+
rotated = knp.rot90(array, k=3)
39+
expected = np.array([[3, 1], [4, 2]])
40+
assert np.array_equal(rotated, expected), f"Failed k=3 test: {rotated}"
41+
42+
# k=4 (full rotation)
43+
rotated = knp.rot90(array, k=4)
44+
expected = array
45+
assert np.array_equal(rotated, expected), f"Failed k=4 test: {rotated}"
46+
47+
def test_axes(self):
48+
array = np.arange(8).reshape((2, 2, 2))
49+
rotated = knp.rot90(array, k=1, axes=(1, 2))
50+
expected = np.array([[[1, 3], [0, 2]], [[5, 7], [4, 6]]])
51+
assert np.array_equal(rotated, expected), f"Failed custom axes test: {rotated}"
52+
53+
def test_single_image(self):
54+
array = np.random.random((4, 4, 3))
55+
rotated = knp.rot90(array, k=1, axes=(0, 1))
56+
expected = np.rot90(array, k=1, axes=(0, 1))
57+
assert np.allclose(rotated, expected), "Failed single image test"
58+
59+
def test_batch_images(self):
60+
array = np.random.random((2, 4, 4, 3))
61+
rotated = knp.rot90(array, k=1, axes=(1, 2))
62+
expected = np.rot90(array, k=1, axes=(1, 2))
63+
assert np.allclose(rotated, expected), "Failed batch images test"
64+
65+
def test_invalid_axes(self):
66+
array = np.array([[1, 2], [3, 4]])
67+
try:
68+
knp.rot90(array, axes=(0, 0))
69+
except ValueError as e:
70+
assert (
71+
"Invalid axes: (0, 0). Axes must be a tuple of two different dimensions."
72+
in str(e)
73+
), f"Failed invalid axes test: {e}"
74+
else:
75+
raise AssertionError("Failed to raise error for invalid axes")
76+
77+
def test_invalid_rank(self):
78+
array = np.array([1, 2, 3]) # 1D array
79+
try:
80+
knp.rot90(array)
81+
except ValueError as e:
82+
assert (
83+
"Input array must have at least 2 dimensions." in str(e)
84+
), f"Failed invalid rank test: {e}"
85+
else:
86+
raise AssertionError("Failed to raise error for invalid input rank")
87+
88+
2289
class NumpyTwoInputOpsDynamicShapeTest(testing.TestCase):
2390
def test_add(self):
2491
x = KerasTensor((None, 3))

0 commit comments

Comments
 (0)