diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index 47ef8549..7d0aa4ae 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -9,8 +9,8 @@ on: jobs: test: runs-on: ${{ matrix.os }} - continue-on-error: ${{ matrix.version == 'nightly' }} strategy: + fail-fast: false matrix: version: - '1.6' @@ -23,13 +23,13 @@ jobs: AD: - Enzyme - ForwardDiff - - Tapir + - Mooncake - Tracker - ReverseDiff - Zygote exclude: - version: 1.6 - AD: Tapir + AD: Mooncake # TODO(mhauru) Hopefully can enable Enzyme on older versions at some point, see # discussion in https://github.com/TuringLang/Bijectors.jl/pull. - version: 1.6 diff --git a/.github/workflows/Interface.yml b/.github/workflows/Interface.yml index ef1f4dc7..b305124e 100644 --- a/.github/workflows/Interface.yml +++ b/.github/workflows/Interface.yml @@ -10,8 +10,8 @@ on: jobs: test: runs-on: ${{ matrix.os }} - continue-on-error: ${{ matrix.version == 'nightly' }} strategy: + fail-fast: false matrix: version: - '1.6' diff --git a/Project.toml b/Project.toml index d13289bf..283b706c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.13.18" +version = "0.14.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -26,21 +26,22 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] BijectorsDistributionsADExt = "DistributionsAD" -BijectorsEnzymeExt = "Enzyme" +BijectorsEnzymeExt = ["Enzyme", "EnzymeCore"] BijectorsForwardDiffExt = "ForwardDiff" BijectorsLazyArraysExt = "LazyArrays" BijectorsReverseDiffExt = "ReverseDiff" +BijectorsMooncakeExt = "Mooncake" BijectorsTrackerExt = "Tracker" -BijectorsTapirExt = "Tapir" BijectorsZygoteExt = "Zygote" [compat] @@ -53,6 +54,7 @@ Distributions = "0.25.33" DistributionsAD = "0.6" DocStringExtensions = "0.9" Enzyme = "0.12.22" +EnzymeCore = "0.7.8" ForwardDiff = "0.10" Functors = "0.1, 0.2, 0.3, 0.4" InverseFunctions = "0.1" @@ -65,7 +67,7 @@ Requires = "0.5, 1" ReverseDiff = "1" Roots = "1.3.4, 2" Statistics = "1" -Tapir = "0.2.23" +Mooncake = "0.4.19" Tracker = "0.2" Zygote = "0.6.63" julia = "1.6" @@ -73,9 +75,10 @@ julia = "1.6" [extras] DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/ext/BijectorsEnzymeExt.jl b/ext/BijectorsEnzymeExt.jl index 1e8d8aa3..303fd92f 100644 --- a/ext/BijectorsEnzymeExt.jl +++ b/ext/BijectorsEnzymeExt.jl @@ -1,14 +1,18 @@ module BijectorsEnzymeExt if isdefined(Base, :get_extension) - using Enzyme: @import_frule, @import_rrule + using Enzyme: @import_rrule, @import_frule using Bijectors: find_alpha else - using ..Enzyme: @import_frule, @import_rrule + using ..Enzyme: @import_rrule, @import_frule using ..Bijectors: find_alpha end -@import_rrule typeof(find_alpha) Real Real Real -@import_frule typeof(find_alpha) Real Real Real - +@static if v"1.11.1" <= VERSION < v"1.12" + @warn "Bijectors and Enzyme do not work together on Julia $VERSION" +else + @import_rrule typeof(find_alpha) Real Real Real + @import_frule typeof(find_alpha) Real Real Real end + +end # module diff --git a/ext/BijectorsTapirExt.jl b/ext/BijectorsMooncakeExt.jl similarity index 77% rename from ext/BijectorsTapirExt.jl rename to ext/BijectorsMooncakeExt.jl index 70805a82..d7285bf6 100644 --- a/ext/BijectorsTapirExt.jl +++ b/ext/BijectorsMooncakeExt.jl @@ -1,10 +1,11 @@ -module BijectorsTapirExt +module BijectorsMooncakeExt if isdefined(Base, :get_extension) - using Tapir: @is_primitive, MinimalCtx, Tapir, CoDual, primal, tangent_type, @from_rrule + using Mooncake: + @is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_rrule using Bijectors: find_alpha, ChainRulesCore else - using ..Tapir: @is_primitive, MinimalCtx, Tapir, primal, tangent_type, @from_rrule + using ..Mooncake: @is_primitive, MinimalCtx, Mooncake, primal, tangent_type, @from_rrule using ..Bijectors: find_alpha, ChainRulesCore end @@ -19,20 +20,20 @@ end # unusual Integer type is encountered. @is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat}) -function Tapir.rrule!!( +function Mooncake.rrule!!( ::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I} ) where {P<:Base.IEEEFloat,I<:Integer} # Require that the integer is non-differentiable. - if tangent_type(I) != Tapir.NoTangent + if tangent_type(I) != Mooncake.NoTangent msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent." throw(ArgumentError(msg)) end out, pb = ChainRulesCore.rrule(find_alpha, primal(x), primal(y), primal(z)) function find_alpha_pb(dout::P) _, dx, dy, _ = pb(dout) - return Tapir.NoRData(), P(dx), P(dy), Tapir.NoRData() + return Mooncake.NoRData(), P(dx), P(dy), Mooncake.NoRData() end - return Tapir.zero_fcodual(out), find_alpha_pb + return Mooncake.zero_fcodual(out), find_alpha_pb end end diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index bcdb9523..a2c13df1 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -27,9 +27,9 @@ end test_frule(Bijectors.find_alpha, x, y, z) test_rrule(Bijectors.find_alpha, x, y, z) - if @isdefined Tapir + if @isdefined Mooncake rng = Xoshiro(123456) - Tapir.TestUtils.test_rule( + Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, x, @@ -37,9 +37,9 @@ end z; is_primitive=true, perf_flag=:none, - interp=Tapir.TapirInterpreter(), + interp=Mooncake.MooncakeInterpreter(), ) - Tapir.TestUtils.test_rule( + Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, x, @@ -47,9 +47,9 @@ end 3; is_primitive=true, perf_flag=:none, - interp=Tapir.TapirInterpreter(), + interp=Mooncake.MooncakeInterpreter(), ) - Tapir.TestUtils.test_rule( + Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, x, @@ -57,7 +57,7 @@ end UInt32(3); is_primitive=true, perf_flag=:none, - interp=Tapir.TapirInterpreter(), + interp=Mooncake.MooncakeInterpreter(), ) end diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 3e21e693..2e709491 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -7,7 +7,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) b in ( :ForwardDiff, :Zygote, - :Tapir, + :Mooncake, :ReverseDiff, :Enzyme, :EnzymeForward, @@ -78,27 +78,39 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) end end - if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10" - rule = Tapir.build_rrule(f, x; safety_on=false) - if :tapir in broken - @test_broken( - isapprox( - Tapir.value_and_gradient!!(rule, f, x)[2][2], - finitediff; - rtol=rtol, - atol=atol, - ) - ) - else - @test( - isapprox( - Tapir.value_and_gradient!!(rule, f, x)[2][2], - finitediff; - rtol=rtol, - atol=atol, - ) - ) + if (AD == "All" || AD == "Mooncake") && VERSION >= v"1.10" + try + Mooncake.build_rrule(f, x) + catch exc + # TODO(penelopeysm): + # @test_throws AssertionError (expr...) doesn't work, unclear why + @test exc isa AssertionError end + # TODO: The above @test_throws happens because of + # https://github.com/compintell/Mooncake.jl/issues/319. If that test + # fails, it probably means that the issue was fixed, in which case + # we can remove that block and uncomment the following instead. + + # rule = Mooncake.build_rrule(f, x) + # if :Mooncake in broken + # @test_broken ( + # isapprox( + # Mooncake.value_and_gradient!!(rule, f, x)[2][2], + # finitediff; + # rtol=rtol, + # atol=atol, + # ) + # ) + # else + # @test( + # isapprox( + # Mooncake.value_and_gradient!!(rule, f, x)[2][2], + # finitediff; + # rtol=rtol, + # atol=atol, + # ) + # ) + # end end return nothing diff --git a/test/bijectors/ordered.jl b/test/bijectors/ordered.jl index b2115fe2..60354005 100644 --- a/test/bijectors/ordered.jl +++ b/test/bijectors/ordered.jl @@ -127,12 +127,12 @@ end end end # Check that the quantiles are reasonable, i.e. within - # 5 standard errors of the true quantiles (and that the MCSE is + # 6 standard errors of the true quantiles (and that the MCSE is # not too large). for i in 1:k for j in 1:length(qts) @test qs_mcse[i, j] < abs(qs_true[i, end] - qs_true[i, 1]) / 2 - @test abs(qs[i, j] - qs_true[i, j]) < 5 * qs_mcse[i, j] + @test abs(qs[i, j] - qs_true[i, j]) < 6 * qs_mcse[i, j] end end end diff --git a/test/runtests.jl b/test/runtests.jl index 914c0e32..638bd15c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,12 +34,12 @@ if VERSION < v"1.9" using Compat: stack end -# Sadly, Tapir.jl cannot be installed on version 1.6, so we have to add it if we're testing +# Sadly, Mooncake.jl cannot be installed on version 1.6, so we have to add it if we're testing # on at least version 1.10. if VERSION >= v"1.10" using Pkg - Pkg.add("Tapir") - using Tapir + Pkg.add("Mooncake") + using Mooncake end const GROUP = get(ENV, "GROUP", "All")