Skip to content

Commit 470635c

Browse files
committed
Numpy 2.0 compatibility
Numpy 2.0 changes scalar repr as per: https://numpy.org/neps/nep-0051-scalar-representation.html Make matplotlib_export convert to python scalars where appropriate to avoid 'np.*' in generated code.
1 parent 1d621d0 commit 470635c

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed
Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
import numpy as np
12
import pyqtgraph as pg
23

34
from orangewidget.tests.base import GuiTest
4-
from orangewidget.utils.matplotlib_export import scatterplot_code
5+
from orangewidget.utils.matplotlib_export import (
6+
scatterplot_code, numpy_repr, compress_if_all_same, numpy_repr_int
7+
)
58

69

710
def add_intro(a):
811
r = "import matplotlib.pyplot as plt\n" + \
12+
"import numpy as np\n" + \
913
"from numpy import array\n" + \
1014
"plt.clf()"
1115
return r + a
@@ -15,8 +19,24 @@ class TestScatterPlot(GuiTest):
1519
def test_scatterplot_simple(self):
1620
plotWidget = pg.PlotWidget(background="w")
1721
scatterplot = pg.ScatterPlotItem()
18-
scatterplot.setData(x=[1, 2, 3], y=[3, 2, 1])
22+
scatterplot.setData(
23+
x=np.array([1., 2, 3]),
24+
y=np.array([3., 2, 1]),
25+
size=np.array([1., 1, 1])
26+
)
1927
plotWidget.addItem(scatterplot)
2028
code = scatterplot_code(scatterplot)
2129
self.assertIn("plt.scatter", code)
2230
exec(add_intro(code), {})
31+
32+
def test_utils(self):
33+
a = np.array([1.5, 2.5])
34+
self.assertIn("1.5, 2.5", numpy_repr(a))
35+
a = np.array([1, 1])
36+
v = compress_if_all_same(a)
37+
self.assertEqual(v, 1)
38+
self.assertEqual(repr(v), "1")
39+
self.assertIs(type(v), int)
40+
a = np.array([1, 2], dtype=int)
41+
v = numpy_repr_int(a)
42+
self.assertIn("1, 2", v)

orangewidget/utils/matplotlib_export.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def numpy_repr(a):
1414
# avoid numpy repr as it changes between versions
1515
# TODO handle numpy repr differences
1616
if isinstance(a, np.ndarray):
17-
return "array(" + repr(list(a)) + ")"
17+
return "array(" + repr(a.tolist()) + ")"
1818
try:
1919
np.set_printoptions(threshold=10**10)
2020
return repr(a)
@@ -25,12 +25,16 @@ def numpy_repr(a):
2525
def numpy_repr_int(a):
2626
# avoid numpy repr as it changes between versions
2727
# TODO handle numpy repr differences
28-
return "array(" + repr(list(a)) + ", dtype='int')"
28+
return "array(" + repr(a.tolist()) + ", dtype='int')"
2929

3030

3131
def compress_if_all_same(l):
3232
s = set(l)
33-
return s.pop() if len(s) == 1 else l
33+
if len(s) == 1:
34+
v = s.pop()
35+
return v.item() if isinstance(v, np.generic) else v
36+
else:
37+
return l
3438

3539

3640
def is_sequence_not_string(a):

0 commit comments

Comments
 (0)