Skip to content

Commit

Permalink
Improve type-inference in broadcasting (#351)
Browse files Browse the repository at this point in the history
* Improve type-inference in broadcasting

* Add test using JET

* Allow more JET versions

* Reshape dest if ndims don't match

* Use ndims from type param

* Rename variable
  • Loading branch information
jishnub authored Mar 25, 2024
1 parent 74e4928 commit 772c307
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Aqua = "0.8"
ArrayLayouts = "1.0.8"
Documenter = "1"
FillArrays = "1"
JET = "0.4, 0.6, 0.7, 0.8"
LinearAlgebra = "1.6"
OffsetArrays = "1"
Random = "1.6"
Expand All @@ -23,11 +24,12 @@ julia = "1.6"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Documenter", "OffsetArrays", "SparseArrays", "StaticArrays", "Test", "Random"]
test = ["Aqua", "Documenter", "JET", "OffsetArrays", "SparseArrays", "StaticArrays", "Test", "Random"]
9 changes: 7 additions & 2 deletions src/blockbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,13 @@ end
@inline _bview(arg, ::Vararg) = arg
@inline _bview(A::AbstractArray, I...) = view(A, I...)

@inline function Base.Broadcast.materialize!(dest, bc::Broadcasted{BS}) where {BS<:AbstractBlockStyle}
return copyto!(dest, Base.Broadcast.instantiate(Base.Broadcast.Broadcasted{BS}(bc.f, bc.args, combine_blockaxes.(axes(dest),axes(bc)))))
@inline function Broadcast.materialize!(dest, bc::Broadcasted{BS}) where {NDims, BS<:AbstractBlockStyle{NDims}}
dest_reshaped = ndims(dest) == NDims ? dest : reshape(dest, size(bc))
bc2 = Broadcast.instantiate(
Broadcast.Broadcasted{BS}(bc.f, bc.args,
map(combine_blockaxes, axes(dest_reshaped), axes(bc))))
copyto!(dest_reshaped, bc2)
return dest
end

function _generic_blockbroadcast_copyto!(dest::AbstractArray,
Expand Down
24 changes: 24 additions & 0 deletions test/test_blockbroadcast.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using BlockArrays, FillArrays, Test
import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal
using JET

@testset "broadcast" begin
@testset "BlockArray" begin
Expand All @@ -24,6 +25,24 @@ import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal

@test axes(A + A) == axes(A .+ A) == axes(A)
@test axes(A .+ 1) == axes(A)

@testset "mismatched ndims" begin
u = BlockArray(randn(5), [2,3])
dest = zeros(size(u)..., 1)
@test (dest .= u) isa typeof(dest)
@static if isdefined(JET, :test_opt)
@test_opt ((dest,u) -> dest .= u)(dest,u)
end
@test reshape(dest, size(u)) == u

u = BlockArray(randn(3,3), [1,2], [1,2])
dest = zeros(length(u))
@test (dest .= u) isa typeof(dest)
@static if isdefined(JET, :test_opt)
@test_opt ((dest,u) -> dest .= u)(dest,u)
end
@test reshape(dest, size(u)) == u
end
end

@testset "PseudoBlockArray" begin
Expand Down Expand Up @@ -180,8 +199,13 @@ import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal

@testset "type inference" begin
u = BlockArray(randn(5), [2,3]);
A = zeros(size(u))
@inferred(copyto!(similar(u), Base.broadcasted(exp, u)))
@test exp.(u) == exp.(Vector(u))
# test_opt isn't available on JET v0.4, which is installed on Julia v1.6
@static if isdefined(JET, :test_opt)
@test_opt ((A,B) -> A .= B)(A,u)
end
end

@testset "adjtrans" begin
Expand Down

0 comments on commit 772c307

Please sign in to comment.