From ec06e4c663edbe102a3228dcd3ec68dcfc6e76e7 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 13:36:50 +0530 Subject: [PATCH 01/11] Add a BroadcastStyle for AbstractFill --- src/FillArrays.jl | 3 +- src/fillbroadcast.jl | 197 +++++++++++++++++++++++++------------------ test/runtests.jl | 46 ++++++++-- 3 files changed, 152 insertions(+), 94 deletions(-) diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 66dad480..633baef9 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -7,7 +7,8 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, any, all, axes, isone, iszero, iterate, unique, allunique, permutedims, inv, copy, vec, setindex!, count, ==, reshape, map, zero, show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat, - parent, similar, issorted, add_sum, accumulate, OneTo, permutedims + parent, similar, issorted, add_sum, accumulate, OneTo, permutedims, + real, imag, conj import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec, diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 2b5ea59c..426420f9 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -73,22 +73,79 @@ function mapreduce(f, op, A::AbstractFill, B::AbstractFill, Cs::AbstractArray... end -### Unary broadcasting +## BroadcastStyle + +abstract type AbstractFillStyle{N} <: Broadcast.AbstractArrayStyle{N} end +struct FillStyle{N} <: AbstractFillStyle{N} end +struct ZerosStyle{N} <: AbstractFillStyle{N} end +FillStyle{N}(::Val{M}) where {N,M} = FillStyle{M}() +ZerosStyle{N}(::Val{M}) where {N,M} = ZerosStyle{M}() +Broadcast.BroadcastStyle(::Type{<:AbstractFill{<:Any,N}}) where {N} = FillStyle{N}() +Broadcast.BroadcastStyle(::Type{<:AbstractZeros{<:Any,N}}) where {N} = ZerosStyle{N}() +Broadcast.BroadcastStyle(::FillStyle{M}, ::ZerosStyle{N}) where {M,N} = FillStyle{max(M,N)}() +Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{2}) = S +Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{1}) = S +Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{0}) = S + +_getindex_value(f::AbstractFill) = getindex_value(f) +_getindex_value(x::Number) = x +_getindex_value(x::Ref) = x[] +function _getindex_value(bc::Broadcast.Broadcasted) + bc.f(map(_getindex_value, bc.args)...) +end + +has_static_value(x) = false +has_static_value(x::Union{AbstractZeros, AbstractOnes}) = true +has_static_value(x::Broadcast.Broadcasted) = all(has_static_value, x.args) -function broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}) where {T,N} - return Fill(op(getindex_value(r)), axes(r)) +function _iszeros(bc::Broadcast.Broadcasted) + all(has_static_value, bc.args) && _iszero(_getindex_value(bc)) end +# conservative check for zeros. In most cases we can't really compare with zero +_iszero(x::Union{Number, AbstractArray}) = iszero(x) +_iszero(_) = false -broadcasted(::DefaultArrayStyle, ::typeof(+), r::AbstractZeros) = r -broadcasted(::DefaultArrayStyle, ::typeof(-), r::AbstractZeros) = r -broadcasted(::DefaultArrayStyle, ::typeof(+), r::AbstractOnes) = r +function _isones(bc::Broadcast.Broadcasted) + all(has_static_value, bc.args) && _isone(_getindex_value(bc)) +end +# conservative check for ones. In most cases we can't really compare with one +_isone(x::Union{Number, AbstractArray}) = isone(x) +_isone(_) = false + +_isfill(bc::Broadcast.Broadcasted) = all(_isfill, bc.args) +_isfill(f::AbstractFill) = true +_isfill(f::Number) = true +_isfill(f::Ref) = true +_isfill(::Any) = false + +function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{N}}) where {N} + if _iszeros(bc) + return Zeros(typeof(_getindex_value(bc)), axes(bc)) + elseif _isones(bc) + return Ones(typeof(_getindex_value(bc)), axes(bc)) + elseif _isfill(bc) + return Fill(_getindex_value(bc), axes(bc)) + else + # fallback style + S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{N}} + copy(convert(S, bc)) + end +end +# make the zero-dimensional case consistent with Base +function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}}) + S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}} + copy(convert(S, bc)) +end -broadcasted(::DefaultArrayStyle{N}, ::typeof(conj), r::AbstractZeros{T,N}) where {T,N} = r -broadcasted(::DefaultArrayStyle{N}, ::typeof(conj), r::AbstractOnes{T,N}) where {T,N} = r -broadcasted(::DefaultArrayStyle{N}, ::typeof(real), r::AbstractZeros{T,N}) where {T,N} = Zeros{real(T)}(axes(r)) -broadcasted(::DefaultArrayStyle{N}, ::typeof(real), r::AbstractOnes{T,N}) where {T,N} = Ones{real(T)}(axes(r)) -broadcasted(::DefaultArrayStyle{N}, ::typeof(imag), r::AbstractZeros{T,N}) where {T,N} = Zeros{real(T)}(axes(r)) -broadcasted(::DefaultArrayStyle{N}, ::typeof(imag), r::AbstractOnes{T,N}) where {T,N} = Zeros{real(T)}(axes(r)) +# some cases that preserve 0d +function broadcast_preserving_0d(f, As...) + bc = Base.broadcasted(f, As...) + r = copy(bc) + length(axes(bc)) == 0 ? Fill(r) : r +end +for f in (:real, :imag, :conj) + @eval ($f)(A::AbstractFill) = broadcast_preserving_0d($f, A) +end ### Binary broadcasting @@ -100,12 +157,6 @@ broadcasted_zeros(f, a, b, elt, ax) = Zeros{elt}(ax) broadcasted_ones(f, a, elt, ax) = Ones{elt}(ax) broadcasted_ones(f, a, b, elt, ax) = Ones{elt}(ax) -function broadcasted(::DefaultArrayStyle, op, a::AbstractFill, b::AbstractFill) - val = op(getindex_value(a), getindex_value(b)) - ax = broadcast_shape(axes(a), axes(b)) - return broadcasted_fill(op, a, b, val, ax) -end - function _broadcasted_zeros(f, a, b) elt = Base.Broadcast.combine_eltypes(f, (a, b)) ax = broadcast_shape(axes(a), axes(b)) @@ -122,57 +173,40 @@ function _broadcasted_nan(f, a, b) return broadcasted_fill(f, a, b, val, ax) end -broadcasted(::DefaultArrayStyle, ::typeof(+), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(+, a, b) -broadcasted(::DefaultArrayStyle, ::typeof(+), a::AbstractOnes, b::AbstractZeros) = _broadcasted_ones(+, a, b) -broadcasted(::DefaultArrayStyle, ::typeof(+), a::AbstractZeros, b::AbstractOnes) = _broadcasted_ones(+, a, b) - -broadcasted(::DefaultArrayStyle, ::typeof(-), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(-, a, b) -broadcasted(::DefaultArrayStyle, ::typeof(-), a::AbstractOnes, b::AbstractZeros) = _broadcasted_ones(-, a, b) -broadcasted(::DefaultArrayStyle, ::typeof(-), a::AbstractOnes, b::AbstractOnes) = _broadcasted_zeros(-, a, b) - -broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::AbstractZerosVector, b::AbstractZerosVector) = _broadcasted_zeros(+, a, b) -broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::AbstractOnesVector, b::AbstractZerosVector) = _broadcasted_ones(+, a, b) -broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::AbstractZerosVector, b::AbstractOnesVector) = _broadcasted_ones(+, a, b) - -broadcasted(::DefaultArrayStyle{1}, ::typeof(-), a::AbstractZerosVector, b::AbstractZerosVector) = _broadcasted_zeros(-, a, b) -broadcasted(::DefaultArrayStyle{1}, ::typeof(-), a::AbstractOnesVector, b::AbstractZerosVector) = _broadcasted_ones(-, a, b) - - -broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(*, a, b) - # In following, need to restrict to <: Number as otherwise we cannot infer zero from type # TODO: generalise to things like SVector for op in (:*, :/) @eval begin - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractOnes) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractFill{<:Number}) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::Number) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractRange) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractArray{<:Number}) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::Base.Broadcast.Broadcasted) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractZeros, b::AbstractRange) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::AbstractFill{<:Number}) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::Number) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::AbstractOnes) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::AbstractRange) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::AbstractArray{<:Number}) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::Base.Broadcast.Broadcasted) = _broadcasted_zeros($op, a, b) end end for op in (:*, :\) @eval begin - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractOnes, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractFill{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::Number, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractRange, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractArray{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::Base.Broadcast.Broadcasted, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractRange, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractOnes, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractFill{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::Number, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractRange, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractArray{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::Base.Broadcast.Broadcasted, b::AbstractZeros) = _broadcasted_zeros($op, a, b) end end +broadcasted(::typeof(*), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(*, a, b) +broadcasted(::typeof(/), a::AbstractZeros, b::AbstractZeros) = _broadcasted_nan(/, a, b) +broadcasted(::typeof(\), a::AbstractZeros, b::AbstractZeros) = _broadcasted_nan(\, a, b) -for op in (:*, :/, :\) - @eval broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractOnes, b::AbstractOnes) = _broadcasted_ones($op, a, b) -end +# for op in (:*, :/, :\) +# @eval broadcasted(::typeof($op), a::AbstractOnes, b::AbstractOnes) = _broadcasted_ones($op, a, b) +# end -for op in (:/, :\) - @eval broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros{<:Number}, b::AbstractZeros{<:Number}) = _broadcasted_nan($op, a, b) -end +# for op in (:/, :\) +# @eval broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros{<:Number}, b::AbstractZeros{<:Number}) = _broadcasted_nan($op, a, b) +# end # special case due to missing converts for ranges _range_convert(::Type{AbstractVector{T}}, a::AbstractRange{T}) where T = a @@ -205,13 +239,13 @@ _range_convert(::Type{AbstractVector{T}}, a::ZerosVector) where T = ZerosVector{ # end # end -function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractOnesVector, b::AbstractRange) +function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractOnes, b::AbstractRange) broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first.")) TT = typeof(zero(eltype(a)) * zero(eltype(b))) return _range_convert(AbstractVector{TT}, b) end -function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractOnesVector) +function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractOnes) broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first.")) TT = typeof(zero(eltype(a)) * zero(eltype(b))) return _range_convert(AbstractVector{TT}, a) @@ -219,51 +253,46 @@ end for op in (:+, :-) @eval begin - function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractVector, b::AbstractZerosVector) - broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first.")) + function broadcasted(::typeof($op), a::AbstractVector, b::AbstractZerosVector) + ax = broadcast_shape(axes(a), axes(b)) + ax == axes(a) || throw(ArgumentError("cannot broadcast an array with size $(size(a)) with $b")) TT = typeof($op(zero(eltype(a)), zero(eltype(b)))) # Use `TT ∘ (+)` to fix AD issues with `broadcasted(TT, x)` eltype(a) === TT ? a : broadcasted(TT ∘ (+), a) end - function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractZerosVector, b::AbstractVector) - broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $a to a Vector first.")) + function broadcasted(::typeof($op), a::AbstractZerosVector, b::AbstractVector) + ax = broadcast_shape(axes(a), axes(b)) + ax == axes(b) || throw(ArgumentError("cannot broadcast $a with an array with size $(size(b))")) TT = typeof($op(zero(eltype(a)), zero(eltype(b)))) $op === (+) && eltype(b) === TT ? b : broadcasted(TT ∘ ($op), b) end - - broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractFillVector, b::AbstractZerosVector) = - Base.invoke(broadcasted, Tuple{DefaultArrayStyle, typeof($op), AbstractFill, AbstractFill}, DefaultArrayStyle{1}(), $op, a, b) - - broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractZerosVector, b::AbstractFillVector) = - Base.invoke(broadcasted, Tuple{DefaultArrayStyle, typeof($op), AbstractFill, AbstractFill}, DefaultArrayStyle{1}(), $op, a, b) + function broadcasted(::typeof($op), a::AbstractZerosVector, b::AbstractZerosVector) + ax = broadcast_shape(axes(a), axes(b)) + TT = typeof($op(zero(eltype(a)), zero(eltype(b)))) + Zeros(TT, ax) + end end end # Need to prevent array-valued fills from broadcasting over entry -_broadcast_getindex_value(a::AbstractFill{<:Number}) = getindex_value(a) -_broadcast_getindex_value(a::AbstractFill) = Ref(getindex_value(a)) - +_mayberef(x) = Ref(x) +_mayberef(x::Number) = x -function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractFill, b::AbstractRange) +function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractFill, b::AbstractRange) broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first.")) - return broadcasted(*, _broadcast_getindex_value(a), b) + return broadcasted(*, _mayberef(getindex_value(a)), b) end -function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractFill) +function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractFill) broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first.")) - return broadcasted(*, a, _broadcast_getindex_value(b)) + return broadcasted(*, a, _mayberef(getindex_value(b))) end -broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} = broadcasted_fill(op, r, op(getindex_value(r),x), axes(r)) -broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {T,N} = broadcasted_fill(op, r, op(x, getindex_value(r)), axes(r)) -broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = broadcasted_fill(op, r, op(getindex_value(r),x[]), axes(r)) -broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = broadcasted_fill(op, r, op(x[], getindex_value(r)), axes(r)) - # support AbstractFill .^ k -broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractFill{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_fill(op, r, getindex_value(r)^k, axes(r)) -broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractOnes{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_ones(op, r, T, axes(r)) -broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractZeros{T,N}, ::Base.RefValue{Val{0}}) where {T,N} = broadcasted_ones(op, r, T, axes(r)) -broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractZeros{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_zeros(op, r, T, axes(r)) +broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractFill{T,N}, ::Val{k}) where {T,N,k} = broadcasted_fill(op, r, getindex_value(r)^k, axes(r)) +broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractOnes{T,N}, ::Val{k}) where {T,N,k} = broadcasted_ones(op, r, T, axes(r)) +broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractZeros{T,N}, ::Val{0}) where {T,N} = broadcasted_ones(op, r, T, axes(r)) +broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractZeros{T,N}, ::Val{k}) where {T,N,k} = broadcasted_zeros(op, r, T, axes(r)) # supports structured broadcast if isdefined(LinearAlgebra, :fzero) diff --git a/test/runtests.jl b/test/runtests.jl index 088089d4..b1741ae6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -821,7 +821,7 @@ end @testset "maximum/minimum/svd/sort" begin @test maximum(Fill(1, 1_000_000_000)) == minimum(Fill(1, 1_000_000_000)) == 1 @test svdvals(fill(2,5,6)) ≈ svdvals(Fill(2,5,6)) - @test svdvals(Eye(5)) === Fill(1.0,5) + @test svdvals(Eye(5)) === Ones(5) @test sort(Ones(5)) == sort!(Ones(5)) @test_throws MethodError issorted(Fill(im, 2)) @@ -957,21 +957,21 @@ end rng = MersenneTwister(123456) sizes = [(5, 4), (5, 1), (1, 4), (1, 1), (5,)] - for sx in sizes, sy in sizes + @testset for sx in sizes, sy in sizes x, y = Fill(randn(rng), sx), Fill(randn(rng), sy) x_one, y_one = Ones(sx), Ones(sy) x_zero, y_zero = Zeros(sx), Zeros(sy) x_dense, y_dense = randn(rng, sx), randn(rng, sy) for x in [x, x_one, x_zero, x_dense], y in [y, y_one, y_zero, y_dense] - @test x .+ y == collect(x) .+ collect(y) + @test x .+ y ≈ collect(x) .+ collect(y) end @test x_zero .+ y_zero isa Zeros @test x_zero .+ y_one isa Ones @test x_one .+ y_zero isa Ones for x in [x, x_one, x_zero, x_dense], y in [y, y_one, y_zero, y_dense] - @test x .* y == collect(x) .* collect(y) + @test x .* y ≈ collect(x) .* collect(y) end for x in [x, x_one, x_zero, x_dense] @test x .* y_zero isa Zeros @@ -1091,7 +1091,7 @@ end @test_throws DimensionMismatch Zeros{Int}(2) .+ (1:5) @test_throws DimensionMismatch (1:5) .+ Zeros{Int}(2) - for v in (rand(Bool, 5), [1:5;], SVector{5}(1:5), SVector{5,ComplexF16}(1:5)), T in (Bool, Int, Float64) + @testset "$(typeof(v)) $T" for v in (rand(Bool, 5), [1:5;], SVector{5}(1:5), SVector{5,ComplexF16}(1:5)), T in (Bool, Int, Float64) TT = eltype(v + zeros(T, 5)) S = v isa SVector ? SVector{5,TT} : Vector{TT} @@ -1144,7 +1144,7 @@ end @testset "issue #208" begin TS = (Bool, Int, Float32, Float64) - for S in TS, T in TS + @testset for S in TS, T in TS u = rand(S, 2) v = Zeros(T, 2) if zero(S) + zero(T) isa S @@ -1177,6 +1177,32 @@ end end end end + + @testset "Zeros to Fill" begin + @test @inferred((f -> ((x -> (1,)).(f)))((Zeros(4)))) == Fill((1,), 4) + @test @inferred((f -> ((x -> Val(1)).(f)))((Zeros(4)))) == Fill(Val(1), 4) + end + + @testset "multi-element broadcast" begin + x = Fill(2, 2) + y = @. 2 * x * 2 + @test y === Fill(8, 2) + end + + @testset "nested broadcast" begin + bc = Broadcast.broadcasted(*, Zeros(4), Ones(4), Broadcast.broadcasted(*, Zeros(4), Ones(4), Zeros(4))) + @test copy(bc) === Zeros(4) + end + + @testset "0d" begin + @test real.(Fill(2)) == real.(fill(2)) + end + + @testset "preserve 0d" begin + @test real(Fill(4 + 5im)) == real(fill(4 + 5im)) + @test imag(Fill(4 + 5im)) == imag(fill(4 + 5im)) + @test conj(Fill(4 + 5im)) == conj(fill(4 + 5im)) + end end @testset "map" begin @@ -1185,7 +1211,7 @@ end @test map(isone,x1) === Fill(true,5) x0 = Zeros(5) - @test map(exp,x0) === exp.(x0) + @test map(exp,x0) == exp.(x0) x2 = Fill(2,5,3) @test map(exp,x2) === Fill(exp(2),5,3) @@ -2185,8 +2211,10 @@ end @test D - Zeros(5,5) isa Diagonal @test D .+ Zeros(5,5) isa Diagonal @test D .- Zeros(5,5) isa Diagonal - @test D .* Zeros(5,5) isa Diagonal - @test Zeros(5,5) .* D isa Diagonal + @test D .* Zeros(5,5) isa FillArrays.ZerosMatrix + @test ((x,y) -> x * y).(D, Zeros(5,5)) isa Diagonal + @test Zeros(5,5) .* D isa FillArrays.ZerosMatrix + @test ((x,y) -> x * y).(Zeros(5,5), D) isa Diagonal @test Zeros(5,5) - D isa Diagonal @test Zeros(5,5) + D isa Diagonal @test Zeros(5,5) .- D isa Diagonal From 684280d28a2dc6ab768881a8ae663eef9ba0f769 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 13:47:45 +0530 Subject: [PATCH 02/11] specialize real/imag/conj for real arrays --- src/fillbroadcast.jl | 6 ++++++ test/runtests.jl | 22 +++++++++++++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 426420f9..f777b92d 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -145,6 +145,12 @@ function broadcast_preserving_0d(f, As...) end for f in (:real, :imag, :conj) @eval ($f)(A::AbstractFill) = broadcast_preserving_0d($f, A) + @eval ($f)(A::AbstractZeros) = A +end +for T in (:AbstractOnes, :(AbstractFill{<:Real})) + @eval real(A::$T) = A + @eval imag(A::$T) = Zeros{eltype(A)}(axes(A)) + @eval conj(A::$T) = A end ### Binary broadcasting diff --git a/test/runtests.jl b/test/runtests.jl index b1741ae6..59cbba69 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1199,9 +1199,25 @@ end end @testset "preserve 0d" begin - @test real(Fill(4 + 5im)) == real(fill(4 + 5im)) - @test imag(Fill(4 + 5im)) == imag(fill(4 + 5im)) - @test conj(Fill(4 + 5im)) == conj(fill(4 + 5im)) + for f in (real, imag, conj), + (F, A) in ((Fill(4 + 5im), fill(4 + 5im)), + (Zeros{ComplexF64}(), zeros(ComplexF64)), + (Zeros(), zeros()), + (Ones(), ones()), + (Ones{ComplexF64}(), ones(ComplexF64)), + ) + x = f(F) + y = f(A) + @test x == y + @test x isa FillArrays.AbstractFill + if F[] isa Real + if f === imag + @test x isa Zeros + else + @test x isa typeof(F) + end + end + end end end From 2cc6485c31114c39e83cb9fd0b2322d288ec6b3e Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 13:55:24 +0530 Subject: [PATCH 03/11] Binary broadcast test --- src/fillbroadcast.jl | 17 ++++++++++------- test/runtests.jl | 33 ++++++++++++++++++++++++--------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index f777b92d..d1fdf4af 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -143,15 +143,18 @@ function broadcast_preserving_0d(f, As...) r = copy(bc) length(axes(bc)) == 0 ? Fill(r) : r end -for f in (:real, :imag, :conj) +for f in (:real, :imag) @eval ($f)(A::AbstractFill) = broadcast_preserving_0d($f, A) - @eval ($f)(A::AbstractZeros) = A -end -for T in (:AbstractOnes, :(AbstractFill{<:Real})) - @eval real(A::$T) = A - @eval imag(A::$T) = Zeros{eltype(A)}(axes(A)) - @eval conj(A::$T) = A + @eval ($f)(A::AbstractZeros) = Zeros{real(eltype(A))}(axes(A)) end +conj(A::AbstractFill) = broadcast_preserving_0d(conj, A) +conj(A::AbstractZeros) = A +real(A::AbstractOnes) = Ones{real(eltype(A))}(axes(A)) +imag(A::AbstractOnes) = Zeros{real(eltype(A))}(axes(A)) +conj(A::AbstractOnes) = A +real(A::AbstractFill{<:Real}) = A +imag(A::AbstractFill{<:Real}) = Zeros{eltype(A)}(axes(A)) +conj(A::AbstractFill{<:Real}) = A ### Binary broadcasting diff --git a/test/runtests.jl b/test/runtests.jl index 59cbba69..c6299e3b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1199,26 +1199,41 @@ end end @testset "preserve 0d" begin - for f in (real, imag, conj), - (F, A) in ((Fill(4 + 5im), fill(4 + 5im)), - (Zeros{ComplexF64}(), zeros(ComplexF64)), - (Zeros(), zeros()), - (Ones(), ones()), - (Ones{ComplexF64}(), ones(ComplexF64)), - ) + @testset for f in (real, imag, conj), (F, A) in ( + (Fill(4), fill(4)), + (Fill(4 + 5im), fill(4 + 5im)), + (Fill(SMatrix{2,2,ComplexF64,4}(fill(4 + 5im, 4))), fill(SMatrix{2,2,ComplexF64,4}(fill(4 + 5im, 4)))), + (Zeros{ComplexF64}(), zeros(ComplexF64)), + (Zeros(), zeros()), + (Ones(), ones()), + (Ones{ComplexF64}(), ones(ComplexF64)), + ) x = f(F) y = f(A) @test x == y + @test eltype(x) == eltype(y) @test x isa FillArrays.AbstractFill - if F[] isa Real + if F isa Ones if f === imag @test x isa Zeros else - @test x isa typeof(F) + @test x isa Ones + end + end + if F[] isa Real + if f === imag + @test x isa Zeros end end end end + + @testset "issue #40" begin + f(x) = x + g(x, y) = x + F = Fill(1, 2) + @test g.(F, "a") === f.(F) + end end @testset "map" begin From d47ea932effd35a42eaf419e40c74c012fe20571 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 15:06:09 +0530 Subject: [PATCH 04/11] Bump version to v1.15.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d0d3e609..6a51fba7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.14.0" +version = "1.15.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From 23123b22c0f0cf0f15965b9ba67dfa394af3fb92 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 15:36:50 +0530 Subject: [PATCH 05/11] Specialize scaling by a number --- src/fillalgebra.jl | 7 +++++++ test/runtests.jl | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index f98ae605..c7fbbe72 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -78,6 +78,13 @@ mult_zeros(a::AbstractArray{<:Number}, b::AbstractArray{<:Number}) = mult_zeros( mult_zeros(a, b) = mult_fill(a, b, mult_axes(a, b)) mult_ones(a, b) = mult_ones(a, b, mult_axes(a, b)) +# scaling +*(a::AbstractFill, b::Number) = Fill(getindex_value(a) * b, axes(a)) +*(a::Number, b::AbstractFill) = Fill(a * getindex_value(b), axes(b)) +*(a::AbstractZeros, b::Number) = Zeros(typeof(getindex_value(a) * b), axes(a)) +*(a::Number, b::AbstractZeros) = Zeros(typeof(a * getindex_value(b)), axes(b)) + +# matmul *(a::AbstractFillMatrix, b::AbstractFillMatrix) = mult_fill(a,b) *(a::AbstractFillMatrix, b::AbstractFillVector) = mult_fill(a,b) diff --git a/test/runtests.jl b/test/runtests.jl index c6299e3b..421b514b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1196,6 +1196,11 @@ end @testset "0d" begin @test real.(Fill(2)) == real.(fill(2)) + @test (@. 2 * Fill(2) * 2) == (@. 2 * fill(2) * 2) + for (F, A) in ((Fill(2), fill(2)), (Zeros(), zeros()), (Ones(), ones())) + @test F * 2 == A * 2 + @test 2 * F == 2 * A + end end @testset "preserve 0d" begin From 0862a598cd430216c7e98c8c196cc420fed7b19e Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 15:38:33 +0530 Subject: [PATCH 06/11] Update comment --- src/fillbroadcast.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index d1fdf4af..6ea61834 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -101,14 +101,14 @@ has_static_value(x::Broadcast.Broadcasted) = all(has_static_value, x.args) function _iszeros(bc::Broadcast.Broadcasted) all(has_static_value, bc.args) && _iszero(_getindex_value(bc)) end -# conservative check for zeros. In most cases we can't really compare with zero +# conservative check for zeros. In most cases, there isn't a zero element to compare with _iszero(x::Union{Number, AbstractArray}) = iszero(x) _iszero(_) = false function _isones(bc::Broadcast.Broadcasted) all(has_static_value, bc.args) && _isone(_getindex_value(bc)) end -# conservative check for ones. In most cases we can't really compare with one +# conservative check for ones. In most cases, there isn't a unit element to compare with _isone(x::Union{Number, AbstractArray}) = isone(x) _isone(_) = false From 004b17306f3e5cd305e30fcae8f6c66837e1fdc1 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 15:39:59 +0530 Subject: [PATCH 07/11] Delete commented out code --- src/fillbroadcast.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 6ea61834..02dc892b 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -209,14 +209,6 @@ broadcasted(::typeof(*), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zero broadcasted(::typeof(/), a::AbstractZeros, b::AbstractZeros) = _broadcasted_nan(/, a, b) broadcasted(::typeof(\), a::AbstractZeros, b::AbstractZeros) = _broadcasted_nan(\, a, b) -# for op in (:*, :/, :\) -# @eval broadcasted(::typeof($op), a::AbstractOnes, b::AbstractOnes) = _broadcasted_ones($op, a, b) -# end - -# for op in (:/, :\) -# @eval broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros{<:Number}, b::AbstractZeros{<:Number}) = _broadcasted_nan($op, a, b) -# end - # special case due to missing converts for ranges _range_convert(::Type{AbstractVector{T}}, a::AbstractRange{T}) where T = a _range_convert(::Type{AbstractVector{T}}, a::AbstractUnitRange) where T = convert(T,first(a)):convert(T,last(a)) From 3b0102b5f769c43d41779b6acea55a088b6c0612 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 16:18:02 +0530 Subject: [PATCH 08/11] Process Fill broadcasting before others --- src/fillbroadcast.jl | 20 ++++++++++++++------ test/runtests.jl | 5 +++++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 02dc892b..dd23c057 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -118,7 +118,18 @@ _isfill(f::Number) = true _isfill(f::Ref) = true _isfill(::Any) = false -function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{N}}) where {N} +_broadcast_maybecopy(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) = copy(bc) +_broadcast_maybecopy(x) = x + +function _fallback_copy(bc) + # treat the fill components + bc2 = Base.broadcasted(bc.f, map(_broadcast_maybecopy, bc.args)...) + # fallback style + S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{ndims(bc)}} + copy(convert(S, bc2)) +end + +function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) if _iszeros(bc) return Zeros(typeof(_getindex_value(bc)), axes(bc)) elseif _isones(bc) @@ -126,15 +137,12 @@ function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{N}}) where {N} elseif _isfill(bc) return Fill(_getindex_value(bc), axes(bc)) else - # fallback style - S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{N}} - copy(convert(S, bc)) + _fallback_copy(bc) end end # make the zero-dimensional case consistent with Base function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}}) - S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}} - copy(convert(S, bc)) + _fallback_copy(bc) end # some cases that preserve 0d diff --git a/test/runtests.jl b/test/runtests.jl index 421b514b..c3c36985 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1239,6 +1239,11 @@ end F = Fill(1, 2) @test g.(F, "a") === f.(F) end + + @testset "early binding" begin + A = ones(2) .+ (x -> rand()).(Fill(2,2)) + @test all(==(A[1]), A) + end end @testset "map" begin From 2c53cc3c1d4dd51f561381362fe1dceec3784ed8 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 16:23:35 +0530 Subject: [PATCH 09/11] Recursively process fill components --- src/fillbroadcast.jl | 1 + test/runtests.jl | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index dd23c057..13074846 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -119,6 +119,7 @@ _isfill(f::Ref) = true _isfill(::Any) = false _broadcast_maybecopy(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) = copy(bc) +_broadcast_maybecopy(bc::Broadcast.Broadcasted) = Broadcast.broadcasted(bc.f, map(_broadcast_maybecopy, bc.args)...) _broadcast_maybecopy(x) = x function _fallback_copy(bc) diff --git a/test/runtests.jl b/test/runtests.jl index c3c36985..c9707a15 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1243,6 +1243,8 @@ end @testset "early binding" begin A = ones(2) .+ (x -> rand()).(Fill(2,2)) @test all(==(A[1]), A) + A = ones(1,5) .+ (ones(1) .+ (_ -> rand()).(Fill("vec", 2))) + @test all(==(A[1]), A) end end From 10507075bfb8ed38627f5a1a52be8f5a786418ef Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 16:41:27 +0530 Subject: [PATCH 10/11] Refactor common parts --- src/fillbroadcast.jl | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 13074846..115fc6ac 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -118,33 +118,36 @@ _isfill(f::Number) = true _isfill(f::Ref) = true _isfill(::Any) = false -_broadcast_maybecopy(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) = copy(bc) -_broadcast_maybecopy(bc::Broadcast.Broadcasted) = Broadcast.broadcasted(bc.f, map(_broadcast_maybecopy, bc.args)...) -_broadcast_maybecopy(x) = x +function _copy_fill(bc) + v = _getindex_value(bc) + if _iszeros(bc) + return Zeros(typeof(v), axes(bc)) + elseif _isones(bc) + return Ones(typeof(v), axes(bc)) + end + return Fill(v, axes(bc)) +end + +# recursively copy the purely fill components +function _preprocess_fill(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) + _isfill(bc) ? _copy_fill(bc) : Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...) +end +_preprocess_fill(bc::Broadcast.Broadcasted) = Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...) +_preprocess_fill(x) = x function _fallback_copy(bc) - # treat the fill components - bc2 = Base.broadcasted(bc.f, map(_broadcast_maybecopy, bc.args)...) + # copy the purely fill components + bc2 = Base.broadcasted(bc.f, map(_preprocess_fill, bc.args)...) # fallback style S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{ndims(bc)}} copy(convert(S, bc2)) end function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) - if _iszeros(bc) - return Zeros(typeof(_getindex_value(bc)), axes(bc)) - elseif _isones(bc) - return Ones(typeof(_getindex_value(bc)), axes(bc)) - elseif _isfill(bc) - return Fill(_getindex_value(bc), axes(bc)) - else - _fallback_copy(bc) - end + _isfill(bc) ? _copy_fill(bc) : _fallback_copy(bc) end # make the zero-dimensional case consistent with Base -function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}}) - _fallback_copy(bc) -end +Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}}) = _fallback_copy(bc) # some cases that preserve 0d function broadcast_preserving_0d(f, As...) From 7296e788d2f5dead0a747d0c70537c92e4a03427 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 31 Dec 2024 20:50:13 +0530 Subject: [PATCH 11/11] Specialize broadcasting adjoint vector --- src/fillbroadcast.jl | 39 ++++++++++++++++++++++++--------------- test/runtests.jl | 6 ++++++ 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 115fc6ac..aca41de4 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -87,39 +87,48 @@ Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{2} Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{1}) = S Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{0}) = S -_getindex_value(f::AbstractFill) = getindex_value(f) -_getindex_value(x::Number) = x -_getindex_value(x::Ref) = x[] -function _getindex_value(bc::Broadcast.Broadcasted) - bc.f(map(_getindex_value, bc.args)...) +# Obtain the fill value of a broadcasted object by recursively evaluating the fill components +broadcast_getindex_value(f::AbstractFill) = getindex_value(f) +broadcast_getindex_value(f::Transpose{<:Any,<:AbstractFill}) = getindex_value(parent(f)) +broadcast_getindex_value(f::Adjoint{<:Any,<:AbstractFill}) = getindex_value(parent(f)) +broadcast_getindex_value(x::Number) = x +broadcast_getindex_value(x::Ref) = x[] +function broadcast_getindex_value(bc::Broadcast.Broadcasted) + bc.f(map(broadcast_getindex_value, bc.args)...) end has_static_value(x) = false has_static_value(x::Union{AbstractZeros, AbstractOnes}) = true has_static_value(x::Broadcast.Broadcasted) = all(has_static_value, x.args) +# _iszeros and _isones are conservative checks for zeros and ones, +# which are used to determine if a broadcasted object is a Fill, Zeros or Ones. function _iszeros(bc::Broadcast.Broadcasted) - all(has_static_value, bc.args) && _iszero(_getindex_value(bc)) + all(has_static_value, bc.args) && _iszero(broadcast_getindex_value(bc)) end # conservative check for zeros. In most cases, there isn't a zero element to compare with _iszero(x::Union{Number, AbstractArray}) = iszero(x) _iszero(_) = false function _isones(bc::Broadcast.Broadcasted) - all(has_static_value, bc.args) && _isone(_getindex_value(bc)) + all(has_static_value, bc.args) && _isone(broadcast_getindex_value(bc)) end # conservative check for ones. In most cases, there isn't a unit element to compare with _isone(x::Union{Number, AbstractArray}) = isone(x) _isone(_) = false -_isfill(bc::Broadcast.Broadcasted) = all(_isfill, bc.args) -_isfill(f::AbstractFill) = true -_isfill(f::Number) = true -_isfill(f::Ref) = true -_isfill(::Any) = false +# wrappers that are equivalent to an `AbstractFill` may opt in to the broadcasting behavior +# of `AbstractFill` by specializing `isfill` and `broadcast_getindex_value` +isfill(bc::Broadcast.Broadcasted) = all(isfill, bc.args) +isfill(f::AbstractFill) = true +isfill(f::Transpose) = isfill(parent(f)) +isfill(f::Adjoint) = isfill(parent(f)) +isfill(f::Number) = true +isfill(f::Ref) = true +isfill(::Any) = false function _copy_fill(bc) - v = _getindex_value(bc) + v = broadcast_getindex_value(bc) if _iszeros(bc) return Zeros(typeof(v), axes(bc)) elseif _isones(bc) @@ -130,7 +139,7 @@ end # recursively copy the purely fill components function _preprocess_fill(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) - _isfill(bc) ? _copy_fill(bc) : Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...) + isfill(bc) ? _copy_fill(bc) : Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...) end _preprocess_fill(bc::Broadcast.Broadcasted) = Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...) _preprocess_fill(x) = x @@ -144,7 +153,7 @@ function _fallback_copy(bc) end function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) - _isfill(bc) ? _copy_fill(bc) : _fallback_copy(bc) + isfill(bc) ? _copy_fill(bc) : _fallback_copy(bc) end # make the zero-dimensional case consistent with Base Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}}) = _fallback_copy(bc) diff --git a/test/runtests.jl b/test/runtests.jl index c9707a15..9298b276 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1246,6 +1246,12 @@ end A = ones(1,5) .+ (ones(1) .+ (_ -> rand()).(Fill("vec", 2))) @test all(==(A[1]), A) end + + @testset "wrappers" begin + f = Fill(3, 4) + @test f * f' === Fill(9,4,4) + @test f * transpose(f) === Fill(9,4,4) + end end @testset "map" begin