Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to the latest broadcast implement. #284

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 49 additions & 35 deletions src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,33 +497,53 @@ end
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict
using Base.Broadcast: combine_styles

struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
@static if fieldcount(Base.Broadcast.Broadcasted) == 4
struct StructArrayStyle{N, S} <: AbstractArrayStyle{N}
style::S
StructArrayStyle{N}(style) where {N} = new{N, typeof(style)}(style)
end
StructArrayStyle{N}(style::StructArrayStyle) where {N} = StructArrayStyle{N}(style.style)
parent_style(s::BroadcastStyle) = s
parent_style(s::StructArrayStyle) = s.style
style(bc::Broadcasted) = bc.style
const broadcasted = Broadcasted
else
struct StructArrayStyle{N, S} <: AbstractArrayStyle{N}
StructArrayStyle{N}(style) where {N} = new{N, typeof(style)}()
end
StructArrayStyle{N}(style::StructArrayStyle{M, S}) where {N, M, S} = StructArrayStyle{N}(S())
parent_style(s::BroadcastStyle) = s
parent_style(::StructArrayStyle{N, S}) where {N, S} = S()
style(::Broadcasted{Style}) where {Style} = Style()
broadcasted(s, f, args, axes) = Broadcasted{typeof(s)}(f, args, axes)
end
StructArrayStyle{N, S}() where {N, S} = StructArrayStyle{N}(S())
parent_style(bc::Broadcasted) = parent_style(style(bc))
ofstyle(s, bc::Broadcasted) = broadcasted(s, bc.f, bc.args, bc.axes)

# Here we define the dimension tracking behavior of StructArrayStyle
function StructArrayStyle{S, M}(::Val{N}) where {S, M, N}
function StructArrayStyle{M, S}(::Val{N}) where {S, M, N}
T = S <: AbstractArrayStyle{M} ? typeof(S(Val{N}())) : S
return StructArrayStyle{T, N}()
return StructArrayStyle{N, T}()
end

# StructArrayStyle is a wrapped style.
# Here we try our best to resolve style conflict.
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M}
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{N, S}) where {S, N, M}
N′ = M === Any || N === Any ? Any : max(M, N)
S′ = Broadcast.result_style(S(), b)
return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}()
return StructArrayStyle{N′}(Broadcast.result_style(parent_style(a), b))
end
BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown()

@inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} =
combine_style_types(BroadcastStyle(A), args...)
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} =
combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...)
combine_style_types(::StructArrayStyle{S}) where {S} = S() # avoid nested StructArrayStyle
combine_style_types(s::BroadcastStyle) = s

Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).parameters...)

BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()
BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{ndims(SA)}(cst(SA))

"""
always_struct_broadcast(style::BroadcastStyle)
Expand Down Expand Up @@ -551,8 +571,8 @@ See also [`always_struct_broadcast`](@ref).
"""
try_struct_copy(bc::Broadcasted) = copy(bc)

function Base.copy(bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
if always_struct_broadcast(S())
function Base.copy(bc::Broadcasted{<:StructArrayStyle})
if always_struct_broadcast(parent_style(bc))
return invoke(copy, Tuple{Broadcasted}, bc)
else
return try_struct_copy(replace_structarray(bc))
Expand All @@ -567,55 +587,49 @@ an equivalent one without it. This is not a must if the root `BroadcastStyle`
supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
"""
function replace_structarray(bc::Broadcasted{Style}) where {Style}
function replace_structarray(bc::Broadcasted)
args = replace_structarray_args(bc.args)
Style′ = parent_style(Style())
return Broadcasted{Style′}(bc.f, args, bc.axes)
style = parent_style(bc)
return broadcasted(style, bc.f, args, bc.axes)
end
function replace_structarray(A::StructArray)
f = Instantiator(eltype(A))
args = Tuple(components(A))
Style = typeof(combine_styles(args...))
return Broadcasted{Style}(f, args, axes(A))
style = combine_styles(args...)
return broadcasted(style, f, args, axes(A))
end
replace_structarray(@nospecialize(A)) = A

replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...)
replace_structarray_args(::Tuple{}) = ()

parent_style(@nospecialize(x)) = typeof(x)
parent_style(::StructArrayStyle{S, N}) where {S, N} = S
parent_style(::StructArrayStyle{S, N}) where {N, S<:AbstractArrayStyle{N}} = S
parent_style(::StructArrayStyle{S, N}) where {S<:AbstractArrayStyle{Any}, N} = S
parent_style(::StructArrayStyle{S, N}) where {S<:AbstractArrayStyle, N} = typeof(S(Val(N)))

# `instantiate` and `_axes` might be overloaded for static axes.
function Broadcast.instantiate(bc::Broadcasted{Style}) where {Style <: StructArrayStyle}
Style′ = parent_style(Style())
bc′ = Broadcast.instantiate(convert(Broadcasted{Style′}, bc))
return convert(Broadcasted{Style}, bc′)
function Broadcast.instantiate(bc::Broadcasted{<:StructArrayStyle})
bc′ = Broadcast.instantiate(ofstyle(parent_style(bc), bc))
return ofstyle(style(bc), bc′)
end

function Broadcast._axes(bc::Broadcasted{Style}, ::Nothing) where {Style <: StructArrayStyle}
Style′ = parent_style(Style())
return Broadcast._axes(convert(Broadcasted{Style′}, bc), nothing)
function Broadcast._axes(bc::Broadcasted{<:StructArrayStyle}, ::Nothing)
return Broadcast._axes(ofstyle(parent_style(bc), bc), nothing)
end

# Here we use `similar` defined for `S` to build the dest Array.
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType}
bc′ = convert(Broadcasted{S}, bc)
function Base.similar(bc::Broadcasted{<:StructArrayStyle}, ::Type{ElType}) where {ElType}
bc′ = ofstyle(parent_style(bc), bc)
return isnonemptystructtype(ElType) ? buildfromschema(T -> similar(bc′, T), ElType) : similar(bc′, ElType)
end

# Unwrapper to recover the behaviour defined by parent style.
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
bc′ = always_struct_broadcast(S()) ? convert(Broadcasted{S}, bc) : replace_structarray(bc)
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:StructArrayStyle})
ps = parent_style(bc)
bc′ = always_struct_broadcast(ps) ? ofstyle(ps, bc) : replace_structarray(bc)
return copyto!(dest, bc′)
end

@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S}
bc′ = always_struct_broadcast(S()) ? bc : replace_structarray(bc)
return Broadcast.materialize!(S(), dest, bc′)
@inline function Broadcast.materialize!(s::StructArrayStyle, dest, bc::Broadcasted)
ps = parent_style(s)
bc′ = always_struct_broadcast(ps) ? bc : replace_structarray(bc)
return Broadcast.materialize!(ps, dest, bc′)
end

# for aliasing analysis during broadcast
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
ares = map(a->a.re, as)
aims = map(a->a.im, as)
style = Broadcast.combine_styles(ares...)
@test Broadcast.combine_styles(as...) === StructArrayStyle{typeof(style),1}()
@test Broadcast.combine_styles(as...) === StructArrayStyle{1,typeof(style)}()
if !(style in tested_style)
push!(tested_style, style)
if style isa Broadcast.ArrayStyle{MyArray3}
Expand All @@ -1249,8 +1249,8 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}

#parent_style
@test StructArrays.parent_style(StructArrayStyle{Broadcast.DefaultArrayStyle{0},2}()) == Broadcast.DefaultArrayStyle{2}
@test StructArrays.parent_style(StructArrayStyle{Broadcast.Style{Tuple},2}()) == Broadcast.Style{Tuple}
@test StructArrays.parent_style(StructArrayStyle{2,Broadcast.DefaultArrayStyle{0}}()) == Broadcast.DefaultArrayStyle{0}()
@test StructArrays.parent_style(StructArrayStyle{2,Broadcast.Style{Tuple}}()) == Broadcast.Style{Tuple}()

# allocation test for overloaded `broadcast_unalias`
StructArrays.always_struct_broadcast(::Broadcast.ArrayStyle{MyArray1}) = false
Expand Down