Skip to content

Commit e6486c4

Browse files
committed
Numpy 2.0 compatibility
Numpy 2.0 changes scalar repr as per: https://numpy.org/neps/nep-0051-scalar-representation.html Make matplotlib_export convert to python scalars where appropriate to avoid 'np.*' in generated code.
1 parent e75bdf8 commit e6486c4

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

orangewidget/tests/test_matplotlib_export.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import numpy as np
12
import pyqtgraph as pg
23

34
from orangewidget.tests.base import GuiTest
4-
from orangewidget.utils.matplotlib_export import scatterplot_code
5+
from orangewidget.utils.matplotlib_export import (
6+
scatterplot_code, numpy_repr, compress_if_all_same, numpy_repr_int
7+
)
58

69

710
def add_intro(a):
@@ -15,8 +18,24 @@ class TestScatterPlot(GuiTest):
1518
def test_scatterplot_simple(self):
1619
plotWidget = pg.PlotWidget(background="w")
1720
scatterplot = pg.ScatterPlotItem()
18-
scatterplot.setData(x=[1, 2, 3], y=[3, 2, 1])
21+
scatterplot.setData(
22+
x=np.array([1., 2, 3]),
23+
y=np.array([3., 2, 1]),
24+
size=np.array([1., 1, 1])
25+
)
1926
plotWidget.addItem(scatterplot)
2027
code = scatterplot_code(scatterplot)
2128
self.assertIn("plt.scatter", code)
2229
exec(add_intro(code), {})
30+
31+
def test_utils(self):
32+
a = np.array([1.5, 2.5])
33+
self.assertIn("1.5, 2.5", numpy_repr(a))
34+
a = np.array([1, 1])
35+
v = compress_if_all_same(a)
36+
self.assertEqual(v, 1)
37+
self.assertEqual(repr(v), "1")
38+
self.assertIs(type(v), int)
39+
a = np.array([1, 2], dtype=int)
40+
v = numpy_repr_int(a)
41+
self.assertIn("1, 2", v)

orangewidget/utils/matplotlib_export.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def numpy_repr(a):
1414
# avoid numpy repr as it changes between versions
1515
# TODO handle numpy repr differences
1616
if isinstance(a, np.ndarray):
17-
return "array(" + repr(list(a)) + ")"
17+
return "array(" + repr(a.tolist()) + ")"
1818
try:
1919
np.set_printoptions(threshold=10**10)
2020
return repr(a)
@@ -25,12 +25,20 @@ def numpy_repr(a):
2525
def numpy_repr_int(a):
2626
# avoid numpy repr as it changes between versions
2727
# TODO handle numpy repr differences
28-
return "array(" + repr(list(a)) + ", dtype='int')"
28+
if isinstance(a, np.ndarray):
29+
a = a.tolist()
30+
else:
31+
a = list(a)
32+
return "array(" + repr(a) + ", dtype='int')"
2933

3034

3135
def compress_if_all_same(l):
3236
s = set(l)
33-
return s.pop() if len(s) == 1 else l
37+
if len(s) == 1:
38+
v = s.pop()
39+
return v.item() if isinstance(v, np.generic) else v
40+
else:
41+
return l
3442

3543

3644
def is_sequence_not_string(a):
@@ -188,6 +196,7 @@ def scene_code(scene):
188196
code = []
189197

190198
code.append("import matplotlib.pyplot as plt")
199+
code.append("import numpy as np")
191200
code.append("from numpy import array")
192201

193202
code.append("")

0 commit comments

Comments
 (0)