Skip to content

Commit 5397d28

Browse files
authored
Run everything on 1.10 (#385)
* Run everything on 1.10 * Expand buildkite runner into matrix * Also run integration and ext testing on LTS * Try LTS * Just write 1.10 * Julia version number in matrix * Document support policy * Document support policy * Restrict Julia compat properly * Fix DynamicPPL tests on 1.10 * Simplify gitignore * Formatting * Relax performance bounds on nnlib scatter * Improve error message * Fix a problem * Fix error introduced in the last PR * Fix coverage * Fix coverage * Fix LuxLib * Prevent small union inlining * use static if everywhere * Revert previous change * Patch the problem * Run error printing * Bump patch
1 parent eaf1fb8 commit 5397d28

20 files changed

+170
-48
lines changed

.buildkite/pipeline.yml

+5-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ env:
22
SECRET_CODECOV_TOKEN: "nkcRFVXdaPNAbiI0x3qK/XUG8rWjBc8fU73YEyP35SeS465XORqrIYrHUbHuJTRyeyqNRdsHaBcV1P7TBbKAaTQAjHQ1Q0KYfd0uRMSWpZSCgTBz5AwttAxVfFrX+Ky3PzTi2TfDe0uPFZtFo0Asq6sUEr1on+Oo+j+q6br2NK6CrA5yKKuTX4Q2V/UPOIK4vNXY3+zDTKSNtr+HQOlcVEeRIk/0ZQ78Cjd52flEaVw8GWo/CC4YBzLtcOZgaFdgOTEDNHMr0mw6zLE4Y6nxq4lHVSoraSjxjhkB0pXTZ1c51yHX8Jc+q6HC5s87+2Zq5YtsuQSGao+eMtkTAYwfLw==;U2FsdGVkX18z27J3+gNgxsPNnXA0ad4LvZnXeohTam7/6UPqX5+3BYI0tAiVkCho4vlJyL7dd8JEyNtk9BFXsg=="
33

44
steps:
5-
- label: "Julia v1"
5+
- label: "Julia v{{matrix}}"
66
plugins:
77
- JuliaCI/julia#v1:
8-
version: "1"
8+
version: "{{matrix}}"
99
- JuliaCI/julia-coverage#v1:
1010
dirs:
1111
- src
@@ -19,3 +19,6 @@ steps:
1919
env:
2020
LABEL: cuda
2121
TEST_TYPE: ext
22+
matrix:
23+
- "1"
24+
- "1.10"

.github/workflows/CI.yml

+2-7
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,11 @@ jobs:
4040
'rrules/twice_precision',
4141
]
4242
version:
43+
- 'lts'
4344
- '1'
4445
arch:
4546
- x64
4647
include:
47-
- test_group: 'basic'
48-
version: '1.10'
49-
arch: x64
5048
- test_group: 'basic'
5149
version: '1.10'
5250
arch: x86
@@ -95,12 +93,9 @@ jobs:
9593
]
9694
version:
9795
- '1'
96+
- 'lts'
9897
arch:
9998
- x64
100-
include:
101-
- test_group: {test_type: 'integration_testing', label: 'turing'}
102-
version: '1.10'
103-
arch: x64
10499
steps:
105100
- uses: actions/checkout@v4
106101
- uses: julia-actions/setup-julia@v2

.gitignore

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
/Manifest.toml
2-
/Manifest-v1.11.toml
1+
Manifest*
32
dev
4-
bench/Manifest.toml
53
analysis_results
64
.vscode
75
profile.pb.gz
86
scratch.jl
97
docs/build/
108
docs/site/
11-
docs/Manifest.toml
12-
Manifest.toml

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Mooncake"
22
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
33
authors = ["Will Tebbutt, Hong Ge, and contributors"]
4-
version = "0.4.48"
4+
version = "0.4.49"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -66,7 +66,7 @@ Setfield = "1"
6666
SpecialFunctions = "2"
6767
StableRNGs = "1"
6868
Test = "1"
69-
julia = "1"
69+
julia = "1.10"
7070

7171
[extras]
7272
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ If you encounter a new version of `Mooncake.jl` in the wild, please consult this
1818

1919
# Getting Started
2020

21+
Check that you're running a version of Julia that Mooncake.jl supports.
22+
See the `SUPPORT_POLICY.md` file for more info.
23+
2124
There are several ways to interact with `Mooncake.jl`.
2225
The one that we recommend people begin with is [`DifferentiationInterface.jl`](https://github.com/gdalle/DifferentiationInterface.jl/).
2326
For example, use it as follows to compute the gradient of a function mapping a `Vector{Float64}` to `Float64`.

SUPPORT_POLICY.md

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Summary
2+
3+
At any given point in time, `Mooncake.jl` supports the current Long Term Support (LTS) release of Julia, and the latest release version of Julia 1.
4+
Consequently, the versions of Julia which are officially supported by `Mooncake.jl` will change (almost) _immediately_ whenever a new Julia LTS version is declared, or a minor release of Julia is made.
5+
6+
For example, the LTS is 1.10 and the latest release is 1.11 at the time of writing. When 1.12 is released, we will
7+
1. bump the Julia compat bounds in `Mooncake.jl` to require either 1.10 or 1.12,
8+
1. cease to run CI on 1.11,
9+
1. cease to provide bug fixes for 1.11,
10+
1. cease to accept 1.11-specific bug fixes, as we will not be running CI for 1.11 and therefore will not be able to test that they have worked.
11+
12+
In short: as far as `Mooncake.jl`'s future releases, 1.11 ceases to exist the moment 1.12 is released.
13+
14+
Note that these changes are not applied retrospectively to existing releases of `Mooncake.jl`.
15+
Suppose that `Mooncake.jl` is at `v0.4.50` when 1.12 is released.
16+
Then the above changes would be relevant to `Mooncake.jl` versions `v0.4.51` and higher.
17+
18+
# Patch Versions
19+
20+
The above only discussed minor versions of Julia (1.10, 1.11, 1.12, etc).
21+
However, it also applies to patch versions of Julia.
22+
For example, at the time of writing, Julia version 1.10.6 is _actually_ the LTS, and 1.11.1 the current release of Julia.
23+
The moment that 1.10.7 is released, we will cease to run any CI on 1.10.6, and will not accept fixes for it.
24+
The same is true of 1.11.2.
25+
26+
Since patch releases of Julia are less invasive than minor releases, this should generally not cause users problems.
27+
28+
# Context
29+
30+
In order to support a particular version of Julia, we must
31+
1. always run CI for that version,
32+
1. accept and proactively produce fixes for that version,
33+
1. maintain version-specific code in the `Mooncake.jl` codebase.
34+
35+
This requires a surprisingly amount of overhead to the development of `Mooncake.jl`, and has the potential to substantially increase the complexity of the codebase.
36+
All of this makes it harder to improve `Mooncake.jl`.
37+
Consequently, this policy represents a decision to tradeoff support for a range of minor Julia versions in exchange for easing the development burden associated to `Mooncake.jl`.
38+
39+
## Why not gently drop support?
40+
41+
In the JuliaGaussianProcesses ecosystem, we had a loosely-defined policy of keeping support for an older version until we ran into a large problem which could not be fixed easily, at which point we would drop support.
42+
While this sounds appealing, in practice it makes it hard to know exactly when to drop support for a particular version of Julia, increases the burden for maintainers, and makes it hard for users to know exactly what to expect.

ext/MooncakeLuxLibSLEEFPiratesExtension.jl

+27
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using LuxLib, Mooncake, SLEEFPirates
44
using Base: IEEEFloat
55
using Mooncake: @from_rrule, DefaultCtx
66

7+
@static if VERSION >= v"1.11"
8+
79
# Workaround for package load order problems. See
810
# https://github.com/JuliaLang/julia/issues/56204#issuecomment-2419553167 for more context.
911
function __init__()
@@ -31,4 +33,29 @@ function __init__()
3133
end
3234
end
3335

36+
else
37+
38+
for f in Any[
39+
LuxLib.NNlib.sigmoid_fast,
40+
LuxLib.NNlib.softplus,
41+
LuxLib.NNlib.logsigmoid,
42+
LuxLib.NNlib.swish,
43+
LuxLib.NNlib.lisht,
44+
Base.tanh,
45+
LuxLib.NNlib.tanh_fast,
46+
]
47+
f_fast = LuxLib.Impl.sleefpirates_fast_act(f)
48+
@eval @from_rrule DefaultCtx Tuple{typeof($f_fast), IEEEFloat}
49+
@eval @from_rrule(
50+
DefaultCtx,
51+
Tuple{
52+
typeof(Broadcast.broadcasted),
53+
typeof($f_fast),
54+
Union{IEEEFloat, Array{<:IEEEFloat}},
55+
},
56+
)
57+
end
58+
59+
end
60+
3461
end

src/interpreter/abstract_interpretation.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ function CC.method_table(interp::MooncakeInterpreter)
107107
return CC.OverlayMethodTable(interp.world, mooncake_method_table)
108108
end
109109

110-
if VERSION < v"1.11.0"
110+
@static if VERSION < v"1.11.0"
111111
CC.get_world_counter(interp::MooncakeInterpreter) = interp.world
112112
get_inference_world(interp::CC.AbstractInterpreter) = CC.get_world_counter(interp)
113113
else
@@ -160,7 +160,7 @@ function Core.Compiler.abstract_call_gf_by_type(
160160
end
161161
end
162162

163-
if VERSION < v"1.11-"
163+
@static if VERSION < v"1.11-"
164164

165165
function CC.inlining_policy(
166166
interp::MooncakeInterpreter{C},

src/interpreter/contexts.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Observe that this information means that whether or not something is a primitive
3232
particular context depends only on static information, not any run-time information that
3333
might live in a particular instance of `Ctx`.
3434
"""
35-
is_primitive(::Type{MinimalCtx}, ::Any) = false
35+
is_primitive(::Type{MinimalCtx}, sig::Type{<:Tuple}) = false
3636
is_primitive(::Type{DefaultCtx}, sig) = is_primitive(MinimalCtx, sig)
3737

3838
"""

src/interpreter/ir_normalisation.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ __get_arg(x::QuoteNode) = x.value
214214
__get_arg(x) = x
215215

216216
# memoryrefget and memoryrefset! were introduced in 1.11.
217-
if VERSION >= v"1.11-"
217+
@static if VERSION >= v"1.11-"
218218

219219
"""
220220
lift_memoryrefget_and_memoryrefset_builtins(inst)

src/interpreter/ir_utils.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true)
190190
ir = __strip_coverage!(ir)
191191
ir = CC.sroa_pass!(ir, inline_state)
192192

193-
if VERSION < v"1.11-"
193+
@static if VERSION < v"1.11-"
194194
ir = CC.adce_pass!(ir, inline_state)
195195
else
196196
ir, _ = CC.adce_pass!(ir, inline_state)
@@ -227,7 +227,7 @@ function lookup_ir(interp::CC.AbstractInterpreter, tt::Type{<:Tuple}; optimize_u
227227
asts = []
228228
for match in get_matches(matches.matches)
229229
match = match::Core.MethodMatch
230-
if VERSION < v"1.11-"
230+
@static if VERSION < v"1.11-"
231231
meth = Base.func_for_method_checked(match.method, tt, match.sparams)
232232
(code, ty) = CC.typeinf_ircode(
233233
interp, meth, match.spec_types, match.sparams, optimize_until

src/interpreter/s2s_reverse_mode_ad.jl

+47-11
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,10 @@ get_primal_type(::ADInfo, x) = _typeof(x)
213213
function get_primal_type(::ADInfo, x::GlobalRef)
214214
return isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty
215215
end
216+
function get_primal_type(::ADInfo, x::Expr)
217+
x.head === :boundscheck && return Bool
218+
error("Unrecognised expression $x found in argument slot.")
219+
end
216220

217221
"""
218222
get_rev_data_id(info::ADInfo, x)
@@ -394,7 +398,11 @@ function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo)
394398
rvs = new_inst(Expr(:call, increment_ref!, rdata_id, Argument(2)))
395399
return ad_stmt_info(line, nothing, inc_args(stmt), rvs)
396400
else
397-
fwds = ReturnNode(const_codual(stmt.val, info))
401+
const_id = ID()
402+
fwds = [
403+
(const_id, new_inst(const_codual_stmt(stmt.val, info))),
404+
(ID(), new_inst(ReturnNode(const_id))),
405+
]
398406
return ad_stmt_info(line, nothing, fwds, nothing)
399407
end
400408
end
@@ -457,7 +465,11 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo)
457465
else
458466
# If the value of the PiNode is a constant / QuoteNode etc, then there is nothing to
459467
# do on the reverse-pass.
460-
fwds = PiNode(const_codual(stmt.val, info), fcodual_type(_type(stmt.typ)))
468+
const_id = ID()
469+
fwds = [
470+
(const_id, new_inst(const_codual_stmt(stmt.val, info))),
471+
(line, new_inst(PiNode(const_id, fcodual_type(_type(stmt.typ))))),
472+
]
461473
rvs = nothing
462474
end
463475

@@ -475,11 +487,11 @@ end
475487
function make_ad_stmts!(stmt::GlobalRef, line::ID, info::ADInfo)
476488
isconst(stmt) && return const_ad_stmt(stmt, line, info)
477489

478-
x = const_codual(getglobal(stmt.mod, stmt.name), info)
479-
globalref_id = ID()
490+
const_id, globalref_id = ID(), ID()
480491
fwds = [
481492
(globalref_id, new_inst(stmt)),
482-
(line, new_inst(Expr(:call, __verify_const, globalref_id, x))),
493+
(const_id, new_inst(const_codual_stmt(getglobal(stmt.mod, stmt.name), info))),
494+
(line, new_inst(Expr(:call, __verify_const, globalref_id, const_id))),
483495
]
484496
return ad_stmt_info(line, nothing, fwds, nothing)
485497
end
@@ -502,8 +514,22 @@ make_ad_stmts!(stmt, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info)
502514
Implementation of `make_ad_stmts!` used for constants.
503515
"""
504516
function const_ad_stmt(stmt, line::ID, info::ADInfo)
505-
x = const_codual(stmt, info)
506-
return ad_stmt_info(line, nothing, x isa ID ? Expr(:call, identity, x) : x, nothing)
517+
return ad_stmt_info(line, nothing, const_codual_stmt(stmt, info), nothing)
518+
end
519+
520+
"""
521+
const_codual_stmt(stmt, info::ADInfo)
522+
523+
Returns a `:call` expression which will return a `CoDual` whose primal is `stmt`, and whose
524+
tangent is whatever `uninit_tangent` returns.
525+
"""
526+
function const_codual_stmt(stmt, info::ADInfo)
527+
v = get_const_primal_value(stmt)
528+
if safe_for_literal(v)
529+
return Expr(:call, uninit_fcodual, v)
530+
else
531+
return Expr(:call, identity, add_data!(info, uninit_fcodual(v)))
532+
end
507533
end
508534

509535
"""
@@ -519,10 +545,21 @@ function const_codual(stmt, info::ADInfo)
519545
return safe_for_literal(v) ? x : add_data!(info, x)
520546
end
521547

522-
safe_for_literal(v) = v isa String || v isa Type || isbitstype(_typeof(v))
548+
function safe_for_literal(v)
549+
v isa Expr && v.head === :boundscheck && return true
550+
v isa String && return true
551+
v isa Type && return true
552+
v isa Tuple && all(safe_for_literal, v) && return true
553+
isbitstype(_typeof(v)) && return true
554+
return false
555+
end
523556

524557
inc_or_const(stmt, info::ADInfo) = is_active(stmt) ? __inc(stmt) : const_codual(stmt, info)
525558

559+
function inc_or_const_stmt(stmt, info::ADInfo)
560+
return is_active(stmt) ? Expr(:call, identity, __inc(stmt)) : const_codual_stmt(stmt, info)
561+
end
562+
526563
"""
527564
get_const_primal_value(x::GlobalRef)
528565
@@ -616,7 +653,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo)
616653
# Make arguments to rrule call. Things which are not already CoDual must be made so.
617654
codual_arg_ids = map(_ -> ID(), collect(args))
618655
codual_args = map(args, codual_arg_ids) do arg, id
619-
return (id, new_inst(Expr(:call, identity, inc_or_const(arg, info))))
656+
return (id, new_inst(inc_or_const_stmt(arg, info)))
620657
end
621658

622659
# Make call to rule.
@@ -691,8 +728,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo)
691728

692729
elseif Meta.isexpr(stmt, :copyast)
693730
# Get constant out and shove it in shared storage.
694-
x = const_codual(stmt.args[1], info)
695-
return ad_stmt_info(line, nothing, Expr(:call, identity, x), nothing)
731+
return ad_stmt_info(line, nothing, const_codual_stmt(stmt.args[1], info), nothing)
696732

697733
elseif Meta.isexpr(stmt, :loopinfo)
698734
# Cannot pass loopinfo back through the optimiser for some reason.

src/rrules/builtins.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,17 @@ function rrule!!(f::CoDual{<:Core.Builtin}, args...)
2929
"which is specialised to this case. " *
3030
"Either way, please consider commenting on " *
3131
"https://github.com/compintell/Mooncake.jl/issues/208/ so that the issue can be " *
32-
"fixed more widely."
32+
"fixed more widely.\n" *
33+
"For reproducibility, note that the full signature is:\n" *
34+
"$(typeof((f, args...)))"
3335
))
3436
end
3537

38+
function Base.showerror(io::IO, err::MissingRuleForBuiltinException)
39+
print(io, "MissingRuleForBuiltinException: ")
40+
println(io, err.msg)
41+
end
42+
3643
module IntrinsicsWrappers
3744

3845
using Base: IEEEFloat

src/rrules/fastmath.jl

+3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ function rrule!!(::CoDual{typeof(Base.FastMath.sincos)}, x::CoDual{P}) where {P<
2626
return CoDual(y, NoFData()), sincos_fast_adj!!
2727
end
2828

29+
@is_primitive MinimalCtx Tuple{typeof(Base.log), Union{IEEEFloat, Int}}
30+
@zero_adjoint MinimalCtx Tuple{typeof(log), Int}
31+
2932
function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:fastmath})
3033
test_cases = Any[
3134
(false, :stability_and_allocs, nothing, Base.FastMath.exp10_fast, 0.5),

src/rrules/misc.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
# Required to avoid an ambiguity.
3535
@zero_adjoint MinimalCtx Tuple{Type{Symbol}, TypeVar, Type}
3636

37-
if VERSION >= v"1.11-"
37+
@static if VERSION >= v"1.11-"
3838
@zero_adjoint MinimalCtx Tuple{typeof(Random.hash_seed), Vararg}
3939
@zero_adjoint MinimalCtx Tuple{typeof(Base.dataids), Memory}
4040
end

src/test_utils.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ end
610610
_new_excluded(::Type) = false
611611
_new_excluded(::Type{<:Union{String}}) = true
612612

613-
if VERSION < v"1.11-"
613+
@static if VERSION < v"1.11-"
614614
# Prior to 1.11, Arrays are special objects, with special constructors that don't
615615
# involve calling the `:new` instruction. From 1.11 onwards, they behave more like
616616
# regular mutable composite types, so calling `_new_` becomes meaningful.

0 commit comments

Comments
 (0)