Skip to content

Commit 67d9d08

Browse files
add batch_sequence (#197)
1 parent 8d3b7d3 commit 67d9d08

File tree

8 files changed

+274
-212
lines changed

8 files changed

+274
-212
lines changed

Diff for: docs/src/api.md

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ shuffleobs
2727

2828
```@docs
2929
batch
30+
batch_sequence
3031
batchsize
3132
batchseq
3233
BatchView

Diff for: src/MLUtils.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,14 @@ export slidingwindow
6363
include("splitobs.jl")
6464
export splitobs
6565

66-
include("utils.jl")
66+
include("batch.jl")
6767
export batch,
6868
batchseq,
69-
chunk,
69+
batch_sequence,
70+
unbatch
71+
72+
include("utils.jl")
73+
export chunk,
7074
falses_like,
7175
fill_like,
7276
flatten,
@@ -79,7 +83,6 @@ export batch,
7983
rpad_constant,
8084
stack, # in Base since julia v1.9
8185
trues_like,
82-
unbatch,
8386
unsqueeze,
8487
unstack,
8588
zeros_like

Diff for: src/batch.jl

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
2+
"""
3+
batch(xs)
4+
5+
Batch the arrays in `xs` into a single array with
6+
an extra dimension.
7+
8+
If the elements of `xs` are tuples, named tuples, or dicts,
9+
the output will be of the same type.
10+
11+
See also [`unbatch`](@ref) and [`batch_sequence`](@ref).
12+
13+
# Examples
14+
15+
```jldoctest
16+
julia> batch([[1,2,3],
17+
[4,5,6]])
18+
3×2 Matrix{Int64}:
19+
1 4
20+
2 5
21+
3 6
22+
23+
julia> batch([(a=[1,2], b=[3,4])
24+
(a=[5,6], b=[7,8])])
25+
(a = [1 5; 2 6], b = [3 7; 4 8])
26+
```
27+
"""
28+
function batch(xs)
29+
# Fallback for generric iterables
30+
@assert length(xs) > 0 "Input should be non-empty"
31+
data = first(xs) isa AbstractArray ?
32+
similar(first(xs), size(first(xs))..., length(xs)) :
33+
Vector{eltype(xs)}(undef, length(xs))
34+
for (i, x) in enumerate(xs)
35+
data[batchindex(data, i)...] = x
36+
end
37+
return data
38+
end
39+
40+
batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
41+
42+
batch(xs::AbstractArray{<:AbstractArray}) = stack(xs)
43+
44+
function batch(xs::Vector{<:Tuple})
45+
@assert length(xs) > 0 "Input should be non-empty"
46+
n = length(first(xs))
47+
@assert all(length.(xs) .== n) "Cannot batch tuples with different lengths"
48+
return ntuple(i -> batch([x[i] for x in xs]), n)
49+
end
50+
51+
function batch(xs::Vector{<:NamedTuple})
52+
@assert length(xs) > 0 "Input should be non-empty"
53+
all_keys = [sort(collect(keys(x))) for x in xs]
54+
ks = all_keys[1]
55+
@assert all(==(ks), all_keys) "Cannot batch named tuples with different keys"
56+
return NamedTuple(k => batch([x[k] for x in xs]) for k in ks)
57+
end
58+
59+
function batch(xs::Vector{<:Dict})
60+
@assert length(xs) > 0 "Input should be non-empty"
61+
all_keys = [sort(collect(keys(x))) for x in xs]
62+
ks = all_keys[1]
63+
@assert all(==(ks), all_keys) "cannot batch dicts with different keys"
64+
return Dict(k => batch([x[k] for x in xs]) for k in ks)
65+
end
66+
67+
"""
68+
unbatch(x)
69+
70+
Reverse of the [`batch`](@ref) operation,
71+
unstacking the last dimension of the array `x`.
72+
73+
See also [`unstack`](@ref) and [`chunk`](@ref).
74+
75+
# Examples
76+
77+
```jldoctest
78+
julia> unbatch([1 3 5 7;
79+
2 4 6 8])
80+
4-element Vector{Vector{Int64}}:
81+
[1, 2]
82+
[3, 4]
83+
[5, 6]
84+
[7, 8]
85+
```
86+
"""
87+
unbatch(x::AbstractArray) = [getobs(x, i) for i in 1:numobs(x)]
88+
unbatch(x::AbstractVector) = x
89+
90+
"""
91+
batchseq(seqs, val = 0)
92+
93+
Take a list of `N` sequences, and turn them into a single sequence where each
94+
item is a batch of `N`. Short sequences will be padded by `val`.
95+
96+
# Examples
97+
98+
```jldoctest
99+
julia> batchseq([[1, 2, 3], [4, 5]], 0)
100+
3-element Vector{Vector{Int64}}:
101+
[1, 4]
102+
[2, 5]
103+
[3, 0]
104+
```
105+
"""
106+
function batchseq(xs, val = 0)
107+
n = maximum(numobs, xs)
108+
xs_ = [rpad_constant(x, n, val; dims=ndims(x)) for x in xs]
109+
return [batch([getobs(xs_[j], i) for j = 1:length(xs_)]) for i = 1:n]
110+
end
111+
112+
"""
113+
batch_sequence(seqs; pad = 0)
114+
115+
Take a list of `N` sequences `seqs`,
116+
where the `i`-th sequence is an array with last dimension `Li`,
117+
and turn the into a single array with size `(..., Lmax, N)`.
118+
119+
The sequences need to have the same size, except for the last dimension.
120+
121+
Short sequences will be padded by `pad`.
122+
123+
See also [`batch`](@ref).
124+
125+
# Examples
126+
127+
```jldoctest
128+
julia> batch_sequence([[1, 2, 3], [10, 20]])
129+
3×2 Matrix{Int64}:
130+
1 10
131+
2 20
132+
3 0
133+
134+
julia> seqs = (ones(2, 3), fill(2.0, (2, 5)))
135+
([1.0 1.0 1.0; 1.0 1.0 1.0], [2.0 2.0 … 2.0 2.0; 2.0 2.0 … 2.0 2.0])
136+
137+
julia> batch_sequence(seqs, pad=-1)
138+
2×5×2 Array{Float64, 3}:
139+
[:, :, 1] =
140+
1.0 1.0 1.0 -1.0 -1.0
141+
1.0 1.0 1.0 -1.0 -1.0
142+
143+
[:, :, 2] =
144+
2.0 2.0 2.0 2.0 2.0
145+
2.0 2.0 2.0 2.0 2.0
146+
```
147+
"""
148+
function batch_sequence(xs; pad = 0)
149+
sz = size(xs[1])[1:end-1]
150+
@assert all(x -> size(x)[1:end-1] == sz, xs) "Array dimensions do not match."
151+
n = ndims(xs[1])
152+
Lmax = maximum(numobs, xs)
153+
padded_seqs = [rpad_constant(x, Lmax, pad, dims=n) for x in xs]
154+
return batch(padded_seqs)
155+
end

Diff for: src/folds.jl

+19-20
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,29 @@ Compute the train/validation assignments for `k` repartitions of
77
first vector contains the index-vectors for the training subsets,
88
and the second vector the index-vectors for the validation subsets
99
respectively. A general rule of thumb is to use either `k = 5` or
10-
`k = 10`. The following code snippet generates the indices
11-
assignments for `k = 5`
12-
13-
```julia
14-
julia> train_idx, val_idx = kfolds(10, 5);
15-
```
10+
`k = 10`.
1611
1712
Each observation is assigned to the validation subset once (and
1813
only once). Thus, a union over all validation index-vectors
1914
reproduces the full range `1:n`. Note that there is no random
2015
assignment of observations to subsets, which means that adjacent
2116
observations are likely to be part of the same validation subset.
2217
23-
```julia
18+
# Examples
19+
20+
```jldoctest
21+
julia> train_idx, val_idx = kfolds(10, 5);
22+
2423
julia> train_idx
25-
5-element Array{Array{Int64,1},1}:
26-
[3,4,5,6,7,8,9,10]
27-
[1,2,5,6,7,8,9,10]
28-
[1,2,3,4,7,8,9,10]
29-
[1,2,3,4,5,6,9,10]
30-
[1,2,3,4,5,6,7,8]
24+
5-element Vector{Vector{Int64}}:
25+
[3, 4, 5, 6, 7, 8, 9, 10]
26+
[1, 2, 5, 6, 7, 8, 9, 10]
27+
[1, 2, 3, 4, 7, 8, 9, 10]
28+
[1, 2, 3, 4, 5, 6, 9, 10]
29+
[1, 2, 3, 4, 5, 6, 7, 8]
3130
3231
julia> val_idx
33-
5-element Array{UnitRange{Int64},1}:
32+
5-element Vector{UnitRange{Int64}}:
3433
1:2
3534
3:4
3635
5:6
@@ -42,7 +41,7 @@ function kfolds(n::Integer, k::Integer = 5)
4241
2 <= k <= n || throw(ArgumentError("n must be positive and k must to be within 2:$(max(2,n))"))
4342
# Compute the size of each fold. This is important because
4443
# in general the number of total observations might not be
45-
# divideable by k. In such cases it is custom that the remaining
44+
# divisible by k. In such cases it is custom that the remaining
4645
# observations are divided among the folds. Thus some folds
4746
# have one more observation than others.
4847
sizes = fill(floor(Int, n/k), k)
@@ -52,15 +51,15 @@ function kfolds(n::Integer, k::Integer = 5)
5251
# Compute start offset for each fold
5352
offsets = cumsum(sizes) .- sizes .+ 1
5453
# Compute the validation indices using the offsets and sizes
55-
val_indices = map((o,s)->(o:o+s-1), offsets, sizes)
54+
val_indices = map((o,s) -> (o:o+s-1), offsets, sizes)
5655
# The train indices are then the indicies not in validation
57-
train_indices = map(idx->setdiff(1:n,idx), val_indices)
56+
train_indices = map(idx -> setdiff(1:n, idx), val_indices)
5857
# We return a tuple of arrays
59-
train_indices, val_indices
58+
return train_indices, val_indices
6059
end
6160

6261
"""
63-
kfolds(data, [k = 5])
62+
kfolds(data, k = 5)
6463
6564
Repartition a `data` container `k` times using a `k` folds
6665
strategy and return the sequence of folds as a lazy iterator.
@@ -96,7 +95,7 @@ By default the folds are created using static splits. Use
9695
folds.
9796
9897
```julia
99-
for (x_train, x_val) in kfolds(shuffleobs(X), k = 10)
98+
for (x_train, x_val) in kfolds(shuffleobs(X), k=10)
10099
# ...
101100
end
102101
```

Diff for: src/utils.jl

+1-110
Original file line numberDiff line numberDiff line change
@@ -326,116 +326,6 @@ function group_indices(classes::T) where T<:AbstractVector
326326
end
327327

328328

329-
"""
330-
batch(xs)
331-
332-
Batch the arrays in `xs` into a single array with
333-
an extra dimension.
334-
335-
If the elements of `xs` are tuples, named tuples, or dicts,
336-
the output will be of the same type.
337-
338-
See also [`unbatch`](@ref).
339-
340-
# Examples
341-
342-
```jldoctest
343-
julia> batch([[1,2,3],
344-
[4,5,6]])
345-
3×2 Matrix{Int64}:
346-
1 4
347-
2 5
348-
3 6
349-
350-
julia> batch([(a=[1,2], b=[3,4])
351-
(a=[5,6], b=[7,8])])
352-
(a = [1 5; 2 6], b = [3 7; 4 8])
353-
```
354-
"""
355-
function batch(xs)
356-
# Fallback for generric iterables
357-
@assert length(xs) > 0 "Input should be non-empty"
358-
data = first(xs) isa AbstractArray ?
359-
similar(first(xs), size(first(xs))..., length(xs)) :
360-
Vector{eltype(xs)}(undef, length(xs))
361-
for (i, x) in enumerate(xs)
362-
data[batchindex(data, i)...] = x
363-
end
364-
return data
365-
end
366-
367-
batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
368-
369-
batch(xs::AbstractArray{<:AbstractArray}) = stack(xs)
370-
371-
function batch(xs::Vector{<:Tuple})
372-
@assert length(xs) > 0 "Input should be non-empty"
373-
n = length(first(xs))
374-
@assert all(length.(xs) .== n) "Cannot batch tuples with different lengths"
375-
return ntuple(i -> batch([x[i] for x in xs]), n)
376-
end
377-
378-
function batch(xs::Vector{<:NamedTuple})
379-
@assert length(xs) > 0 "Input should be non-empty"
380-
all_keys = [sort(collect(keys(x))) for x in xs]
381-
ks = all_keys[1]
382-
@assert all(==(ks), all_keys) "Cannot batch named tuples with different keys"
383-
return NamedTuple(k => batch([x[k] for x in xs]) for k in ks)
384-
end
385-
386-
function batch(xs::Vector{<:Dict})
387-
@assert length(xs) > 0 "Input should be non-empty"
388-
all_keys = [sort(collect(keys(x))) for x in xs]
389-
ks = all_keys[1]
390-
@assert all(==(ks), all_keys) "cannot batch dicts with different keys"
391-
return Dict(k => batch([x[k] for x in xs]) for k in ks)
392-
end
393-
394-
"""
395-
unbatch(x)
396-
397-
Reverse of the [`batch`](@ref) operation,
398-
unstacking the last dimension of the array `x`.
399-
400-
See also [`unstack`](@ref) and [`chunk`](@ref).
401-
402-
# Examples
403-
404-
```jldoctest
405-
julia> unbatch([1 3 5 7;
406-
2 4 6 8])
407-
4-element Vector{Vector{Int64}}:
408-
[1, 2]
409-
[3, 4]
410-
[5, 6]
411-
[7, 8]
412-
```
413-
"""
414-
unbatch(x::AbstractArray) = [getobs(x, i) for i in 1:numobs(x)]
415-
unbatch(x::AbstractVector) = x
416-
417-
"""
418-
batchseq(seqs, val = 0)
419-
420-
Take a list of `N` sequences, and turn them into a single sequence where each
421-
item is a batch of `N`. Short sequences will be padded by `val`.
422-
423-
# Examples
424-
425-
```jldoctest
426-
julia> batchseq([[1, 2, 3], [4, 5]], 0)
427-
3-element Vector{Vector{Int64}}:
428-
[1, 4]
429-
[2, 5]
430-
[3, 0]
431-
```
432-
"""
433-
function batchseq(xs, val = 0)
434-
n = maximum(numobs, xs)
435-
xs_ = [rpad_constant(x, n, val; dims=ndims(x)) for x in xs]
436-
return [batch([getobs(xs_[j], i) for j = 1:length(xs_)]) for i = 1:n]
437-
end
438-
439329
"""
440330
rpad_constant(v::AbstractArray, n::Union{Integer, Tuple}, val = 0; dims=:)
441331
@@ -765,3 +655,4 @@ function rrule(::typeof(fill_like), x::AbstractArray, val, T::Type, sz)
765655
end
766656
return fill_like(x, val, T, sz), fill_like_pullback
767657
end
658+

0 commit comments

Comments
 (0)