Skip to content

Commit b458523

Browse files
feat: copy_behaviors to make sub-classing easy (#3137)
* feat: copy_behaviors to make sub-classing easy * pylint errors and tests * make 'copy_behaviors' safer by making it immutable --------- Co-authored-by: Jim Pivarski <[email protected]> Co-authored-by: Jim Pivarski <[email protected]>
1 parent a1da072 commit b458523

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

src/awkward/_util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import struct
88
import sys
9+
import typing
910
from collections.abc import Collection
1011

1112
import numpy as np # noqa: TID251
@@ -102,3 +103,18 @@ def unique_list(items: Collection[T]) -> list[T]:
102103
seen.add(item)
103104
result.append(item)
104105
return result
106+
107+
108+
def copy_behaviors(existing_class: typing.Any, new_class: typing.Any, behavior: dict):
109+
output = {}
110+
111+
oldname = existing_class.__name__
112+
newname = new_class.__name__
113+
114+
for key, value in behavior.items():
115+
if oldname in key:
116+
if not isinstance(key, str) and "*" not in key:
117+
new_tuple = tuple(newname if k == oldname else k for k in key)
118+
output[new_tuple] = value
119+
120+
return output

tests/test_2433_copy_behaviors.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
2+
3+
from __future__ import annotations
4+
5+
import numpy
6+
import pytest
7+
8+
import awkward as ak
9+
10+
11+
def test():
12+
class SuperVector:
13+
def add(self, other):
14+
"""Add two vectors together elementwise using `x` and `y` components"""
15+
return ak.zip(
16+
{"x": self.x + other.x, "y": self.y + other.y},
17+
with_name="VectorTwoD",
18+
behavior=self.behavior,
19+
)
20+
21+
# first sub-class
22+
@ak.mixin_class(ak.behavior)
23+
class VectorTwoD(SuperVector):
24+
def __eq__(self, other):
25+
return ak.all(self.x == other.x) and ak.all(self.y == other.y)
26+
27+
v = ak.Array(
28+
[
29+
[{"x": 1, "y": 1.1}, {"x": 2, "y": 2.2}],
30+
[],
31+
[{"x": 3, "y": 3.3}],
32+
[
33+
{"x": 4, "y": 4.4},
34+
{"x": 5, "y": 5.5},
35+
{"x": 6, "y": 6.6},
36+
],
37+
],
38+
with_name="VectorTwoD",
39+
behavior=ak.behavior,
40+
)
41+
v_added = ak.Array(
42+
[
43+
[{"x": 2, "y": 2.2}, {"x": 4, "y": 4.4}],
44+
[],
45+
[{"x": 6, "y": 6.6}],
46+
[
47+
{"x": 8, "y": 8.8},
48+
{"x": 10, "y": 11},
49+
{"x": 12, "y": 13.2},
50+
],
51+
],
52+
with_name="VectorTwoD",
53+
behavior=ak.behavior,
54+
)
55+
56+
# add method works but the binary operator does not
57+
assert v.add(v) == v_added
58+
with pytest.raises(TypeError):
59+
v + v
60+
61+
# registering the operator makes everything work
62+
ak.behavior[numpy.add, "VectorTwoD", "VectorTwoD"] = lambda v1, v2: v1.add(v2)
63+
assert v + v == v_added
64+
65+
# second sub-class
66+
@ak.mixin_class(ak.behavior)
67+
class VectorTwoDAgain(VectorTwoD):
68+
pass
69+
70+
v = ak.Array(
71+
[
72+
[{"x": 1, "y": 1.1}, {"x": 2, "y": 2.2}],
73+
[],
74+
[{"x": 3, "y": 3.3}],
75+
[
76+
{"x": 4, "y": 4.4},
77+
{"x": 5, "y": 5.5},
78+
{"x": 6, "y": 6.6},
79+
],
80+
],
81+
with_name="VectorTwoDAgain",
82+
behavior=ak.behavior,
83+
)
84+
# add method works but the binary operator does not
85+
assert v.add(v) == v_added
86+
with pytest.raises(TypeError):
87+
v + v
88+
89+
# instead of registering every operator again, just copy the behaviors of
90+
# another class to this class
91+
ak.behavior.update(
92+
ak._util.copy_behaviors(VectorTwoD, VectorTwoDAgain, ak.behavior)
93+
)
94+
assert v + v == v_added
95+
96+
# third sub-class
97+
@ak.mixin_class(ak.behavior)
98+
class VectorTwoDAgainAgain(VectorTwoDAgain):
99+
pass
100+
101+
v = ak.Array(
102+
[
103+
[{"x": 1, "y": 1.1}, {"x": 2, "y": 2.2}],
104+
[],
105+
[{"x": 3, "y": 3.3}],
106+
[
107+
{"x": 4, "y": 4.4},
108+
{"x": 5, "y": 5.5},
109+
{"x": 6, "y": 6.6},
110+
],
111+
],
112+
with_name="VectorTwoDAgainAgain",
113+
behavior=ak.behavior,
114+
)
115+
# add method works but the binary operator does not
116+
assert v.add(v) == v_added
117+
with pytest.raises(TypeError):
118+
v + v
119+
120+
# instead of registering every operator again, just copy the behaviors of
121+
# another class to this class
122+
ak.behavior.update(
123+
ak._util.copy_behaviors(VectorTwoDAgain, VectorTwoDAgainAgain, ak.behavior)
124+
)
125+
assert v + v == v_added

0 commit comments

Comments
 (0)