Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add insertdims #830

Merged
merged 3 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Compat"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "4.15.0"
version = "4.16.0"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ changes in `julia`.

## Supported features

* `insertdims(D; dims)` is the opposite of `dropdims` ([#45793]) (since Compat 4.16.0)

* `chopprefix(s, prefix)` and `chopsuffix(s, suffix)` ([#40995]) (since Compat 4.15.0)

* `logrange(lo, hi; length)` is like `range` but with a constant ratio, not difference. ([#39071]) (since Compat 4.14.0) Note that on Julia 1.8 and earlier, the version from Compat has slightly lower floating-point accuracy than the one in Base (Julia 1.11 and later).
Expand Down Expand Up @@ -192,3 +194,4 @@ Note that you should specify the correct minimum version for `Compat` in the
[#47679]: https://github.com/JuliaLang/julia/pull/47679
[#48038]: https://github.com/JuliaLang/julia/issues/48038
[#50105]: https://github.com/JuliaLang/julia/issues/50105
[#45793]: https://github.com/JuliaLang/julia/pull/45793
31 changes: 31 additions & 0 deletions src/Compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,37 @@ if VERSION < v"1.8.0-DEV.1016"
export chopprefix, chopsuffix
end

if VERSION < v"1.12.0-DEV.974" # contrib/commit-name.sh 2635dea

insertdims(A; dims) = _insertdims(A, dims)

function _insertdims(A::AbstractArray{T, N}, dims::NTuple{M, Int}) where {T, N, M}
for i in eachindex(dims)
1 ≤ dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1"))
dims[i] ≤ N+M || throw(ArgumentError("the largest entry in dims must be not larger than the dimension of the array and the length of dims added"))
for j = 1:i-1
dims[j] == dims[i] && throw(ArgumentError("inserted dims must be unique"))
end
end

# acc is a tuple, where the first entry is the final shape
# the second entry off acc is a counter for the axes of A
inds = Base._foldoneto((acc, i) ->
i ∈ dims
? ((acc[1]..., Base.OneTo(1)), acc[2])
: ((acc[1]..., axes(A, acc[2])), acc[2] + 1),
((), 1), Val(N+M))
new_shape = inds[1]
return reshape(A, new_shape)
end

_insertdims(A::AbstractArray, dim::Integer) = _insertdims(A, (Int(dim),))

export insertdims
else
using Base: insertdims, _insertdims
end

include("deprecated.jl")

end # module Compat
31 changes: 31 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -907,3 +907,34 @@ end
@test isa(chopsuffix(S("foo"), "oo"), SubString)
end
end

# https://github.com/JuliaLang/julia/pull/45793
@testset "insertdims" begin
a = rand(8, 7)
@test @inferred(insertdims(a, dims=1)) == @inferred(insertdims(a, dims=(1,))) == reshape(a, (1, 8, 7))
@test @inferred(insertdims(a, dims=3)) == @inferred(insertdims(a, dims=(3,))) == reshape(a, (8, 7, 1))
@test @inferred(insertdims(a, dims=(1, 3))) == reshape(a, (1, 8, 1, 7))
@test @inferred(insertdims(a, dims=(1, 2, 3))) == reshape(a, (1, 1, 1, 8, 7))
@test @inferred(insertdims(a, dims=(1, 4))) == reshape(a, (1, 8, 7, 1))
@test @inferred(insertdims(a, dims=(1, 3, 5))) == reshape(a, (1, 8, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 2, 4, 6))) == reshape(a, (1, 1, 8, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 3, 4, 6))) == reshape(a, (1, 8, 1, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 4, 6, 3))) == reshape(a, (1, 8, 1, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 3, 5, 6))) == reshape(a, (1, 8, 1, 7, 1, 1))

@test_throws ArgumentError insertdims(a, dims=(1, 1, 2, 3))
@test_throws ArgumentError insertdims(a, dims=(1, 2, 2, 3))
@test_throws ArgumentError insertdims(a, dims=(1, 2, 3, 3))
@test_throws UndefKeywordError insertdims(a)
@test_throws ArgumentError insertdims(a, dims=0)
@test_throws ArgumentError insertdims(a, dims=(1, 2, 1))
@test_throws ArgumentError insertdims(a, dims=4)
@test_throws ArgumentError insertdims(a, dims=6)

# insertdims and dropdims are inverses
b = rand(1,1,1,5,1,1,7)
for dims in [1, (1,), 2, (2,), 3, (3,), (1,3), (1,2,3), (1,2), (1,3,5), (1,2,5,6), (1,3,5,6), (1,3,5,6), (1,6,5,3)]
@test dropdims(insertdims(a; dims); dims) == a
@test insertdims(dropdims(b; dims); dims) == b
end
end
Loading