diff --git a/Project.toml b/Project.toml index de630d6..5296f32 100644 --- a/Project.toml +++ b/Project.toml @@ -23,7 +23,8 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Zygote"] +test = ["Test", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Tracker", "Zygote"] diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 5767fdf..deebc2a 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -647,6 +647,7 @@ function __init__() @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl") @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("reversediff.jl") @require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("finitedifferences.jl") + @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("tracker.jl") end end diff --git a/src/tracker.jl b/src/tracker.jl new file mode 100644 index 0000000..4f2173b --- /dev/null +++ b/src/tracker.jl @@ -0,0 +1,33 @@ +using .Tracker: Tracker + +""" + TrackerBackend + +AD backend that uses reverse mode with Tracker.jl. +""" +struct TrackerBackend <: AbstractReverseMode end + +function second_lowest(::TrackerBackend) + return throw(ArgumentError("Tracker backend does not support nested differentiation.")) +end + +@primitive function pullback_function(ba::TrackerBackend, f, xs...) + value, back = Tracker.forward(f, xs...) + function pullback(ws) + if ws isa Tuple && !(value isa Tuple) + @assert length(ws) == 1 + map(Tracker.data, back(ws[1])) + else + map(Tracker.data, back(ws)) + end + end + return pullback +end + +function derivative(ba::TrackerBackend, f, xs::Number...) + return Tracker.data.(Tracker.gradient(f, xs...)) +end + +function gradient(ba::TrackerBackend, f, xs::AbstractVector...) + return Tracker.data.(Tracker.gradient(f, xs...)) +end diff --git a/test/runtests.jl b/test/runtests.jl index d72abb3..0348523 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,4 +7,5 @@ using Test include("forwarddiff.jl") include("reversediff.jl") include("finitedifferences.jl") + include("tracker.jl") end diff --git a/test/tracker.jl b/test/tracker.jl new file mode 100644 index 0000000..6f09c0f --- /dev/null +++ b/test/tracker.jl @@ -0,0 +1,37 @@ +using AbstractDifferentiation +using Test +using Tracker + +@testset "TrackerBackend" begin + backends = [@inferred(AD.TrackerBackend())] + @testset for backend in backends + @testset "errors when nested" begin + @test_throws ArgumentError AD.second_lowest(backend) + @test_throws ArgumentError AD.hessian(backend, sum, randn(3)) + end + @testset "Derivative" begin + test_derivatives(backend) + end + @testset "Gradient" begin + test_gradients(backend) + end + @testset "Jacobian" begin + test_jacobians(backend) + end + @testset "jvp" begin + test_jvp(backend) + end + @testset "j′vp" begin + test_j′vp(backend) + end + @testset "Lazy Derivative" begin + test_lazy_derivatives(backend) + end + @testset "Lazy Gradient" begin + test_lazy_gradients(backend) + end + @testset "Lazy Jacobian" begin + test_lazy_jacobians(backend) + end + end +end