diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bfec232..2a4cf15 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,8 +17,9 @@ jobs: fail-fast: false matrix: version: - - '1.5' # Replace this with the minimum Julia version that your package supports. - # - '1' # automatically expands to the latest stable 1.x release of Julia + - '1.0' + - '1.6' # Replace this with the minimum Julia version that your package supports. + - '1' # automatically expands to the latest stable 1.x release of Julia - 'nightly' os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index 1f8b8a9..84dc030 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,11 @@ name = "Functors" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" authors = ["Mike J Innes "] -version = "0.2.7" +version = "0.2.8" [compat] -julia = "1" Documenter = "0.27" +julia = "1" [extras] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/src/Functors.jl b/src/Functors.jl index dcf21f8..5247df1 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -23,6 +23,7 @@ usually using the macro [@functor](@ref). """ functor +@static if VERSION >= v"1.5" # var"@functor" doesn't work on 1.0, temporarily disable """ @functor T @functor T (x,) @@ -65,6 +66,7 @@ TwoThirds(Foo(10, 20), Foo(3, 4), 560) ``` """ var"@functor" +end # VERSION """ Functors.isleaf(x) @@ -182,6 +184,16 @@ This function walks (maps) over `xs` calling the continuation `f'` to continue t julia> fmap(x -> 10x, m, walk=(f, x) -> x isa Bar ? x : Functors._default_walk(f, x)) Foo(Bar([1, 2, 3]), (40, 50, Bar(Foo(6, 7)))) ``` + +The behaviour when the same node appears twice can be altered by giving a value +to the `prune` keyword, which is then used in place of all but the first: + +```jldoctest +julia> twice = [1, 2]; + +julia> fmap(float, (x = twice, y = [1,2], z = twice); prune = missing) +(x = [1.0, 2.0], y = [1.0, 2.0], z = missing) +``` """ fmap diff --git a/src/functor.jl b/src/functor.jl index 608dbd3..b8cb8f5 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -3,7 +3,7 @@ functor(T, x) = (), _ -> x functor(x) = functor(typeof(x), x) functor(::Type{<:Tuple}, x) = x, y -> y -functor(::Type{<:NamedTuple}, x) = x, y -> y +functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty(x, s), L)), identity functor(::Type{<:AbstractArray}, x) = x, y -> y functor(::Type{<:AbstractArray{<:Number}}, x) = (), _ -> x @@ -43,12 +43,11 @@ function _default_walk(f, x) re(map(f, func)) end -function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict()) - haskey(cache, x) && return cache[x] - y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x) - cache[x] = y +struct NoKeyword end - return y +function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = NoKeyword()) + haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune + cache[x] = exclude(x) ? f(x) : walk(x -> fmap(f, x; exclude=exclude, walk=walk, cache=cache, prune=prune), x) end ### @@ -74,27 +73,16 @@ end ### Vararg forms ### -function fmap(f, x, dx...; cache = IdDict()) - haskey(cache, x) && return cache[x] - cache[x] = isleaf(x) ? f(x, dx...) : _default_walk((x...) -> fmap(f, x..., cache = cache), x, dx...) +function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = NoKeyword()) + haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune + cache[x] = exclude(x) ? f(x, ys...) : walk((xy...,) -> fmap(f, xy...; exclude=exclude, walk=walk, cache=cache, prune=prune), x, ys...) end -function functor_tuple(f, x::Tuple, dx::Tuple) - map(x, dx) do x, x̄ - _default_walk(f, x, x̄) - end -end -functor_tuple(f, x, dx) = f(x, dx) -functor_tuple(f, x, ::Nothing) = x - -function _default_walk(f, x, dx) +function _default_walk(f, x, ys...) func, re = functor(x) - map(func, dx) do x, x̄ - # functor_tuple(f, x, x̄) - f(x, x̄) - end |> re + yfuncs = map(y -> functor(typeof(x), y)[1], ys) + re(map(f, func, yfuncs...)) end -_default_walk(f, ::Nothing, ::Nothing) = nothing ### ### FlexibleFunctors.jl @@ -112,9 +100,7 @@ function makeflexiblefunctor(m::Module, T, pfield) func = NamedTuple{pfields}(map(p -> getproperty(x, p), pfields)) return func, re end - end - end function flexiblefunctorm(T, pfield = :params) diff --git a/test/basics.jl b/test/basics.jl index 21aa445..914a5cd 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -1,25 +1,16 @@ -struct Foo - x - y -end + +using Functors: functor + +struct Foo; x; y; end @functor Foo -struct Bar - x -end +struct Bar; x; end @functor Bar -struct Baz - x - y - z -end -@functor Baz (y,) +struct OneChild3; x; y; z; end +@functor OneChild3 (y,) -struct NoChildren - x - y -end +struct NoChildren2; x; y; end @static if VERSION >= v"1.6" @testset "ComposedFunction" begin @@ -31,6 +22,10 @@ end end end +### +### Basic functionality +### + @testset "Nested" begin model = Bar(Foo(1, [1, 2, 3])) @@ -53,6 +48,73 @@ end @test fmap(f, x; exclude = x -> x isa AbstractArray) == x end +@testset "Property list" begin + model = OneChild3(1, 2, 3) + model′ = fmap(x -> 2x, model) + + @test (model′.x, model′.y, model′.z) == (1, 4, 3) +end + +@testset "cache" begin + shared = [1,2,3] + m1 = Foo(shared, Foo([1,2,3], Foo(shared, [1,2,3]))) + m1f = fmap(float, m1) + @test m1f.x === m1f.y.y.x + @test m1f.x !== m1f.y.x + m1p = fmapstructure(identity, m1; prune = nothing) + @test m1p == (x = [1, 2, 3], y = (x = [1, 2, 3], y = (x = nothing, y = [1, 2, 3]))) + + # A non-leaf node can also be repeated: + m2 = Foo(Foo(shared, 4), Foo(shared, 4)) + @test m2.x === m2.y + m2f = fmap(float, m2) + @test m2f.x.x === m2f.y.x + m2p = fmapstructure(identity, m2; prune = Bar(0)) + @test m2p == (x = (x = [1, 2, 3], y = 4), y = Bar(0)) + + # Repeated isbits types should not automatically be regarded as shared: + m3 = Foo(Foo(shared, 1:3), Foo(1:3, shared)) + m3p = fmapstructure(identity, m3; prune = 0) + @test m3p.y.y == 0 + @test_broken m3p.y.x == 1:3 +end + +@testset "functor(typeof(x), y) from @functor" begin + nt1, re1 = functor(Foo, (x=1, y=2, z=3)) + @test nt1 == (x = 1, y = 2) + @test re1((x = 10, y = 20)) == Foo(10, 20) + re1((y = 22, x = 11)) # gives Foo(22, 11), is that a bug? + + nt2, re2 = functor(Foo, (z=33, x=1, y=2)) + @test nt2 == (x = 1, y = 2) + @test re2((x = 10, y = 20)) == Foo(10, 20) + + @test_throws Exception functor(Foo, (z=33, x=1)) # type NamedTuple has no field y + + nt3, re3 = functor(OneChild3, (x=1, y=2, z=3)) + @test nt3 == (y = 2,) + @test re3((y = 20,)) == OneChild3(1, 20, 3) + re3(22) # gives OneChild3(1, 22, 3), is that a bug? +end + +@testset "functor(typeof(x), y) for Base types" begin + nt11, re11 = functor(NamedTuple{(:x, :y)}, (x=1, y=2, z=3)) + @test nt11 == (x = 1, y = 2) + @test re11((x = 10, y = 20)) == (x = 10, y = 20) + re11((y = 22, x = 11)) + re11((11, 22)) # passes right through + + nt12, re12 = functor(NamedTuple{(:x, :y)}, (z=33, x=1, y=2)) + @test nt12 == (x = 1, y = 2) + @test re12((x = 10, y = 20)) == (x = 10, y = 20) + + @test_throws Exception functor(NamedTuple{(:x, :y)}, (z=33, x=1)) +end + +### +### Extras +### + @testset "Walk" begin model = Foo((0, Bar([1, 2, 3])), [4, 5]) @@ -60,13 +122,6 @@ end @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) end -@testset "Property list" begin - model = Baz(1, 2, 3) - model′ = fmap(x -> 2x, model) - - @test (model′.x, model′.y, model′.z) == (1, 4, 3) -end - @testset "fcollect" begin m1 = [1, 2, 3] m2 = 1 @@ -78,7 +133,7 @@ end m1 = [1, 2, 3] m2 = Bar(m1) - m0 = NoChildren(:a, :b) + m0 = NoChildren2(:a, :b) m3 = Foo(m2, m0) m4 = Bar(m3) @test all(fcollect(m4) .=== [m4, m3, m2, m1, m0]) @@ -89,6 +144,79 @@ end @test all(fcollect(m3) .=== [m3, m1, m2]) end +### +### Vararg forms +### + +@testset "fmap(f, x, y)" begin + m1 = (x = [1,2], y = 3) + n1 = (x = [4,5], y = 6) + @test fmap(+, m1, n1) == (x = [5, 7], y = 9) + + # Reconstruction type comes from the first argument + foo1 = Foo([7,8], 9) + @test fmap(+, m1, foo1) == (x = [8, 10], y = 12) + @test fmap(+, foo1, n1) isa Foo + @test fmap(+, foo1, n1).x == [11, 13] + + # Mismatched trees should be an error + m2 = (x = [1,2], y = (a = [3,4], b = 5)) + n2 = (x = [6,7], y = 8) + @test_throws Exception fmap(first∘tuple, m2, n2) # ERROR: type Int64 has no field a + @test_throws Exception fmap(first∘tuple, m2, n2) + + # The cache uses IDs from the first argument + shared = [1,2,3] + m3 = (x = shared, y = [4,5,6], z = shared) + n3 = (x = shared, y = shared, z = [7,8,9]) + @test fmap(+, m3, n3) == (x = [2, 4, 6], y = [5, 7, 9], z = [2, 4, 6]) + z3 = fmap(+, m3, n3) + @test z3.x === z3.z + + # Pruning of duplicates: + @test fmap(+, m3, n3; prune = nothing) == (x = [2,4,6], y = [5,7,9], z = nothing) + + # More than two arguments: + z4 = fmap(+, m3, n3, m3, n3) + @test z4 == fmap(x -> 2x, z3) + @test z4.x === z4.z + + @test fmap(+, foo1, m1, n1) isa Foo + @static if VERSION >= v"1.6" # fails on Julia 1.0 + @test fmap(.*, m1, foo1, n1) == (x = [4*7, 2*5*8], y = 3*6*9) + end +end + +@static if VERSION >= v"1.6" # Julia 1.0: LoadError: error compiling top-level scope: type definition not allowed inside a local scope +@testset "old test update.jl" begin + struct M{F,T,S} + σ::F + W::T + b::S + end + + @functor M + + (m::M)(x) = m.σ.(m.W * x .+ m.b) + + m = M(identity, ones(Float32, 3, 4), zeros(Float32, 3)) + x = ones(Float32, 4, 2) + m̄, _ = gradient((m,x) -> sum(m(x)), m, x) + m̂ = Functors.fmap(m, m̄) do x, y + isnothing(x) && return y + isnothing(y) && return x + x .- 0.1f0 .* y + end + + @test m̂.W ≈ fill(0.8f0, size(m.W)) + @test m̂.b ≈ fill(-0.2f0, size(m.b)) +end +end # VERSION + +### +### FlexibleFunctors.jl +### + struct FFoo x y @@ -102,13 +230,13 @@ struct FBar end @flexiblefunctor FBar p -struct FBaz +struct FOneChild4 x y z p end -@flexiblefunctor FBaz p +@flexiblefunctor FOneChild4 p @testset "Flexible Nested" begin model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,)) @@ -132,7 +260,7 @@ end end @testset "Flexible Property list" begin - model = FBaz(1, 2, 3, (:x, :z)) + model = FOneChild4(1, 2, 3, (:x, :z)) model′ = fmap(x -> 2x, model) @test (model′.x, model′.y, model′.z) == (2, 2, 6) @@ -147,7 +275,7 @@ end @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3]) @test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4]) - m0 = NoChildren(:a, :b) + m0 = NoChildren2(:a, :b) m1 = [1, 2, 3] m2 = FBar(m1, ()) m3 = FFoo(m2, m0, (:x, :y,)) diff --git a/test/runtests.jl b/test/runtests.jl index 394cfbf..5b853fc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,6 @@ using Zygote include("basics.jl") include("base.jl") - include("update.jl") if VERSION < v"1.6" # || VERSION > v"1.7-" @warn "skipping doctests, on Julia $VERSION" diff --git a/test/update.jl b/test/update.jl deleted file mode 100644 index 0ed6bca..0000000 --- a/test/update.jl +++ /dev/null @@ -1,23 +0,0 @@ -@testset "Generalized fmap over equivalent functors" begin - struct M{F,T,S} - σ::F - W::T - b::S - end - - @functor M - - (m::M)(x) = m.σ.(m.W * x .+ m.b) - - m = M(identity, ones(Float32, 3, 4), zeros(Float32, 3)) - x = ones(Float32, 4, 2) - m̄, _ = gradient((m,x) -> sum(m(x)), m, x) - m̂ = Functors.fmap(m, m̄) do x, y - isnothing(x) && return y - isnothing(y) && return x - x .- 0.1f0 .* y - end - - @test m̂.W ≈ fill(0.8f0, size(m.W)) - @test m̂.b ≈ fill(-0.2f0, size(m.b)) -end