diff --git a/src/FillArrays.jl b/src/FillArrays.jl index cbf9405a..ddd10c35 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -589,6 +589,7 @@ end include("fillalgebra.jl") include("fillbroadcast.jl") include("trues.jl") +include("fillcat.jl") ## # print diff --git a/src/fillcat.jl b/src/fillcat.jl new file mode 100644 index 00000000..4f30d824 --- /dev/null +++ b/src/fillcat.jl @@ -0,0 +1,53 @@ + +function Base.cat_t(::Type{T}, fs::Fill...; dims) where T + allvals = unique([f.value for f in fs]) + length(allvals) > 1 && return Base._cat_t(dims, T, fs...) + + catdims = Base.dims2cat(dims) + + # When dims is a tuple the output gets zero padded and we can't use a Fill unless it is all zeros + # There might be some cases when it does not get padded which are not considered here + + if sum(catdims) > 1 + allvals[] isa Number || return Base._cat_t(dims, T, fs...) + allvals[] !== zero(T) && return Base._cat_t(dims, T, fs...) + end + + shape = cat_shape_fill(catdims, fs) + return Fill(convert(T, fs[1].value), shape) +end + +Base.vcat(vs::Fill...) = cat(vs...;dims=Val(1)) +Base.hcat(vs::Fill...) = cat(vs...;dims=Val(2)) + + +function Base.cat_t(::Type{T}, fs::Zeros...; dims) where T + catdims = Base.dims2cat(dims) + shape = cat_shape_fill(catdims, fs) + return Zeros{T}(shape) +end + +Base.vcat(vs::Zeros...) = cat(vs...;dims=Val(1)) +Base.hcat(vs::Zeros...) = cat(vs...;dims=Val(2)) + + +function Base.cat_t(::Type{T}, fs::Ones...; dims) where T + catdims = Base.dims2cat(dims) + + # When dims is a tuple the output gets zero padded so we can't return a Ones + # There might be some cases when it does not get padded which are not considered here + sum(catdims) > 1 && return Base._cat_t(dims, T, fs...) + + shape = cat_shape_fill(catdims, fs) + return Ones{T}(shape) +end + +Base.vcat(vs::Ones...) = cat(vs...;dims=Val(1)) +Base.hcat(vs::Ones...) = cat(vs...;dims=Val(2)) + + +if VERSION < v"1.6-" + cat_shape_fill(catdims, fs) = Base.cat_shape(catdims, (), map(Base.cat_size, fs)...) +else + cat_shape_fill(catdims, fs) = Base.cat_shape(catdims, map(Base.cat_size, fs)::Tuple{Vararg{Union{Int,Dims}}})::Dims +end diff --git a/test/runtests.jl b/test/runtests.jl index 37d32c82..1d8e7bff 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1232,3 +1232,149 @@ end @test convert(Fill, transpose(a)) ≡ Fill(2.0,1,5) end end + +@testset "Concatenation" begin + + @testset "Fill" begin + @testset "cat shape $s" for s in + ( + 0, + 1, + (2, 0), + (0, 2), + (2,3,4) + ) + + @testset "Dim $dims" for dims in (1,2,3, Val(4)) + res = cat(Fill(1, s), Fill(1, s); dims=dims) + @test res isa Fill + @test res == cat(fill(1, s), fill(1, s); dims=dims) + + res = cat(Fill(1.0, s), Fill(1, s); dims=dims) + @test res isa Fill + @test res == cat(fill(1.0, s), fill(1, s); dims=dims) + + res = cat(Fill(:a, s), Fill(:a, s); dims=dims) + @test res isa Fill + @test res == cat(fill(:a, s), fill(:a, s);dims=dims) + + @test cat(Fill(1, s), Fill(2, s);dims=dims) == cat(fill(1, s), fill(2, s);dims=dims) + end + @testset "Dim $dims" for dims in ( + (1,2), + (2,3), + (1,3,4), + Iterators.take(3:5, 2) + ) + # This inserts a bunch of zeros so we can no longer assume the answer is a Fill + @test cat(Fill(1, s), Fill(1, s); dims=dims) == cat(fill(1, s), fill(1,s); dims=dims) + @test cat(Fill(0, s), Fill(0, s); dims=dims) isa Fill + @test cat(Fill(0.0, s), Fill(0.0, s); dims=dims) isa Fill + + @test cat(Fill(1, s), Fill(2, s);dims=dims) == cat(fill(1, s), fill(2, s);dims=dims) + end + end + + @testset "vcat" begin + # vcat just delegates to cat, so we basically just test that here + res = vcat(Fill(1, 3), Fill(1, 4)) + @test res isa Fill + @test res == vcat(fill(1, 3), fill(1,4)) + end + + @testset "hcat" begin + # hcat just delegates to cat, so we basically just test that here + res = hcat(Fill(1, 2), Fill(1, 2)) + @test res isa Fill + @test res == hcat(fill(1, 2), fill(1,2)) + end + end + + @testset "Zeros" begin + @testset "cat shape $s" for s in + ( + 0, + 1, + (2, 0), + (0, 2), + (2,3,4) + ) + + @testset "Dim $dims" for dims in ( + 1, + 2, + Val(3), + (1,2), + (2,3), + Iterators.take(3:5, 2) + ) + res = cat(Zeros(s), Zeros(s); dims=dims) + @test res isa Zeros + @test res == cat(zeros(s), zeros(s); dims=dims) + + res = cat(Zeros{Float64}(s), Zeros{Int}(s); dims=dims) + @test res isa Zeros + @test res == cat(zeros(Float64, s), zeros(Int, s); dims=dims) + end + end + + @testset "vcat" begin + # vcat just delegates to cat, so we basically just test that here + res = vcat(Zeros(3), Zeros(4)) + @test res isa Zeros + @test res == vcat(zeros(3), zeros(4)) + end + + @testset "hcat" begin + # hcat just delegates to cat, so we basically just test that here + res = vcat(Zeros(2), Zeros(2)) + @test res isa Zeros + @test res == vcat(zeros(2), zeros(2)) + end + end + + @testset "Ones" begin + @testset "cat shape $s" for s in + ( + 0, + 1, + (2, 0), + (0, 2), + (2,3,4) + ) + + @testset "Dim $dims" for dims in (1,2,3, Val(4)) + res = cat(Ones(s), Ones(s); dims=dims) + @test res isa Ones + @test res == cat(ones(s), ones(s); dims=dims) + + res = cat(Ones{Float64}(s), Ones{Int}(s); dims=dims) + @test res isa Ones + @test res == cat(ones(Float64, s), ones(Int, s); dims=dims) + end + @testset "Dim $dims" for dims in ( + (1,2), + (2,3), + (1,3,4), + Iterators.take(3:5, 2) + ) + # This inserts a bunch of zeros so we can no longer assume the answer is a Fill + @test cat(Ones(s), Ones(s); dims=dims) == cat(ones(s), ones(s); dims=dims) + end + end + + @testset "vcat" begin + # vcat just delegates to cat, so we basically just test that here + res = vcat(Ones(3), Ones(4)) + @test res isa Ones + @test res == vcat(ones(3), fill(1,4)) + end + + @testset "hcat" begin + # hcat just delegates to cat, so we basically just test that here + res = hcat(Ones(2), Ones(2)) + @test res isa Ones + @test res == hcat(ones(2), fill(1,2)) + end + end +end