From 0fa4728b522627979bde2f6b395d62a149859a49 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 22 Feb 2022 11:07:02 -0600 Subject: [PATCH 1/3] Make getindex/length the default interface. --- src/MLUtils.jl | 2 +- src/batchview.jl | 2 +- src/observation.jl | 22 ++++++++++++++-------- src/obstransform.jl | 20 ++++++++++---------- src/obsview.jl | 3 +-- 5 files changed, 27 insertions(+), 22 deletions(-) diff --git a/src/MLUtils.jl b/src/MLUtils.jl index 8454231..95b42bd 100644 --- a/src/MLUtils.jl +++ b/src/MLUtils.jl @@ -22,7 +22,7 @@ export mapobs, groupobs, joinobs, shuffleobs - + include("batchview.jl") export batchsize, BatchView diff --git a/src/batchview.jl b/src/batchview.jl index 90c27dc..bf7bc33 100644 --- a/src/batchview.jl +++ b/src/batchview.jl @@ -100,7 +100,7 @@ Return the fixed size of each batch in `data`. """ batchsize(A::BatchView) = A.batchsize -numobs(A::BatchView) = A.count +Base.length(A::BatchView) = A.count getobs(A::BatchView) = getobs(A.data) getobs(A::BatchView, i::Int) = getobs(A.data, _batchrange(A, i)) diff --git a/src/observation.jl b/src/observation.jl index ff4d7fd..e362368 100644 --- a/src/observation.jl +++ b/src/observation.jl @@ -2,8 +2,14 @@ numobs(data) Return the total number of observations contained in `data`. + If `data` does not have `numobs` defined, then this function falls back to `length(data)`. +Authors of custom data containers should implement +`Base.length` for their type instead of `numobs`. +`numobs` should only be implemented for types where there is a +difference between `numobs` and `Base.length` +(such as multi-dimensional arrays). See also [`getobs`](@ref) """ @@ -18,16 +24,20 @@ numobs(data) = length(data) Return the observations corresponding to the observation-index `idx`. Note that `idx` can be any type as long as `data` has defined `getobs` for that type. + If `data` does not have `getobs` defined, then this function falls back to `data[idx]`. +Authors of custom data containers should implement +`Base.getindex` for their type instead of `getobs`. +`getobs` should only be implemented for types where there is a +difference between `getobs` and `Base.getindex` +(such as multi-dimensional arrays). The returned observation(s) should be in the form intended to be passed as-is to some learning algorithm. There is no strict interface requirement on how this "actual data" must look like. - Every author behind some custom data container can make this decision themselves. - The output should be consistent when `idx` is a scalar vs vector. See also [`getobs!`](@ref) and [`numobs`](@ref) @@ -64,13 +74,9 @@ getobs!(buffer, data, idx) = getobs(data, idx) abstract type AbstractDataContainer end -Base.getindex(x::AbstractDataContainer, i) = getobs(x, i) -Base.length(x::AbstractDataContainer) = numobs(x) -Base.size(x::AbstractDataContainer) = (length(x),) - Base.iterate(x::AbstractDataContainer, state = 1) = - (state > length(x)) ? nothing : (x[state], state + 1) -Base.lastindex(x::AbstractDataContainer) = length(x) + (state > numobs(x)) ? nothing : (getobs(x, state), state + 1) +Base.lastindex(x::AbstractDataContainer) = numobs(x) # -------------------------------------------------------------------- # Arrays diff --git a/src/obstransform.jl b/src/obstransform.jl index a5630f5..24f521e 100644 --- a/src/obstransform.jl +++ b/src/obstransform.jl @@ -1,7 +1,7 @@ # mapobs -struct MappedData{F,D} +struct MappedData{F,D} <: AbstractDataContainer f::F data::D end @@ -9,9 +9,9 @@ end Base.show(io::IO, data::MappedData) = print(io, "mapobs($(data.f), $(summary(data.data)))") Base.show(io::IO, data::MappedData{F,<:AbstractArray}) where {F} = print(io, "mapobs($(data.f), $(ShowLimit(data.data, limit=80)))") -numobs(data::MappedData) = numobs(data.data) -getobs(data::MappedData, idx::Int) = data.f(getobs(data.data, idx)) -getobs(data::MappedData, idxs::AbstractVector) = data.f.(getobs(data.data, idxs)) +Base.length(data::MappedData) = numobs(data.data) +Base.getindex(data::MappedData, idx::Int) = data.f(getobs(data.data, idx)) +Base.getindex(data::MappedData, idxs::AbstractVector) = data.f.(getobs(data.data, idxs)) """ @@ -38,14 +38,14 @@ Returns a tuple of transformed data containers. mapobs(fs::Tuple, data) = Tuple(mapobs(f, data) for f in fs) -struct NamedTupleData{TData,F} +struct NamedTupleData{TData,F} <: AbstractDataContainer data::TData namedfs::NamedTuple{F} end -numobs(data::NamedTupleData) = numobs(getfield(data, :data)) +Base.length(data::NamedTupleData) = numobs(getfield(data, :data)) -function getobs(data::NamedTupleData{TData,F}, idx::Int) where {TData,F} +function Base.getindex(data::NamedTupleData{TData,F}, idx::Int) where {TData,F} obs = getobs(getfield(data, :data), idx) namedfs = getfield(data, :namedfs) return NamedTuple{F}(f(obs) for f in namedfs) @@ -126,16 +126,16 @@ end # joinumobs -struct JoinedData{T,N} +struct JoinedData{T,N} <: AbstractDataContainer datas::NTuple{N,T} ns::NTuple{N,Int} end JoinedData(datas) = JoinedData(datas, numobs.(datas)) -numobs(data::JoinedData) = sum(data.ns) +Base.length(data::JoinedData) = sum(data.ns) -function getobs(data::JoinedData, idx) +function Base.getindex(data::JoinedData, idx) for (i, n) in enumerate(data.ns) if idx <= n return getobs(data.datas[i], idx) diff --git a/src/obsview.jl b/src/obsview.jl index 0127403..a8ba1af 100644 --- a/src/obsview.jl +++ b/src/obsview.jl @@ -178,11 +178,10 @@ end Base.IteratorEltype(::Type{<:ObsView}) = Base.EltypeUnknown() -# override AbstractDataContainer defaults Base.getindex(subset::ObsView, idx) = obsview(subset.data, subset.indices[idx]) -numobs(subset::ObsView) = length(subset.indices) +Base.length(subset::ObsView) = length(subset.indices) getobs(subset::ObsView) = getobs(subset.data, subset.indices) getobs(subset::ObsView, idx) = getobs(subset.data, subset.indices[idx]) From a771ceaadad9b753641e7e91fc6c651a23a93452 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 22 Feb 2022 11:13:06 -0600 Subject: [PATCH 2/3] Add back size(::AbstractDataContainer) --- src/observation.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/observation.jl b/src/observation.jl index e362368..f831a56 100644 --- a/src/observation.jl +++ b/src/observation.jl @@ -74,6 +74,7 @@ getobs!(buffer, data, idx) = getobs(data, idx) abstract type AbstractDataContainer end +Base.size(x::AbstractDataContainer) = (numobs(x),) Base.iterate(x::AbstractDataContainer, state = 1) = (state > numobs(x)) ? nothing : (getobs(x, state), state + 1) Base.lastindex(x::AbstractDataContainer) = numobs(x) From c0a19e33e8f69383e881d7c5436f5887b42b6f47 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 22 Feb 2022 11:27:27 -0600 Subject: [PATCH 3/3] Override iterate for BatchView --- src/batchview.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/batchview.jl b/src/batchview.jl index bf7bc33..c1cd63a 100644 --- a/src/batchview.jl +++ b/src/batchview.jl @@ -119,6 +119,10 @@ function Base.getindex(A::BatchView, is::AbstractVector) obsview(A.data, obsindices) end +# override AbstractDataContainer default +Base.iterate(A::BatchView, state = 1) = + (state > numobs(A)) ? nothing : (A[state], state + 1) + obsview(A::BatchView) = A obsview(A::BatchView, i) = A[i]