Skip to content

Commit 04d08aa

Browse files
authored
Add more base types (FluxML#47)
* add more base types * add tests
1 parent 75fef99 commit 04d08aa

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

src/base.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11

22
@functor Base.RefValue
33

4+
@functor Base.Pair
5+
6+
@functor Base.Generator # aka Iterators.map
7+
48
functor(::Type{<:Base.ComposedFunction}, x) = (outer = x.outer, inner = x.inner), y -> Base.ComposedFunction(y.outer, y.inner)
59

10+
@functor Base.Fix1
11+
@functor Base.Fix2
12+
613
###
714
### Array wrappers
815
###
@@ -36,3 +43,26 @@ end
3643
_PermutedDimsArray(x, iperm) = PermutedDimsArray(x, iperm)
3744
_PermutedDimsArray(x::NamedTuple{(:parent,)}, iperm) = x.parent
3845
_PermutedDimsArray(bc::Broadcast.Broadcasted, iperm) = _PermutedDimsArray(Broadcast.materialize(bc), iperm)
46+
47+
###
48+
### Iterators
49+
###
50+
51+
@functor Iterators.Accumulate
52+
# Count
53+
@functor Iterators.Cycle
54+
@functor Iterators.Drop
55+
@functor Iterators.DropWhile
56+
@functor Iterators.Enumerate
57+
@functor Iterators.Filter
58+
@functor Iterators.Flatten
59+
# IterationCutShort
60+
@functor Iterators.PartitionIterator
61+
@functor Iterators.ProductIterator
62+
@functor Iterators.Repeated
63+
@functor Iterators.Rest
64+
@functor Iterators.Reverse
65+
# Stateful
66+
@functor Iterators.Take
67+
@functor Iterators.TakeWhile
68+
@functor Iterators.Zip

test/base.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ end
3434
@test fmap(x -> x + 10, f1 f2) == Foo(11.1, 12.2) Bar(13.3)
3535
end
3636

37+
@testset "Pair, Fix12" begin
38+
@test fmap(sqrt, 4 => 9) === (2.0 => 3.0)
39+
40+
exclude = x -> x isa Number
41+
@test fmap(sqrt, Base.Fix1(/, 4); exclude)(10) == 0.2
42+
@test fmap(sqrt, Base.Fix2(/, 4); exclude)(10) == 5.0
43+
end
44+
3745
@testset "LinearAlgebra containers" begin
3846
@test fmapstructure(identity, [1,2,3]') == (parent = [1, 2, 3],)
3947
@test fmapstructure(identity, transpose([1,2,3])) == (parent = [1, 2, 3],)
@@ -84,3 +92,63 @@ end
8492
@test fmapstructure(identity, PermutedDimsArray([1 2; 3 4], (2,1))) == (parent = [1 2; 3 4],)
8593
@test fmap(exp, PermutedDimsArray([1 2; 3 4], (2,1))) isa PermutedDimsArray{Float64}
8694
end
95+
96+
@testset "Iterators" begin
97+
exclude = x -> x isa Array
98+
99+
x = fmap(complex, Iterators.map(sqrt, [1,2,3]); exclude) # Base.Generator
100+
@test x.iter isa Vector{<:Complex}
101+
@test collect(x) isa Vector{<:Complex}
102+
103+
x = fmap(complex, Iterators.accumulate(/, [1,2,3]); exclude)
104+
@test x.itr isa Vector{<:Complex}
105+
@test collect(x) isa Vector{<:Complex}
106+
107+
x = fmap(complex, Iterators.cycle([1,2,3]))
108+
@test x.xs isa Vector{<:Complex}
109+
@test first(x) isa Complex
110+
111+
x = fmap(complex, Iterators.drop([1,2,3], 1); exclude)
112+
@test x.xs isa Vector{<:Complex}
113+
@test collect(x) isa Vector{<:Complex}
114+
115+
116+
x = fmap(complex, Iterators.drop([1,2,3], 1); exclude)
117+
@test x.xs isa Vector{<:Complex}
118+
@test collect(x) isa Vector{<:Complex}
119+
120+
x = fmap(float, Iterators.dropwhile(<(2), [1,2,3]); exclude)
121+
@test x.xs isa Vector{Float64}
122+
@test collect(x) isa Vector{Float64}
123+
124+
x = fmap(complex, enumerate([1,2,3]))
125+
@test first(x) === (1, 1+0im)
126+
127+
x = fmap(float, Iterators.filter(<(3), [1,2,3]); exclude)
128+
@test collect(x) isa Vector{Float64}
129+
130+
x = fmap(complex, Iterators.flatten(([1,2,3], [4,5])))
131+
@test collect(x) isa Vector{<:Complex}
132+
133+
x = fmap(complex, Iterators.partition([1,2,3],2); exclude)
134+
@test first(x) isa AbstractVector{<:Complex}
135+
136+
x = fmap(complex, Iterators.product([1,2,3],[4,5]))
137+
@test first(x) === (1 + 0im, 4 + 0im)
138+
139+
x = fmap(complex, Iterators.repeated([1,2,3], 4); exclude) # Iterators.Take{Iterators.Repeated}
140+
@test first(x) isa Vector{<:Complex}
141+
142+
x = fmap(complex, Iterators.rest([1,2,3], 2); exclude)
143+
@test collect(x) isa Vector{<:Complex}
144+
145+
x = fmap(complex, Iterators.reverse([1,2,3]))
146+
@test collect(x) isa Vector{<:Complex}
147+
148+
x = fmap(float, Iterators.takewhile(<(2), [1,2,3]); exclude)
149+
@test collect(x) isa Vector{Float64}
150+
151+
x = fmap(complex, zip([1,2,3], [4,5]))
152+
@test x.is[1] isa Vector{<:Complex}
153+
@test collect(x) isa Vector{<:Tuple{Complex, Complex}}
154+
end

0 commit comments

Comments
 (0)