Skip to content

Commit

Permalink
Add eachslice for Julia 1.1+ (invenia#58)
Browse files Browse the repository at this point in the history
* Add `eachslice` for Julia 1.1+

* Give each testset their own data

* Bump version
  • Loading branch information
nickrobinson251 committed Jul 21, 2019
1 parent 31a8195 commit 2d2d2a7
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NamedDims"
uuid = "356022a1-0364-5f58-8944-0da4b18d706f"
authors = ["Invenia Technical Computing Corporation"]
version = "0.2.3"
version = "0.2.4"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
13 changes: 13 additions & 0 deletions src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ for (mod, funs) in (
end
end

if VERSION > v"1.1-"
function Base.eachslice(a::NamedDimsArray{L}; dims, kwargs...) where L
numerical_dims = dim(a, dims)
slices = eachslice(parent(a); dims=numerical_dims, kwargs...)
return Base.Generator(slices) do slice
# For unknown reasons (something to do with hoisting?) having this in the
# function passed to `Generator` actually results in less memory being allocated
names = remaining_dimnames_after_dropping(L, numerical_dims)
return NamedDimsArray(slice, names)
end
end
end

# 1 arg before - no default for `dims` keyword
for (mod, funs) in (
(:Base, (:mapslices,)),
Expand Down
54 changes: 45 additions & 9 deletions test/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ using Test
using Statistics

@testset "Base" begin
a = [10 20; 31 40]
nda = NamedDimsArray(a, (:x, :y))
a = [10 20; 31 40]
nda = NamedDimsArray(a, (:x, :y))

@testset "$f" for f in (sum, prod, maximum, minimum, extrema)
@test f(nda) == f(a)
Expand All @@ -25,8 +25,8 @@ nda = NamedDimsArray(a, (:x, :y))
end

@testset "sort!" begin
a2 = [1 9; 7 3]
nda2 = NamedDimsArray(a2, (:x, :y))
a = [1 9; 7 3]
nda = NamedDimsArray(a, (:x, :y))

# Vector case
veca = [1, 9, 7, 3]
Expand All @@ -35,16 +35,45 @@ nda = NamedDimsArray(a, (:x, :y))

# Higher-dim case: `dims` keyword in `sort!` requires Julia v1.1+
if VERSION > v"1.1-"
sort!(nda2, dims=:y)
@test issorted(a2[2, :])
@test_throws UndefKeywordError sort!(nda2)
sort!(nda, dims=:y)
@test issorted(a[2, :])
@test_throws UndefKeywordError sort!(nda)

sort!(nda2; dims=:x, order=Base.Reverse)
@test issorted(a2[:, 1]; order=Base.Reverse)
sort!(nda; dims=:x, order=Base.Reverse)
@test issorted(a[:, 1]; order=Base.Reverse)
end
end

@testset "eachslice" begin
if VERSION > v"1.1-"
slices = [[111 121; 211 221], [112 122; 212 222]]
a = cat(slices...; dims=3)
nda = NamedDimsArray(a, (:a, :b, :c))

@test (
sum(eachslice(nda; dims=:c)) ==
sum(eachslice(nda; dims=3)) ==
sum(eachslice(a; dims=3)) ==
slices[1] + slices[2]
)
@test_throws ArgumentError eachslice(nda; dims=(1, 2))
@test_throws ArgumentError eachslice(a; dims=(1, 2))

@test_throws UndefKeywordError eachslice(nda)
@test_throws UndefKeywordError eachslice(a)

@test (
names(first(eachslice(nda; dims=:b))) ==
names(first(eachslice(nda; dims=2))) ==
(:a, :c)
)
end
end

@testset "mapslices" begin
a = [10 20; 31 40]
nda = NamedDimsArray(a, (:x, :y))

@test (
mapslices(join, nda; dims=:x) ==
mapslices(join, nda; dims=1) ==
Expand All @@ -71,6 +100,9 @@ nda = NamedDimsArray(a, (:x, :y))
end

@testset "mapreduce" begin
a = [10 20; 31 40]
nda = NamedDimsArray(a, (:x, :y))

@test mapreduce(isodd, |, nda) == true == mapreduce(isodd, |, a)
@test (
mapreduce(isodd, |, nda; dims=:x) ==
Expand All @@ -90,13 +122,17 @@ nda = NamedDimsArray(a, (:x, :y))
end

@testset "zero" begin
a = [10 20; 31 40]
nda = NamedDimsArray(a, (:x, :y))

@test zero(nda) == [0 0; 0 0] == zero(a)
@test names(zero(nda)) == (:x, :y)
end

@testset "count" begin
a = [true false; true true]
nda = NamedDimsArray(a, (:x, :y))

@test count(nda) == count(a) == 3
@test_throws ErrorException count(nda; dims=:x)
@test_throws ErrorException count(a; dims=1)
Expand Down

0 comments on commit 2d2d2a7

Please sign in to comment.