Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make fmap(f, x, y) useful #37

Merged
merged 13 commits into from
Feb 9, 2022
36 changes: 11 additions & 25 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ functor(T, x) = (), _ -> x
functor(x) = functor(typeof(x), x)

functor(::Type{<:Tuple}, x) = x, y -> y
functor(::Type{<:NamedTuple}, x) = x, y -> y
functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty(x, s), L)), identity

functor(::Type{<:AbstractArray}, x) = x, y -> y
functor(::Type{<:AbstractArray{<:Number}}, x) = (), _ -> x
Expand Down Expand Up @@ -43,12 +43,11 @@ function _default_walk(f, x)
re(map(f, func))
end

function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict())
haskey(cache, x) && return cache[x]
y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x)
cache[x] = y
const INIT = Base._InitialValue() # sentinel value for keyword not supplied
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe less opaque to just define struct NoPrune end?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can do. Are you OK with the interface being that the absence of this keyword is how you turn this off? That's how init often works; and here, the values you might use for "don't prune" like false or nothing are all reasonable choices for the value to use during pruning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is weird to me, but the API makes sense. I'm happy.


return y
function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = INIT)
haskey(cache, x) && return prune === INIT ? cache[x] : prune
cache[x] = exclude(x) ? f(x) : walk(x -> fmap(f, x; exclude, walk, cache, prune), x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To use keywords like this, this package cannot claim to support Julia 1.0. At present it is only tested on 1.5+. Maybe we should just move to 1.6?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have on CI, might as well make it official.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what repo I was looking at, because I just checked back and it's 1.5 still. Do we need to cut a breaking release for minimum version bumps like this again?

Copy link
Member Author

@mcabbott mcabbott Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's 1.0 here: https://github.com/FluxML/Functors.jl/blob/master/Project.toml#L7

Tests do in fact pass on 1.0, for Functors v0.2.7, despite the lack of CI.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And, people say not to do this on a patch release, because it closes the door on a bugfix-for-1.0 release. Do we care?

We may also need a breaking release for #33, and to make the cache not used on isbits arguments. We could gang these together.

end

###
Expand All @@ -74,27 +73,16 @@ end
### Vararg forms
###

function fmap(f, x, dx...; cache = IdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x, dx...) : _default_walk((x...) -> fmap(f, x..., cache = cache), x, dx...)
function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = INIT)
haskey(cache, x) && return prune === INIT ? cache[x] : prune
cache[x] = exclude(x) ? f(x, ys...) : walk((xy...,) -> fmap(f, xy...; exclude, walk, cache, prune), x, ys...)
end

function functor_tuple(f, x::Tuple, dx::Tuple)
map(x, dx) do x, x̄
_default_walk(f, x, x̄)
end
end
functor_tuple(f, x, dx) = f(x, dx)
functor_tuple(f, x, ::Nothing) = x

function _default_walk(f, x, dx)
function _default_walk(f, x, ys...)
func, re = functor(x)
map(func, dx) do x, x̄
# functor_tuple(f, x, x̄)
f(x, x̄)
end |> re
yfuncs = map(y -> functor(typeof(x), y)[1], ys)
re(map(f, func, yfuncs...))
end
_default_walk(f, ::Nothing, ::Nothing) = nothing

###
### FlexibleFunctors.jl
Expand All @@ -112,9 +100,7 @@ function makeflexiblefunctor(m::Module, T, pfield)
func = NamedTuple{pfields}(map(p -> getproperty(x, p), pfields))
return func, re
end

end

end

function flexiblefunctorm(T, pfield = :params)
Expand Down
137 changes: 108 additions & 29 deletions test/basics.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,14 @@
struct Foo
x
y
end

struct Foo; x; y; end
@functor Foo

struct Bar
x
end
struct Bar; x; end
@functor Bar

struct Baz
x
y
z
end
@functor Baz (y,)
struct OneChild3; x; y; z; end
@functor OneChild3 (y,)

struct NoChildren
x
y
end
struct NoChildren2; x; y; end

@static if VERSION >= v"1.6"
@testset "ComposedFunction" begin
Expand All @@ -31,6 +20,10 @@ end
end
end

###
### Basic functionality
###

@testset "Nested" begin
model = Bar(Foo(1, [1, 2, 3]))

Expand All @@ -53,20 +46,48 @@ end
@test fmap(f, x; exclude = x -> x isa AbstractArray) == x
end

@testset "Property list" begin
model = OneChild3(1, 2, 3)
model′ = fmap(x -> 2x, model)

@test (model′.x, model′.y, model′.z) == (1, 4, 3)
end

@testset "cache" begin
shared = [1,2,3]
m1 = Foo(shared, Foo([1,2,3], Foo(shared, [1,2,3])))
m1f = fmap(float, m1)
@test m1f.x === m1f.y.y.x
@test m1f.x !== m1f.y.x
m1p = fmapstructure(identity, m1; prune = nothing)
@test m1p == (x = [1, 2, 3], y = (x = [1, 2, 3], y = (x = nothing, y = [1, 2, 3])))

# A non-leaf node can also be repeated:
m2 = Foo(Foo(shared, 4), Foo(shared, 4))
@test m2.x === m2.y
m2f = fmap(float, m2)
@test m2f.x.x === m2f.y.x
m2p = fmapstructure(identity, m2; prune = Bar(0))
@test m2p == (x = (x = [1, 2, 3], y = 4), y = Bar(0))

# Repeated isbits types should not automatically be regarded as shared:
m3 = Foo(Foo(shared, 1:3), Foo(1:3, shared))
m3p = fmapstructure(identity, m3; prune = 0)
@test m3p.y.y == 0
@test_broken m3p.y.x == 1:3
end

###
### Extras
###

@testset "Walk" begin
model = Foo((0, Bar([1, 2, 3])), [4, 5])

model′ = fmapstructure(identity, model)
@test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5])
end

@testset "Property list" begin
model = Baz(1, 2, 3)
model′ = fmap(x -> 2x, model)

@test (model′.x, model′.y, model′.z) == (1, 4, 3)
end

@testset "fcollect" begin
m1 = [1, 2, 3]
m2 = 1
Expand All @@ -78,7 +99,7 @@ end

m1 = [1, 2, 3]
m2 = Bar(m1)
m0 = NoChildren(:a, :b)
m0 = NoChildren2(:a, :b)
m3 = Foo(m2, m0)
m4 = Bar(m3)
@test all(fcollect(m4) .=== [m4, m3, m2, m1, m0])
Expand All @@ -89,6 +110,64 @@ end
@test all(fcollect(m3) .=== [m3, m1, m2])
end

###
### Vararg forms
###

@testset "fmap(f, x, y)" begin
m1 = (x = [1,2], y = 3)
n1 = (x = [4,5], y = 6)
@test fmap(+, m1, n1) == (x = [5, 7], y = 9)

# Reconstruction type comes from the first argument
foo1 = Foo([7,8], 9)
@test fmap(+, m1, foo1) == (x = [8, 10], y = 12)
@test fmap(+, foo1, n1) isa Foo
@test fmap(+, foo1, n1).x == [11, 13]

# Mismatched trees should be an error
m2 = (x = [1,2], y = (a = [3,4], b = 5))
n2 = (x = [6,7], y = 8)
@test_throws Exception fmap(first∘tuple, m2, n2)
@test_throws Exception fmap(first∘tuple, m2, n2)

# The cache uses IDs from the first argument
shared = [1,2,3]
m3 = (x = shared, y = [4,5,6], z = shared)
n3 = (x = shared, y = shared, z = [7,8,9])
@test fmap(+, m3, n3) == (x = [2, 4, 6], y = [5, 7, 9], z = [2, 4, 6])
z3 = fmap(+, m3, n3)
@test z3.x === z3.z
end

@testset "old test update.jl" begin
struct M{F,T,S}
σ::F
W::T
b::S
end

@functor M

(m::M)(x) = m.σ.(m.W * x .+ m.b)

m = M(identity, ones(Float32, 3, 4), zeros(Float32, 3))
x = ones(Float32, 4, 2)
m̄, _ = gradient((m,x) -> sum(m(x)), m, x)
m̂ = Functors.fmap(m, m̄) do x, y
isnothing(x) && return y
isnothing(y) && return x
x .- 0.1f0 .* y
end

@test m̂.W ≈ fill(0.8f0, size(m.W))
@test m̂.b ≈ fill(-0.2f0, size(m.b))
end

###
### FlexibleFunctors.jl
###

struct FFoo
x
y
Expand All @@ -102,13 +181,13 @@ struct FBar
end
@flexiblefunctor FBar p

struct FBaz
struct FOneChild4
x
y
z
p
end
@flexiblefunctor FBaz p
@flexiblefunctor FOneChild4 p

@testset "Flexible Nested" begin
model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,))
Expand All @@ -132,7 +211,7 @@ end
end

@testset "Flexible Property list" begin
model = FBaz(1, 2, 3, (:x, :z))
model = FOneChild4(1, 2, 3, (:x, :z))
model′ = fmap(x -> 2x, model)

@test (model′.x, model′.y, model′.z) == (2, 2, 6)
Expand All @@ -147,7 +226,7 @@ end
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3])
@test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4])

m0 = NoChildren(:a, :b)
m0 = NoChildren2(:a, :b)
m1 = [1, 2, 3]
m2 = FBar(m1, ())
m3 = FFoo(m2, m0, (:x, :y,))
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using Zygote

include("basics.jl")
include("base.jl")
include("update.jl")

if VERSION < v"1.6" # || VERSION > v"1.7-"
@warn "skipping doctests, on Julia $VERSION"
Expand Down
23 changes: 0 additions & 23 deletions test/update.jl

This file was deleted.