diff --git a/Project.toml b/Project.toml index cbff73dc..1ec9b48c 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ Aqua = "0.8" ArrayLayouts = "1.0.8" Documenter = "1" FillArrays = "1" +JET = "0.4, 0.6, 0.7, 0.8" LinearAlgebra = "1.6" OffsetArrays = "1" Random = "1.6" @@ -23,6 +24,7 @@ julia = "1.6" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -30,4 +32,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Documenter", "OffsetArrays", "SparseArrays", "StaticArrays", "Test", "Random"] +test = ["Aqua", "Documenter", "JET", "OffsetArrays", "SparseArrays", "StaticArrays", "Test", "Random"] diff --git a/src/blockbroadcast.jl b/src/blockbroadcast.jl index 306b05d2..7bf796b9 100644 --- a/src/blockbroadcast.jl +++ b/src/blockbroadcast.jl @@ -140,8 +140,13 @@ end @inline _bview(arg, ::Vararg) = arg @inline _bview(A::AbstractArray, I...) = view(A, I...) -@inline function Base.Broadcast.materialize!(dest, bc::Broadcasted{BS}) where {BS<:AbstractBlockStyle} - return copyto!(dest, Base.Broadcast.instantiate(Base.Broadcast.Broadcasted{BS}(bc.f, bc.args, combine_blockaxes.(axes(dest),axes(bc))))) +@inline function Broadcast.materialize!(dest, bc::Broadcasted{BS}) where {NDims, BS<:AbstractBlockStyle{NDims}} + dest_reshaped = ndims(dest) == NDims ? dest : reshape(dest, size(bc)) + bc2 = Broadcast.instantiate( + Broadcast.Broadcasted{BS}(bc.f, bc.args, + map(combine_blockaxes, axes(dest_reshaped), axes(bc)))) + copyto!(dest_reshaped, bc2) + return dest end function _generic_blockbroadcast_copyto!(dest::AbstractArray, diff --git a/test/test_blockbroadcast.jl b/test/test_blockbroadcast.jl index 0d58aa3d..15a012eb 100644 --- a/test/test_blockbroadcast.jl +++ b/test/test_blockbroadcast.jl @@ -1,5 +1,6 @@ using BlockArrays, FillArrays, Test import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal +using JET @testset "broadcast" begin @testset "BlockArray" begin @@ -24,6 +25,24 @@ import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal @test axes(A + A) == axes(A .+ A) == axes(A) @test axes(A .+ 1) == axes(A) + + @testset "mismatched ndims" begin + u = BlockArray(randn(5), [2,3]) + dest = zeros(size(u)..., 1) + @test (dest .= u) isa typeof(dest) + @static if isdefined(JET, :test_opt) + @test_opt ((dest,u) -> dest .= u)(dest,u) + end + @test reshape(dest, size(u)) == u + + u = BlockArray(randn(3,3), [1,2], [1,2]) + dest = zeros(length(u)) + @test (dest .= u) isa typeof(dest) + @static if isdefined(JET, :test_opt) + @test_opt ((dest,u) -> dest .= u)(dest,u) + end + @test reshape(dest, size(u)) == u + end end @testset "PseudoBlockArray" begin @@ -180,8 +199,13 @@ import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal @testset "type inference" begin u = BlockArray(randn(5), [2,3]); + A = zeros(size(u)) @inferred(copyto!(similar(u), Base.broadcasted(exp, u))) @test exp.(u) == exp.(Vector(u)) + # test_opt isn't available on JET v0.4, which is installed on Julia v1.6 + @static if isdefined(JET, :test_opt) + @test_opt ((A,B) -> A .= B)(A,u) + end end @testset "adjtrans" begin