diff --git a/src/batchview.jl b/src/batchview.jl index 7778dae..f5c1bad 100644 --- a/src/batchview.jl +++ b/src/batchview.jl @@ -134,26 +134,45 @@ Base.@propagate_inbounds function getobs(A::BatchView) return _getbatch(A, 1:numobs(A.data)) end -Base.@propagate_inbounds function Base.getindex(A::BatchView, i::Int) - obsindices = _batchrange(A, i) +Base.@propagate_inbounds function Base.getindex(A::BatchView, i) + obsindices = _batchindexes(A, i) _getbatch(A, obsindices) end -Base.@propagate_inbounds function Base.getindex(A::BatchView, is::AbstractVector) - obsindices = union((_batchrange(A, i) for i in is)...)::Vector{Int} - _getbatch(A, obsindices) -end - -function _getbatch(A::BatchView{TElem, TData, Val{true}}, obsindices) where {TElem, TData} +function _getbatch(A::BatchView{<:Any, <:Any, Val{true}}, obsindices) batch([getobs(A.data, i) for i in obsindices]) end -function _getbatch(A::BatchView{TElem, TData, Val{false}}, obsindices) where {TElem, TData} +function _getbatch(A::BatchView{<:Any, <:Any, Val{false}}, obsindices) return [getobs(A.data, i) for i in obsindices] end -function _getbatch(A::BatchView{TElem, TData, Val{nothing}}, obsindices) where {TElem, TData} +function _getbatch(A::BatchView{<:Any, <:Any, Val{nothing}}, obsindices) getobs(A.data, obsindices) end +function getobs!(buffer, A::BatchView, i) + obsindices = _batchindexes(A, i) + return _getbatch!(buffer, A, obsindices) +end + +function _getbatch!(buffer, A::BatchView{<:Any, <:Any, Val{nothing}}, obsindices) + return getobs!(buffer, A.data, obsindices) +end + +# This collate=true specialization doesn't seem to be particularly useful, use collate=nothing instead. +function _getbatch!(buffer, A::BatchView{<:Any, <:Any, Val{true}}, obsindices) + for (i, idx) in enumerate(obsindices) + getobs!(buffer[i], A.data, idx) + end + return batch(buffer) +end + +function _getbatch!(buffer, A::BatchView{<:Any, <:Any, Val{false}}, obsindices) + for (i, idx) in enumerate(obsindices) + getobs!(buffer[i], A.data, idx) + end + return buffer +end + Base.parent(A::BatchView) = A.data Base.eltype(::BatchView{Tel}) where Tel = Tel @@ -169,6 +188,9 @@ Base.iterate(A::BatchView, state = 1) = return startidx:endidx end +@inline _batchindexes(A::BatchView, i::Integer) = _batchrange(A, i) +@inline _batchindexes(A::BatchView, is::AbstractVector{<:Integer}) = union((_batchrange(A, i) for i in is)...)::Vector{Int} + function Base.showarg(io::IO, A::BatchView, toplevel) print(io, "BatchView(") Base.showarg(io, parent(A), false) @@ -178,5 +200,3 @@ function Base.showarg(io::IO, A::BatchView, toplevel) print(io, ')') toplevel && print(io, " with eltype ", nameof(eltype(A))) # simplify end - -# -------------------------------------------------------------------- diff --git a/test/batchview.jl b/test/batchview.jl index 39fbdcf..f4ca79d 100644 --- a/test/batchview.jl +++ b/test/batchview.jl @@ -116,4 +116,36 @@ using MLUtils: obsview @test bv[2] == 6:10 @test_throws BoundsError bv[3] end + + + @testset "getobs!" begin + buf1 = rand(4, 3) + bv = BatchView(X, batchsize=3) + @test @inferred(getobs!(buf1, bv, 2)) === buf1 + @test buf1 == getobs(bv, 2) + + buf12 = [rand(4) for _=1:3] + bv12 = BatchView(X, batchsize=3, collate=false) + @test @inferred(getobs!(buf12, bv12, 2)) === buf12 + @test buf12 == getobs(bv12, 2) + + buf2 = rand(4, 6) + @test @inferred(getobs!(buf2, bv, [1,3])) === buf2 + @test buf2 == getobs(bv, [1,3]) + + @testset "custom type" begin # issue #156 + struct DummyData{X} + x::X + end + MLUtils.numobs(data::DummyData) = numobs(data.x) + MLUtils.getobs(data::DummyData, idx) = getobs(data.x, idx) + MLUtils.getobs!(buffer, data::DummyData, idx) = getobs!(buffer, data.x, idx) + + data = DummyData(X) + buf = rand(4, 3) + bv = BatchView(data, batchsize=3) + @test @inferred(getobs!(buf, bv, 2)) === buf + @test buf == getobs(bv, 2) + end + end end