Skip to content

Commit 92567da

Browse files
Attempt to fix Broadcast.broadcast_shape inference
1 parent ec8c7f5 commit 92567da

File tree

5 files changed

+94
-0
lines changed

5 files changed

+94
-0
lines changed

src/blockbroadcast.jl

+2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ BroadcastStyle(::PseudoBlockStyle{M}, ::BlockStyle{N}) where {M,N} = BlockStyle(
2929

3030

3131
# sortedunion can assume inputs are already sorted so this could be improved
32+
include("tuple_tools.jl")
3233
sortedunion(a,b) = sort!(union(a,b))
34+
sortedunion(a::Tuple, b::Tuple) = tuple_sort(tuple_union(a,b))
3335
sortedunion(a::Base.OneTo, b::Base.OneTo) = Base.OneTo(max(last(a),last(b)))
3436
sortedunion(a::AbstractUnitRange, b::AbstractUnitRange) = min(first(a),first(b)):max(last(a),last(b))
3537
combine_blockaxes(a, b) = _BlockedUnitRange(sortedunion(blocklasts(a), blocklasts(b)))

src/tuple_tools.jl

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#####
2+
##### From TupleTools.jl
3+
#####
4+
function _split(t::Tuple)
5+
N = length(t)
6+
M = N >> 1
7+
return ntuple(i -> t[i], M), ntuple(i -> t[i + M], N - M)
8+
end
9+
function _merge(t1::Tuple, t2::Tuple, lt, by, rev)
10+
if lt(by(first(t1)), by(first(t2))) != rev
11+
return (first(t1), _merge(tail(t1), t2, lt, by, rev)...)
12+
else
13+
return (first(t2), _merge(t1, tail(t2), lt, by, rev)...)
14+
end
15+
end
16+
_merge(::Tuple{}, t2::Tuple, lt, by, rev) = t2
17+
_merge(t1::Tuple, ::Tuple{}, lt, by, rev) = t1
18+
_merge(::Tuple{}, ::Tuple{}, lt, by, rev) = ()
19+
20+
tuple_sort(t::Tuple; lt=isless, by=identity, rev::Bool=false) = _tuple_sort(t, lt, by, rev)
21+
@inline function _tuple_sort(t::Tuple, lt=isless, by=identity, rev::Bool=false)
22+
t1, t2 = _split(t)
23+
t1s = _tuple_sort(t1, lt, by, rev)
24+
t2s = _tuple_sort(t2, lt, by, rev)
25+
return _merge(t1s, t2s, lt, by, rev)
26+
end
27+
_tuple_sort(t::Tuple{Any}, lt=isless, by=identity, rev::Bool=false) = t
28+
_tuple_sort(t::Tuple{}, lt=isless, by=identity, rev::Bool=false) = t
29+
30+
31+
######
32+
###### tuple_union
33+
######
34+
35+
struct DistinctElems{T<:Tuple}
36+
elems::T
37+
end
38+
tuple_union(a::Tuple, b::Tuple) = distinct_elems(DistinctElems(()), a..., b...).elems
39+
40+
distinct_elems(x::DistinctElems) = x
41+
42+
distinct_elems(x::DistinctElems, r1) =
43+
r1 in x.elems ? x : DistinctElems((x.elems..., r1))
44+
45+
function distinct_elems(x::DistinctElems, r1, remaining...)
46+
return if r1 in x.elems
47+
distinct_elems(x, remaining...)
48+
else
49+
distinct_elems(DistinctElems((x.elems..., r1)), remaining...)
50+
end
51+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,4 @@ include("test_blockproduct.jl")
2929
include("test_blockreduce.jl")
3030
include("test_blockdeque.jl")
3131
include("test_blockcholesky.jl")
32+
include("test_tuple_tools.jl")

test/test_blockbroadcast.jl

+5
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,11 @@ import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal
182182
u = BlockArray(randn(5), [2,3]);
183183
@inferred(copyto!(similar(u), Base.broadcasted(exp, u)))
184184
@test exp.(u) == exp.(Vector(u))
185+
186+
shape1 = (BlockArrays._BlockedUnitRange((2,)),);
187+
shape2 = (BlockArrays._BlockedUnitRange((2,)),);
188+
@inferred Base.Broadcast.broadcast_shape(shape1, shape2)
189+
@code_warntype Base.Broadcast.broadcast_shape(shape1, shape2)
185190
end
186191

187192
@testset "adjtrans" begin

test/test_tuple_tools.jl

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using BlockArrays
2+
using Test
3+
using Random
4+
5+
@testset "Tuple Tools" begin
6+
@testset "tuple_sort" begin
7+
n = 10
8+
p = randperm(n)
9+
t = (p...,)
10+
@test @inferred(BlockArrays.tuple_sort((1,))) == (1,)
11+
@test @inferred(BlockArrays.tuple_sort(())) == ()
12+
@inferred(BlockArrays.tuple_sort(t; rev=true)) == (sort(p; rev=true)...,)
13+
@test @inferred(BlockArrays.tuple_sort(t; rev=false)) == (sort(p; rev=false)...,)
14+
@test BlockArrays.tuple_sort((2, 1, 3.0)) === (1, 2, 3.0)
15+
16+
shape1 = (BlockArrays._BlockedUnitRange((2,)),);
17+
shape2 = (BlockArrays._BlockedUnitRange((2,)),);
18+
bl1 = BlockArrays.blocklasts(shape1[1])
19+
bl2 = BlockArrays.blocklasts(shape2[1])
20+
# @show BlockArrays.tuple_union(bl1,bl2)
21+
@test BlockArrays.sortedunion(bl1, bl2) == (2,)
22+
end
23+
24+
# from Base
25+
@testset "tuple_union" begin
26+
for S in (identity,)
27+
s = BlockArrays.tuple_union(S((1,2)), S((3,4)))
28+
@test s == S((1,2,3,4))
29+
s = BlockArrays.tuple_union(S((5,6,7,8)), S((7,8,9)))
30+
@test s == S((5,6,7,8,9))
31+
s = BlockArrays.tuple_union(S((1,3,5,7)), (2,3,4,5))
32+
@test s == S((1,3,5,7,2,4))
33+
end
34+
end
35+
end

0 commit comments

Comments
 (0)