Skip to content

Commit 2320a26

Browse files
committed
update test to pass with numpy-2.3
1 parent 2d27110 commit 2320a26

File tree

7 files changed

+30
-13
lines changed

7 files changed

+30
-13
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ env:
2222
test-env-name: 'test'
2323
rerun-tests-on-failure: 'true'
2424
rerun-tests-max-attempts: 2
25-
rerun-tests-timeout: 35
25+
rerun-tests-timeout: 40
2626

2727
jobs:
2828
build:

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,7 +1187,11 @@ def _norm_int_axis(x, ord, axis, keepdims):
11871187
if ord == dpnp.inf:
11881188
if x.shape[axis] == 0:
11891189
x = dpnp.moveaxis(x, axis, -1)
1190-
return dpnp.zeros_like(x, shape=x.shape[:-1])
1190+
res_shape = x.shape[:-1]
1191+
result = dpnp.zeros_like(x, shape=res_shape)
1192+
if keepdims:
1193+
result = result.reshape(res_shape + (1,))
1194+
return result
11911195
return dpnp.abs(x).max(axis=axis, keepdims=keepdims)
11921196
if ord == -dpnp.inf:
11931197
return dpnp.abs(x).min(axis=axis, keepdims=keepdims)
@@ -1222,11 +1226,16 @@ def _norm_tuple_axis(x, ord, row_axis, col_axis, keepdims):
12221226
12231227
"""
12241228

1229+
# pylint: disable=too-many-branches
12251230
axis = (row_axis, col_axis)
12261231
flag = x.shape[row_axis] == 0 or x.shape[col_axis] == 0
12271232
if flag and ord in [1, 2, dpnp.inf]:
12281233
x = dpnp.moveaxis(x, axis, (-2, -1))
1229-
return dpnp.zeros_like(x, shape=x.shape[:-2])
1234+
res_shape = x.shape[:-2]
1235+
result = dpnp.zeros_like(x, shape=res_shape)
1236+
if keepdims:
1237+
result = result.reshape(res_shape + (1, 1))
1238+
return result
12301239
if row_axis == col_axis:
12311240
raise ValueError("Duplicate axes given.")
12321241
if ord == 2:

dpnp/tests/test_linalg.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
get_integer_float_dtypes,
2424
has_support_aspect64,
2525
is_cpu_device,
26-
is_cuda_device,
2726
numpy_version,
2827
requires_intel_mkl_version,
2928
)
@@ -2104,11 +2103,12 @@ def test_empty(self, shape, ord, axis, keepdims):
21042103
assert_raises(ValueError, dpnp.linalg.norm, ia, **kwarg)
21052104
assert_raises(ValueError, numpy.linalg.norm, a, **kwarg)
21062105
else:
2107-
# TODO: when similar changes in numpy are available, instead
2108-
# of assert_equal with zero, we should compare with numpy
2109-
# ord in [None, 1, 2]
2110-
assert_equal(dpnp.linalg.norm(ia, **kwarg), 0.0)
2111-
assert_raises(ValueError, numpy.linalg.norm, a, **kwarg)
2106+
if numpy_version() >= "2.3.0":
2107+
result = dpnp.linalg.norm(ia, **kwarg)
2108+
expected = numpy.linalg.norm(a, **kwarg)
2109+
assert_dtype_allclose(result, expected)
2110+
else:
2111+
assert_equal(dpnp.linalg.norm(ia, **kwarg), 0.0)
21122112
else:
21132113
result = dpnp.linalg.norm(ia, **kwarg)
21142114
expected = numpy.linalg.norm(a, **kwarg)

dpnp/tests/test_product.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
assert_dtype_allclose,
1313
generate_random_numpy_array,
1414
get_all_dtypes,
15-
get_complex_dtypes,
16-
is_win_platform,
1715
numpy_version,
1816
)
1917
from .third_party.cupy import testing
@@ -845,6 +843,8 @@ def test_dtype_matrix(self, dt_in1, dt_in2, dt_out, shape1, shape2):
845843
assert_raises(TypeError, dpnp.matmul, ia, ib, out=iout)
846844
assert_raises(TypeError, numpy.matmul, a, b, out=out)
847845

846+
# TODO: include numpy-2.3 when numpy-issue-29164 is resolved
847+
@testing.with_requires("numpy<2.3")
848848
@pytest.mark.parametrize("dtype", _selected_dtypes)
849849
@pytest.mark.parametrize("order1", ["C", "F", "A"])
850850
@pytest.mark.parametrize("order2", ["C", "F", "A"])
@@ -882,6 +882,8 @@ def test_order(self, dtype, order1, order2, order, shape1, shape2):
882882
assert result.flags.f_contiguous == expected.flags.f_contiguous
883883
assert_dtype_allclose(result, expected)
884884

885+
# TODO: include numpy-2.3 when numpy-issue-29164 is resolved
886+
@testing.with_requires("numpy<2.3")
885887
@pytest.mark.parametrize("dtype", _selected_dtypes)
886888
@pytest.mark.parametrize(
887889
"stride",

dpnp/tests/testing/array.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def _assert(assert_func, result, expected, *args, **kwargs):
4949
]
5050
# For numpy < 2.0, some tests will fail for dtype mismatch
5151
dev = dpctl.select_default_device()
52-
if numpy.__version__ >= "2.0.0" and dev.has_aspect_fp64:
52+
if (
53+
numpy.lib.NumpyVersion(numpy.__version__) >= "2.0.0"
54+
and dev.has_aspect_fp64
55+
):
5356
strict = kwargs.setdefault("strict", True)
5457
if flag:
5558
if strict:

dpnp/tests/third_party/cupy/manipulation_tests/test_add_remove.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,8 @@ def test_unique_inverse(self, xp, dtype, attr):
340340
a = testing.shaped_random((100, 100), xp, dtype)
341341
return getattr(xp.unique_inverse(a), attr)
342342

343-
@testing.with_requires("numpy>=2.0")
343+
# TODO: include numpy-2.3 when dpnp-issue-2476 is addressed
344+
@testing.with_requires("numpy>=2.0", "numpy<2.3")
344345
@testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True)
345346
@testing.numpy_cupy_array_equal()
346347
def test_unique_values(self, xp, dtype):

dpnp/tests/third_party/cupy/math_tests/test_matmul.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def test_cupy_matmul(self, xp, dtype1, dtype2):
9999
)
100100
class TestMatmulOut(unittest.TestCase):
101101

102+
# TODO: include numpy-2.3 when numpy-issue-29164 is resolved
103+
@testing.with_requires("numpy<2.3")
102104
# no_int8=True is added to avoid overflow
103105
@testing.for_all_dtypes(name="dtype1", no_int8=True)
104106
@testing.for_all_dtypes(name="dtype2", no_int8=True)

0 commit comments

Comments
 (0)