Skip to content

Commit

Permalink
Attempt to fix Broadcast.broadcast_shape inference
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Sep 20, 2023
1 parent ec8c7f5 commit 92567da
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/blockbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ BroadcastStyle(::PseudoBlockStyle{M}, ::BlockStyle{N}) where {M,N} = BlockStyle(


# sortedunion can assume inputs are already sorted so this could be improved
include("tuple_tools.jl")
sortedunion(a,b) = sort!(union(a,b))
sortedunion(a::Tuple, b::Tuple) = tuple_sort(tuple_union(a,b))
sortedunion(a::Base.OneTo, b::Base.OneTo) = Base.OneTo(max(last(a),last(b)))
sortedunion(a::AbstractUnitRange, b::AbstractUnitRange) = min(first(a),first(b)):max(last(a),last(b))
combine_blockaxes(a, b) = _BlockedUnitRange(sortedunion(blocklasts(a), blocklasts(b)))
Expand Down
51 changes: 51 additions & 0 deletions src/tuple_tools.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#####
##### From TupleTools.jl
#####
function _split(t::Tuple)
N = length(t)
M = N >> 1
return ntuple(i -> t[i], M), ntuple(i -> t[i + M], N - M)
end
function _merge(t1::Tuple, t2::Tuple, lt, by, rev)
if lt(by(first(t1)), by(first(t2))) != rev
return (first(t1), _merge(tail(t1), t2, lt, by, rev)...)
else
return (first(t2), _merge(t1, tail(t2), lt, by, rev)...)
end
end
_merge(::Tuple{}, t2::Tuple, lt, by, rev) = t2
_merge(t1::Tuple, ::Tuple{}, lt, by, rev) = t1
_merge(::Tuple{}, ::Tuple{}, lt, by, rev) = ()

tuple_sort(t::Tuple; lt=isless, by=identity, rev::Bool=false) = _tuple_sort(t, lt, by, rev)
@inline function _tuple_sort(t::Tuple, lt=isless, by=identity, rev::Bool=false)
t1, t2 = _split(t)
t1s = _tuple_sort(t1, lt, by, rev)
t2s = _tuple_sort(t2, lt, by, rev)
return _merge(t1s, t2s, lt, by, rev)
end
_tuple_sort(t::Tuple{Any}, lt=isless, by=identity, rev::Bool=false) = t
_tuple_sort(t::Tuple{}, lt=isless, by=identity, rev::Bool=false) = t


######
###### tuple_union
######

struct DistinctElems{T<:Tuple}
elems::T
end
tuple_union(a::Tuple, b::Tuple) = distinct_elems(DistinctElems(()), a..., b...).elems

distinct_elems(x::DistinctElems) = x

distinct_elems(x::DistinctElems, r1) =
r1 in x.elems ? x : DistinctElems((x.elems..., r1))

function distinct_elems(x::DistinctElems, r1, remaining...)
return if r1 in x.elems
distinct_elems(x, remaining...)
else
distinct_elems(DistinctElems((x.elems..., r1)), remaining...)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ include("test_blockproduct.jl")
include("test_blockreduce.jl")
include("test_blockdeque.jl")
include("test_blockcholesky.jl")
include("test_tuple_tools.jl")
5 changes: 5 additions & 0 deletions test/test_blockbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal
u = BlockArray(randn(5), [2,3]);
@inferred(copyto!(similar(u), Base.broadcasted(exp, u)))
@test exp.(u) == exp.(Vector(u))

shape1 = (BlockArrays._BlockedUnitRange((2,)),);
shape2 = (BlockArrays._BlockedUnitRange((2,)),);
@inferred Base.Broadcast.broadcast_shape(shape1, shape2)
@code_warntype Base.Broadcast.broadcast_shape(shape1, shape2)
end

@testset "adjtrans" begin
Expand Down
35 changes: 35 additions & 0 deletions test/test_tuple_tools.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using BlockArrays
using Test
using Random

@testset "Tuple Tools" begin
@testset "tuple_sort" begin
n = 10
p = randperm(n)
t = (p...,)
@test @inferred(BlockArrays.tuple_sort((1,))) == (1,)
@test @inferred(BlockArrays.tuple_sort(())) == ()
@inferred(BlockArrays.tuple_sort(t; rev=true)) == (sort(p; rev=true)...,)
@test @inferred(BlockArrays.tuple_sort(t; rev=false)) == (sort(p; rev=false)...,)
@test BlockArrays.tuple_sort((2, 1, 3.0)) === (1, 2, 3.0)

shape1 = (BlockArrays._BlockedUnitRange((2,)),);
shape2 = (BlockArrays._BlockedUnitRange((2,)),);
bl1 = BlockArrays.blocklasts(shape1[1])
bl2 = BlockArrays.blocklasts(shape2[1])
# @show BlockArrays.tuple_union(bl1,bl2)
@test BlockArrays.sortedunion(bl1, bl2) == (2,)
end

# from Base
@testset "tuple_union" begin
for S in (identity,)
s = BlockArrays.tuple_union(S((1,2)), S((3,4)))
@test s == S((1,2,3,4))
s = BlockArrays.tuple_union(S((5,6,7,8)), S((7,8,9)))
@test s == S((5,6,7,8,9))
s = BlockArrays.tuple_union(S((1,3,5,7)), (2,3,4,5))
@test s == S((1,3,5,7,2,4))
end
end
end

0 comments on commit 92567da

Please sign in to comment.