Skip to content

Commit cdc980d

Browse files
committed
feat: guard undefined children behind Nullable
1 parent 938b335 commit cdc980d

File tree

10 files changed

+90
-63
lines changed

10 files changed

+90
-63
lines changed

src/DynamicExpressions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ end
2828
import Reexport: @reexport
2929
macro ignore(args...) end
3030

31+
import .UtilsModule: Nullable
3132
import .ValueInterfaceModule:
3233
is_valid,
3334
is_valid_array,

src/Interfaces.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module InterfacesModule
33

44
using Interfaces: Interfaces, @interface, @implements, Arguments
55
using DispatchDoctor: @unstable
6+
using ..UtilsModule: Nullable
67
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
78
using ..NodeModule:
89
Node,
@@ -14,6 +15,7 @@ using ..NodeModule:
1415
with_type_parameters,
1516
with_max_degree,
1617
max_degree,
18+
unsafe_get_children,
1719
get_children,
1820
leaf_copy,
1921
leaf_convert,
@@ -228,9 +230,9 @@ function _check_create_node(tree::AbstractExpressionNode)
228230
end
229231
function _check_get_children(tree::AbstractExpressionNode{T,D}) where {T,D}
230232
tree.degree == 0 && return true
231-
return get_children(tree) isa Tuple{typeof(tree),Vararg{typeof(tree)}} &&
232-
get_children(tree, Val(D)) isa Tuple &&
233-
length(get_children(tree, Val(D))) == D &&
233+
return unsafe_get_children(tree) isa NTuple{D,Nullable{typeof(tree)}} &&
234+
get_children(tree, tree.degree) isa Tuple &&
235+
length(get_children(tree, tree.degree)) == tree.degree &&
234236
length(get_children(tree, Val(1))) == 1
235237
end
236238
function _check_copy(tree::AbstractExpressionNode)

src/Node.jl

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module NodeModule
22

33
using DispatchDoctor: @unstable
44

5-
import ..UtilsModule: deprecate_varmap, Undefined
5+
import ..UtilsModule: deprecate_varmap, Undefined, Nullable
66

77
const DEFAULT_NODE_TYPE = Float32
88
const DEFAULT_MAX_DEGREE = 2
@@ -78,7 +78,7 @@ for N in (:Node, :GraphNode)
7878
val::T # If is a constant, this stores the actual value
7979
feature::UInt16 # (Possibly undefined) If is a variable (e.g., x in cos(x)), this stores the feature index.
8080
op::UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum
81-
children::NTuple{D,$N{T,D}}
81+
children::NTuple{D,Nullable{$N{T,D}}}
8282

8383
#################
8484
## Constructors:
@@ -173,13 +173,7 @@ Accessing this node should trigger some kind of noticable error
173173
(e.g., default returns itself, which causes infinite recursion).
174174
"""
175175
function poison_node(n::AbstractNode)
176-
# We don't want to use `nothing` because the type instability
177-
# hits memory hard.
178-
# Setting itself as the right child is the best thing,
179-
# because it (1) doesn't allocate, and (2) will trigger
180-
# infinite recursion errors if someone is mistakenly trying
181-
# to access the right child when `.degree == 1`.
182-
return n
176+
return Nullable(true, n)
183177
end
184178

185179
"""
@@ -190,7 +184,7 @@ children may be "poisoned" nodes which you should not access,
190184
as they will trigger infinite recursion errors. Ensure to
191185
only access children only up to the `.degree` of the node.
192186
"""
193-
@inline function get_children(node::AbstractNode)
187+
@inline function unsafe_get_children(node::AbstractNode)
194188
return getfield(node, :children)
195189
end
196190

@@ -207,8 +201,8 @@ for total type stability.
207201
return get_children(node, Val(n))
208202
end
209203
@inline function get_children(node::AbstractNode{D}, ::Val{n}) where {D,n}
210-
cs = get_children(node)
211-
return ntuple(i -> cs[i], Val(n))
204+
cs = unsafe_get_children(node)
205+
return ntuple(i -> cs[i][], Val(n))
212206
end
213207

214208
"""
@@ -217,7 +211,7 @@ end
217211
Return the `i`-th child of a node (1-indexed).
218212
"""
219213
@inline function get_child(n::AbstractNode{D}, i::Int) where {D}
220-
return get_children(n)[i]
214+
return unsafe_get_children(n)[i][]
221215
end
222216

223217
"""
@@ -227,7 +221,7 @@ Replace the `i`-th child of a node (1-indexed) with the given child node.
227221
Returns the new child. Updates the children tuple in-place.
228222
"""
229223
@inline function set_child!(n::AbstractNode{D}, child::AbstractNode{D}, i::Int) where {D}
230-
set_children!(n, Base.setindex(get_children(n), child, i))
224+
set_children!(n, Base.setindex(unsafe_get_children(n), Nullable(false, child), i))
231225
return child
232226
end
233227

@@ -242,17 +236,21 @@ provided than the node's maximum degree, remaining slots are filled with poison
242236
) where {D}
243237
D2 = length(children)
244238
if D === D2
245-
n.children = children
239+
n.children = ntuple(i -> _ensure_nullable(@inbounds(children[i])), Val(D))
246240
else
247241
poison = poison_node(n)
248242
# We insert poison at the end of the tuple so that
249243
# errors will appear loudly if accessed.
250244
# This poison should be efficient to insert. So
251-
# for simplicity, we can just use poison == n, which
252-
# will trigger infinite recursion errors if accessed.
253-
n.children = ntuple(i -> i <= D2 ? @inbounds(children[i]) : poison, Val(D))
245+
# for simplicity, we can just use poison := Nullable(true, n)
246+
# which will raise an UndefRefError if accessed.
247+
n.children = ntuple(
248+
i -> i <= D2 ? _ensure_nullable(@inbounds(children[i])) : poison, Val(D)
249+
)
254250
end
255251
end
252+
@inline _ensure_nullable(x) = Nullable(false, x)
253+
@inline _ensure_nullable(x::Nullable) = x
256254

257255
macro make_accessors(node_type)
258256
esc(
@@ -491,7 +489,7 @@ function set_leaf!(tree::AbstractExpressionNode, new_leaf::AbstractExpressionNod
491489
end
492490
function set_branch!(tree::AbstractExpressionNode, new_branch::AbstractExpressionNode)
493491
tree.op = new_branch.op
494-
set_children!(tree, get_children(new_branch))
492+
set_children!(tree, unsafe_get_children(new_branch))
495493
return nothing
496494
end
497495

src/NodePreallocation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function leaf_copy_into!(dest::N, src::N) where {N<:AbstractExpressionNode}
5656
end
5757
# COV_EXCL_STOP
5858
function branch_copy_into!(
59-
dest::N, src::N, children::Vararg{N,M}
59+
dest::N, src::N, children::Vararg{Any,M}
6060
) where {T,D,N<:AbstractExpressionNode{T,D},M}
6161
dest.degree = M
6262
dest.op = src.op

src/NodeUtils.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module NodeUtilsModule
22

3+
using ..UtilsModule: Nullable
4+
35
import ..NodeModule:
46
AbstractNode,
57
AbstractExpressionNode,
@@ -147,7 +149,7 @@ mutable struct NodeIndex{T,D} <: AbstractNode{D}
147149
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
148150
val::T # If is a constant, this stores the actual value
149151
# ------------------- (possibly undefined below)
150-
children::NTuple{D,NodeIndex{T,D}}
152+
children::NTuple{D,Nullable{NodeIndex{T,D}}}
151153

152154
function NodeIndex(::Type{_T}, ::Val{_D}, val) where {_T,_D}
153155
return new{_T,_D}(0, convert(_T, val))
@@ -166,10 +168,10 @@ NodeIndex(::Type{T}, ::Val{D}) where {T,D} = NodeIndex(T, Val(D), zero(T))
166168

167169
@inline function Base.getproperty(n::NodeIndex, k::Symbol)
168170
if k == :l
169-
# TODO: Should a depwarn be raised here? Or too slow?
170-
return getfield(n, :children)[1]
171+
# TODO: Should a deprecation warning be raised here? Or too slow?
172+
return getfield(n, :children)[1][]
171173
elseif k == :r
172-
return getfield(n, :children)[2]
174+
return getfield(n, :children)[2][]
173175
else
174176
return getfield(n, k)
175177
end

src/ParametricExpression.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
88
using ..ExpressionModule:
99
AbstractExpression, Metadata, with_contents, with_metadata, unpack_metadata
1010
using ..ChainRulesModule: NodeTangent
11+
using ..UtilsModule: Nullable
1112

1213
import ..NodeModule:
1314
constructorof,
@@ -58,7 +59,7 @@ mutable struct ParametricNode{T,D} <: AbstractExpressionNode{T,D}
5859
parameter::UInt16 # Stores index of per-class parameter
5960

6061
op::UInt8
61-
children::NTuple{D,ParametricNode{T,D}} # Children nodes
62+
children::NTuple{D,Nullable{ParametricNode{T,D}}} # Children nodes
6263

6364
function ParametricNode{_T,_D}() where {_T,_D}
6465
n = new{_T,_D}()

src/Utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,16 @@ struct ResultOk2{A<:AbstractArray,B<:AbstractArray}
5858
ok::Bool
5959
end
6060

61+
struct Nullable{T}
62+
null::Bool
63+
x::T
64+
end
65+
@inline function Base.getindex(x::Nullable)
66+
x.null && throw(UndefRefError())
67+
return x.x
68+
end
69+
@inline function Base.convert(::Type{Nullable{T}}, x::Nullable) where {T}
70+
return Nullable(x.null, convert(T, x.x))
71+
end
72+
6173
end

test/test_custom_node_type.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
using DynamicExpressions
2+
using DynamicExpressions: Nullable
23
using Test
34

45
mutable struct MyCustomNode{A,B} <: AbstractNode{2}
56
degree::Int
67
val1::A
78
val2::B
8-
children::NTuple{2,MyCustomNode{A,B}}
9+
children::NTuple{2,Nullable{MyCustomNode{A,B}}}
910

1011
MyCustomNode(val1, val2) = new{typeof(val1),typeof(val2)}(0, val1, val2)
1112
function MyCustomNode(val1, val2, l)
1213
n = MyCustomNode(val1, val2)
1314
poison = n
1415
n.degree = 1
15-
n.children = (l, poison)
16+
set_children!(n, (l, poison))
1617
return n
1718
end
1819
function MyCustomNode(val1, val2, l, r)
19-
return new{typeof(val1),typeof(val2)}(2, val1, val2, (l, r))
20+
n = new{typeof(val1),typeof(val2)}(2, val1, val2)
21+
set_children!(n, (l, r))
22+
return n
2023
end
2124
end
2225

@@ -31,7 +34,7 @@ node2 = MyCustomNode(1.5, 3, node1)
3134

3235
@test typeof(node2) == MyCustomNode{Float64,Int}
3336
@test node2.degree == 1
34-
@test node2.children[1].degree == 0
37+
@test get_child(node2, 1).degree == 0
3538
@test count_depth(node2) == 2
3639
@test count_nodes(node2) == 2
3740

@@ -50,7 +53,7 @@ mutable struct MyCustomNode2{T} <: AbstractExpressionNode{T,2}
5053
val::T
5154
feature::UInt16
5255
op::UInt8
53-
children::NTuple{2,Base.RefValue{MyCustomNode2{T}}}
56+
children::NTuple{2,Nullable{MyCustomNode2{T}}}
5457
end
5558

5659
@test_throws ErrorException MyCustomNode2()

test/test_extra_node_fields.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
using Test
44
using DynamicExpressions
5-
using DynamicExpressions: constructorof, max_degree
5+
using DynamicExpressions: Nullable, constructorof, max_degree
66

77
mutable struct FrozenNode{T,D} <: AbstractExpressionNode{T,D}
88
degree::UInt8
@@ -11,7 +11,7 @@ mutable struct FrozenNode{T,D} <: AbstractExpressionNode{T,D}
1111
frozen::Bool # Extra field!
1212
feature::UInt16
1313
op::UInt8
14-
children::NTuple{D,FrozenNode{T,D}}
14+
children::NTuple{D,Nullable{FrozenNode{T,D}}}
1515

1616
function FrozenNode{_T,_D}() where {_T,_D}
1717
n = new{_T,_D}()
@@ -104,5 +104,5 @@ ex = parse_expression(
104104

105105
@test string_tree(ex) == "x + sin(y + 2.1)"
106106
@test ex.tree.frozen == false
107-
@test ex.tree.children[2].frozen == true
108-
@test ex.tree.children[2].children[1].frozen == false
107+
@test ex.tree.children[2][].frozen == true
108+
@test ex.tree.children[2][].children[1][].frozen == false

0 commit comments

Comments
 (0)