Skip to content

Commit

Permalink
Merge pull request #20 from darsnack/length-as-field
Browse files Browse the repository at this point in the history
Store number of labels as a field
  • Loading branch information
darsnack authored Oct 12, 2022
2 parents 69ca6fa + afd2ac0 commit ddbba63
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 54 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "OneHotArrays"
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
version = "0.1.1"
version = "0.2.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
103 changes: 64 additions & 39 deletions src/array.jl
Original file line number Diff line number Diff line change
@@ -1,65 +1,84 @@
"""
OneHotArray{T, L, N, M, I} <: AbstractArray{Bool, M}
OneHotArray{T, N, M, I} <: AbstractArray{Bool, M}
OneHotArray(indices, L)
A one-hot `M`-dimensional array with `L` labels (i.e. `size(A, 1) == L` and `sum(A, dims=1) == 1`)
stored as a compact `N == M-1`-dimensional array of indices.
Typically constructed by [`onehot`](@ref) and [`onehotbatch`](@ref).
Parameter `I` is the type of the underlying storage, and `T` its eltype.
"""
struct OneHotArray{T<:Integer, L, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} <: AbstractArray{Bool, var"N+1"}
struct OneHotArray{T<:Integer, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} <: AbstractArray{Bool, var"N+1"}
indices::I
nlabels::Int
end
OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1, I}(indices)
OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, 1, T}(indices)
OneHotArray(indices::I, L::Integer) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, L, N, N+1, I}(indices)
OneHotArray{T, N, I}(indices, L::Int) where {T, N, I} = OneHotArray{T, N, N+1, I}(indices, L)
OneHotArray(indices::T, L::Int) where {T<:Integer} = OneHotArray{T, 0, 1, T}(indices, L)
OneHotArray(indices::I, L::Int) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, N, N+1, I}(indices, L)

_indices(x::OneHotArray) = x.indices
_indices(x::Base.ReshapedArray{<: Any, <: Any, <: OneHotArray}) =
_indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) =
reshape(parent(x).indices, x.dims[2:end])

"""
OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
OneHotVector{T} = OneHotArray{T, 0, 1, T}
OneHotVector(indices, L)
A one-hot vector with `L` labels (i.e. `length(A) == L` and `count(A) == 1`) typically constructed by [`onehot`](@ref).
Stored efficiently as a single index of type `T`, usually `UInt32`.
"""
const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
const OneHotVector{T} = OneHotArray{T, 0, 1, T}
OneHotVector(idx, L) = OneHotArray(idx, L)

"""
OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I}
OneHotMatrix{T, I} = OneHotArray{T, 1, 2, I}
OneHotMatrix(indices, L)
A one-hot matrix (with `L` labels) typically constructed using [`onehotbatch`](@ref).
Stored efficiently as a vector of indices with type `I` and eltype `T`.
"""
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I}

OneHotVector(idx, L) = OneHotArray(idx, L)
const OneHotMatrix{T, I} = OneHotArray{T, 1, 2, I}
OneHotMatrix(indices, L) = OneHotArray(indices, L)

# use this type so reshaped arrays hit fast paths
# e.g. argmax
const OneHotLike{T, L, N, var"N+1", I} =
Union{OneHotArray{T, L, N, var"N+1", I},
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}
const OneHotLike{T, N, var"N+1", I} =
Union{OneHotArray{T, N, var"N+1", I},
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, <:Any, <:Any, I}}}

_isonehot(x::OneHotArray) = true
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) = (size(x, 1) == parent(x).nlabels)

Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...)
_check_nlabels(L, xs::OneHotLike...) = all(size.(xs, 1) .== L)

function Base.getindex(x::OneHotArray{<:Any, <:Any, N}, i::Integer, I::Vararg{Any, N}) where N
@boundscheck checkbounds(x, i, I...)
return x.indices[I...] .== i
end
_nlabels(x::OneHotArray) = size(x, 1)
function _nlabels(x::OneHotLike, xs::OneHotLike...)
L = size(x, 1)
_check_nlabels(L, xs...) ||
throw(DimensionMismatch("The number of labels are not the same for all one-hot arrays."))

function Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L
@boundscheck checkbounds(x, :, I...)
return OneHotArray(x.indices[I...], L)
return L
end

Base.size(x::OneHotArray) = (x.nlabels, size(x.indices)...)

function Base.getindex(x::OneHotArray{<:Any, N}, i::Int, I::Vararg{Int, N}) where N
@boundscheck (1 <= i <= x.nlabels) || throw(BoundsError(x, (i, I...)))
return x.indices[I...] .== i
end
# the method above is faster on the CPU but will scalar index on the GPU
# so we define the method below to pass the extra indices directly to GPU array
function Base.getindex(x::OneHotArray{<:Any, N, <:Any, <:AbstractGPUArray},
i::Int,
I::Vararg{Any, N}) where N
@boundscheck (1 <= i <= x.nlabels) || throw(BoundsError(x, (i, I...)))
return x.indices[I...] .== i
end
function Base.getindex(x::OneHotArray{<:Any, N}, ::Colon, I::Vararg{Any, N}) where N
return OneHotArray(x.indices[I...], x.nlabels)
end
Base.getindex(x::OneHotArray, ::Colon) = BitVector(reshape(x, :))
Base.getindex(x::OneHotArray{<:Any, <:Any, N}, ::Colon, ::Vararg{Colon, N}) where N = x
Base.getindex(x::OneHotArray{<:Any, N}, ::Colon, ::Vararg{Colon, N}) where N = x

function Base.showarg(io::IO, x::OneHotArray, toplevel)
print(io, ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(")
Expand All @@ -77,38 +96,44 @@ end
# copy CuArray versions back before trying to print them:
for fun in (:show, :print_array) # print_array is used by 3-arg show
@eval begin
Base.$fun(io::IO, X::OneHotLike{T, L, N, var"N+1", <:AbstractGPUArray}) where {T, L, N, var"N+1"} =
Base.$fun(io::IO, X::OneHotLike{T, N, var"N+1", <:AbstractGPUArray}) where {T, N, var"N+1"} =
Base.$fun(io, adapt(Array, X))
Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, <:Any, <:AbstractGPUArray}}) where {T, L, N} =
Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, N, <:Any, <:AbstractGPUArray}}) where {T, N} =
Base.$fun(io, adapt(Array, X))
end
end

_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, var"N+1", <:Union{Integer, AbstractArray}}) where {var"N+1"} = Array{Bool, var"N+1"}
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, var"N+1", <:AbstractGPUArray}) where {var"N+1"} = AbstractGPUArray{Bool, var"N+1"}
_onehot_bool_type(::OneHotLike{<:Any, <:Any, var"N+1", <:Union{Integer, AbstractArray}}) where {var"N+1"} = Array{Bool, var"N+1"}
_onehot_bool_type(::OneHotLike{<:Any, <:Any, var"N+1", <:AbstractGPUArray}) where {var"N+1"} = AbstractGPUArray{Bool, var"N+1"}

_notall_onehot(x::OneHotArray, xs::OneHotArray...) = false
_notall_onehot(x::OneHotLike, xs::OneHotLike...) = any(x -> !_isonehot(x), (x, xs...))

function Base.cat(x::OneHotLike{<:Any, L}, xs::OneHotLike{<:Any, L}...; dims::Int) where L
if isone(dims) || any(x -> !_isonehot(x), (x, xs...))
function Base.cat(x::OneHotLike{<:Any, <:Any, N}, xs::OneHotLike...; dims::Int) where N
if isone(dims) || _notall_onehot(x, xs...)
return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)
else
L = _nlabels(x, xs...)

return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L)
end
end

Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2)
Base.vcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 1)
Base.vcat(x::OneHotLike, xs::OneHotLike...) =
vcat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...)

# optimized concatenation for matrices and vectors of same parameters
Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 2}} =
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L)
Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 1}} =
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L)
Base.hcat(x::OneHotMatrix, xs::OneHotMatrix...) =
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), _nlabels(x, xs...))
Base.hcat(x::OneHotVector, xs::OneHotVector...) =
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), _nlabels(x, xs...))

MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatrix(_indices.(xs), L)
MLUtils.batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(_indices.(xs), _nlabels(xs...))

Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)
Adapt.adapt_structure(T, x::OneHotArray) = OneHotArray(adapt(T, _indices(x)), x.nlabels)

function Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, var"N+1", T}}) where {var"N+1", T <: AbstractGPUArray}
function Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, var"N+1", T}}) where {var"N+1", T <: AbstractGPUArray}
# We want CuArrayStyle{N+1}(). There's an AbstractGPUArrayStyle but it doesn't do what we need.
S = Base.BroadcastStyle(T)
# S has dim N not N+1. The following hack to fix it relies on the arraystyle having N as its first type parameter, which
Expand Down
20 changes: 10 additions & 10 deletions src/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
function Base.:(*)(A::AbstractMatrix, B::OneHotLike)
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
size(A, 2) == size(B, 1) || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(size(B, 1))"))
return A[:, onecold(B)]
end

function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L, 1}) where L
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, 1})
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
size(A, 2) == size(B, 1) || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(size(B, 1))"))
return NNlib.gather(A, _indices(B))
end

Expand All @@ -18,16 +18,16 @@ end

for wrapper in [:Adjoint, :Transpose]
@eval begin
function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector{<:Any, L}) where {L, T}
size(A, 2) == L ||
throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector) where T
size(A, 2) == length(b) ||
throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(length(b))"))

return A[:, onecold(b)]
end

function Base.:*(A::$wrapper{<:Number, <:AbstractVector{T}}, b::OneHotVector{<:Any, L}) where {L, T}
size(A, 2) == L ||
throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
function Base.:*(A::$wrapper{<:Number, <:AbstractVector{T}}, b::OneHotVector) where T
size(A, 2) == length(b) ||
throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(length(b))"))

return A[onecold(b)]
end
Expand Down
7 changes: 4 additions & 3 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ julia> hcat(αβγ...) # preserves sparsity
function onehot(x, labels)
i = _findval(x, labels)
isnothing(i) && error("Value $x is not in labels")
OneHotVector{UInt32, length(labels)}(i)
OneHotVector{UInt32}(i, length(labels))
end

function onehot(x, labels, default)
i = _findval(x, labels)
isnothing(i) && return onehot(default, labels)
OneHotVector{UInt32, length(labels)}(i)
OneHotVector{UInt32}(i, length(labels))
end

_findval(val, labels) = findfirst(isequal(val), labels)
Expand Down Expand Up @@ -135,6 +135,7 @@ function onecold(y::AbstractArray, labels = 1:size(y, 1))
end

_fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1)
_fast_argmax(x::OneHotArray) = _indices(x)
function _fast_argmax(x::OneHotLike)
if _isonehot(x)
return _indices(x)
Expand All @@ -147,4 +148,4 @@ ChainRulesCore.@non_differentiable onehot(::Any...)
ChainRulesCore.@non_differentiable onehotbatch(::Any...)
ChainRulesCore.@non_differentiable onecold(::Any...)

ChainRulesCore.@non_differentiable (::Type{<:OneHotArray})(indices::Any, L::Integer)
ChainRulesCore.@non_differentiable (::Type{<:OneHotArray})(indices::Any, L::Int)
2 changes: 1 addition & 1 deletion test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end
# vector indexing
@test ov[3] == (ov.indices == 3)
@test ov[:] == ov

# matrix indexing
@test om[3, 3] == (om.indices[3] == 3)
@test om[:, 3] == OneHotVector(om.indices[3], 10)
Expand Down
6 changes: 6 additions & 0 deletions test/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
@test onecold(onehot(0.0, floats)) == 1
@test onecold(onehot(-0.0, floats)) == 2 # as it uses isequal
@test onecold(onehot(Inf, floats)) == 5

# inferrabiltiy tests
@test @inferred(onehot(20, 10:10:30)) == [false, true, false]
@test @inferred(onehot(40, (10,20,30), 20)) == [false, true, false]
@test @inferred(onehotbatch([20, 10], 10:10:30)) == Bool[0 1; 1 0; 0 0]
@test @inferred(onehotbatch([40, 10], (10,20,30), 20)) == Bool[0 1; 1 0; 0 0]
end

@testset "onecold" begin
Expand Down

0 comments on commit ddbba63

Please sign in to comment.