Skip to content

Commit a5cb83b

Browse files
gdallewilltebbutt
andauthored
Clarify public symbols (#477)
* Clarify public symbols * Imports * Imports * macros * Format * Add preparation * Handle macros * Tweak imports * Refine export list * Fix CUDA loading * Fix benchmark loading * Fix docs * Actually fix docs build * Formatting * Do not export Config * Tweak canonical settings * Do interface properly * Sort out canonicalisation * Bump patch version * Bump patch version again * Bump patch _again_ --------- Co-authored-by: willtebbutt <[email protected]> Co-authored-by: Will Tebbutt <[email protected]>
1 parent 75785aa commit a5cb83b

21 files changed

+211
-82
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Mooncake"
22
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
33
authors = ["Will Tebbutt, Hong Ge, and contributors"]
4-
version = "0.4.97"
4+
version = "0.4.98"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

bench/run_benchmarks.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ using Mooncake:
2323
generate_hand_written_rrule!!_test_cases,
2424
generate_derived_rrule!!_test_cases,
2525
TestUtils,
26-
_typeof
26+
_typeof,
27+
primal,
28+
tangent,
29+
zero_codual
2730

2831
using Mooncake.TestUtils: _deepcopy
2932

docs/make.jl

+5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ DocMeta.setdocmeta!(
55
:DocTestSetup,
66
quote
77
using Random, Mooncake
8+
using Mooncake: tangent_type, fdata_type, rdata_type
9+
using Mooncake: zero_tangent
10+
using Mooncake: NoTangent, NoFData, NoRData, MutableTangent, Tangent
11+
using Mooncake: build_rrule, Config
812
end;
913
recursive=true,
1014
)
@@ -30,6 +34,7 @@ makedocs(;
3034
pages=[
3135
"Mooncake.jl" => "index.md",
3236
"Tutorial" => "tutorial.md",
37+
"Interface" => "interface.md",
3338
"Understanding Mooncake.jl" => [
3439
joinpath("understanding_mooncake", "introduction.md"),
3540
joinpath("understanding_mooncake", "algorithmic_differentiation.md"),

docs/src/developer_documentation/developer_tools.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ which save you from having to dig in to the objects created by `build_rrule`.
66

77
Since these provide access to internals, they do not follow the usual rules of semver, and
88
may change without notice!
9-
```@docs
9+
```@docs; canonical=false
1010
Mooncake.primal_ir
1111
Mooncake.fwd_ir
1212
Mooncake.rvs_ir

docs/src/developer_documentation/internal_docstrings.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@ Consequently, they can change between non-breaking changes to Mooncake.jl withou
55

66
The purpose of this is to make it easy for developers to find docstrings straightforwardly via the docs, as opposed to having to ctrl+f through Mooncake.jl's source code, or looking at the docstrings via the Julia REPL.
77

8-
```@autodocs; canonical=false
8+
```@autodocs; canonical=true
99
Modules = [Mooncake]
1010
Public = false
11+
Filter = t -> !(t in [Mooncake.value_and_pullback!!, Mooncake.prepare_pullback_cache, Mooncake.Config])
1112
```
1213

13-
```@docs
14+
```@docs; canonical=true
1415
Mooncake.IntrinsicsWrappers
15-
```
16+
```

docs/src/developer_documentation/misc_internals_notes.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ Last checked: 09/02/2025, Julia v1.10.8 / v1.11.3, Mooncake 0.4.82.
132132
Mooncake handles recursive function calls by delaying code generation for generic function calls until the first time that they are actually run.
133133
The docstring below contains a thorough explanation:
134134

135-
```@docs
135+
```@docs; canonical=false
136136
Mooncake.LazyDerivedRule
137137
```
138138

docs/src/interface.md

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Interface
2+
3+
This is the public interface that day-to-day users of AD are expected to interact with if
4+
for some reason DifferentiationInterface.jl does not suffice.
5+
If you have not tried using Mooncake.jl via DifferentiationInterface.jl, please do so.
6+
See [Tutorial](@ref) for more info.
7+
8+
```@docs; canonical=true
9+
Mooncake.Config
10+
Mooncake.value_and_gradient!!(::Mooncake.Cache, f::F, x::Vararg{Any, N}) where {F, N}
11+
Mooncake.value_and_pullback!!(::Mooncake.Cache, ȳ, f::F, x::Vararg{Any, N}) where {F, N}
12+
Mooncake.prepare_gradient_cache
13+
Mooncake.prepare_pullback_cache
14+
```

docs/src/known_limitations.md

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Mooncake.jl has a number of known qualitative limitations, which we document her
77
```@meta
88
DocTestSetup = quote
99
using Mooncake
10+
using Mooncake: NoTangent, build_rrule
1011
end
1112
```
1213

docs/src/understanding_mooncake/rule_system.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,8 @@ _**Representing Gradients**_
279279
This package assigns to each type in Julia a unique `tangent_type`, the purpose of which is to contain the gradients computed during reverse mode AD.
280280
The extended docstring for [`tangent_type`](@ref) provides the best introduction to the types which are used to represent tangents / gradients.
281281

282-
```@docs
283-
tangent_type(P)
282+
```@docs; canonical=false
283+
Mooncake.tangent_type(P)
284284
```
285285

286286

@@ -295,7 +295,7 @@ Conversely, the gradient w.r.t. a value type resides in another value type.
295295

296296
The following docstring provides the best in-depth explanation.
297297

298-
```@docs
298+
```@docs; canonical=false
299299
Mooncake.fdata_type(T)
300300
```
301301

@@ -327,7 +327,7 @@ Now that you've seen what data structures are used to represent gradients, we ca
327327
```@meta
328328
DocTestSetup = quote
329329
using Mooncake
330-
using Mooncake: CoDual
330+
using Mooncake: CoDual, NoFData, NoRData
331331
import Mooncake: rrule!!
332332
end
333333
```

docs/src/utilities/debug_mode.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ _**The Solution**_
3333
Check that the types of the fdata / rdata associated to arguments are exactly what `tangent_type` / `fdata_type` / `rdata_type` require upon entry to / exit from rules and pullbacks.
3434

3535
This is implemented via `DebugRRule`:
36-
```@docs
36+
```@docs; canonical=false
3737
Mooncake.DebugRRule
3838
```
3939

4040
You can straightforwardly enable it when building a rule via the `debug_mode` kwarg in the following:
41-
```@docs
41+
```@docs; canonical=false
4242
Mooncake.build_rrule
4343
```
4444

docs/src/utilities/debugging_and_mwes.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ In order to debug what is going on when this happens, or to produce an MWE, it i
55

66
We recommend making use of Mooncake.jl's testing functionality to generate your test cases:
77

8-
```@docs
8+
```@docs; canonical=false
99
Mooncake.TestUtils.test_rule
1010
```
1111

docs/src/utilities/defining_rules.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ In this section, we detail some useful strategies which can help you avoid havin
66

77
## Simplfiying Code via Overlays
88

9-
```@docs
9+
```@docs; canonical=false
1010
Mooncake.@mooncake_overlay
1111
```
1212

@@ -15,7 +15,7 @@ Mooncake.@mooncake_overlay
1515
If the above strategy does not work, but you find yourself in the surprisingly common
1616
situation that the adjoint of the derivative of your function is always zero, you can very
1717
straightforwardly write a rule by making use of the following:
18-
```@docs
18+
```@docs; canonical=false
1919
Mooncake.@zero_adjoint
2020
Mooncake.zero_adjoint
2121
```
@@ -28,18 +28,18 @@ There are some instances where it is most convenient to implement a `Mooncake.rr
2828

2929
There is enough similarity between these two systems that most of the boilerplate code can be avoided.
3030

31-
```@docs
31+
```@docs; canonical=false
3232
Mooncake.@from_rrule
3333
```
3434

3535
## Adding Methods To `rrule!!` And `build_primitive_rrule`
3636

3737
If the above strategies do not work for you, you should first implement a method of [`Mooncake.is_primitive`](@ref) for the signature of interest:
38-
```@docs
38+
```@docs; canonical=false
3939
Mooncake.is_primitive
4040
```
4141
Then implement a method of one of the following:
42-
```@docs
42+
```@docs; canonical=false
4343
Mooncake.rrule!!
4444
Mooncake.build_primitive_rrule
4545
```

ext/MooncakeCUDAExt.jl

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Mooncake:
1010
rrule!!,
1111
@is_primitive,
1212
tangent_type,
13+
primal,
1314
tangent,
1415
zero_tangent_internal,
1516
randn_tangent_internal,
@@ -26,6 +27,7 @@ import Mooncake:
2627
increment_and_get_rdata!,
2728
MaybeCache,
2829
IncCache,
30+
NoRData,
2931
StackDict
3032

3133
import Mooncake.TestUtils:

src/Mooncake.jl

+6-28
Original file line numberDiff line numberDiff line change
@@ -143,33 +143,11 @@ include("interface.jl")
143143
include("config.jl")
144144
include("developer_tools.jl")
145145

146-
export primal,
147-
tangent,
148-
randn_tangent,
149-
increment!!,
150-
NoTangent,
151-
Tangent,
152-
MutableTangent,
153-
PossiblyUninitTangent,
154-
set_to_zero!!,
155-
tangent_type,
156-
zero_tangent,
157-
_scale,
158-
_add_to_primal,
159-
_diff,
160-
_dot,
161-
zero_codual,
162-
codual_type,
163-
rrule!!,
164-
build_rrule,
165-
value_and_gradient!!,
166-
value_and_pullback!!,
167-
NoFData,
168-
NoRData,
169-
fdata_type,
170-
rdata_type,
171-
fdata,
172-
rdata,
173-
get_interpreter
146+
# Public, not exported
147+
include("public.jl")
148+
@public Config, value_and_pullback!!, prepare_pullback_cache
149+
150+
# Public, exported
151+
export value_and_gradient!!, prepare_gradient_cache
174152

175153
end

src/config.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Config(; debug_mode=false, silence_debug_messages=false)
33
4-
Configuration struct for use with ADTypes.AutoMooncake.
4+
Configuration struct for use with `ADTypes.AutoMooncake`.
55
"""
66
@kwdef struct Config
77
debug_mode::Bool = false

src/interface.jl

+73-29
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,7 @@ _copy!!(::Number, src::Number) = src
175175
"""
176176
prepare_pullback_cache(f, x...)
177177
178-
WARNING: experimental functionality. Interface subject to change without warning!
179-
180-
Returns a `cache` which can be passed to `value_and_gradient!!`. See the docstring for
181-
`Mooncake.value_and_gradient!!` for more info.
178+
Returns a cache used with [`value_and_pullback!!`](@ref). See that function for more info.
182179
"""
183180
function prepare_pullback_cache(fx...; kwargs...)
184181

@@ -200,18 +197,46 @@ end
200197
"""
201198
value_and_pullback!!(cache::Cache, ȳ, f, x...)
202199
203-
WARNING: experimental functionality. Interface subject to change without warning!
200+
!!! info
201+
If `f(x...)` returns a scalar, you should use [`value_and_gradient!!`](@ref), not this
202+
function.
203+
204+
Computes a 2-tuple. The first element is `f(x...)`, and the second is a tuple containing the
205+
pullback of `f` applied to `ȳ`. The first element is the component of the pullback
206+
associated to any fields of `f`, the second w.r.t the first element of `x`, etc.
207+
208+
There are no restrictions on what `y = f(x...)` is permitted to return. However, `ȳ` must be
209+
an acceptable tangent for `y`. This means that, for example, it must be true that
210+
`tangent_type(typeof(y)) == typeof(ȳ)`.
211+
212+
As with all functionality in Mooncake, if `f` modifes itself or `x`, `value_and_gradient!!`
213+
will return both to their original state as part of the process of computing the gradient.
214+
215+
!!! info
216+
`cache` must be the output of [`prepare_pullback_cache`](@ref), and (fields of) `f` and
217+
`x` must be of the same size and shape as those used to construct the `cache`. This is
218+
to ensure that the gradient can be written to the memory allocated when the `cache` was
219+
built.
220+
221+
!!! warning
222+
`cache` owns any mutable state returned by this function, meaning that mutable
223+
components of values returned by it will be mutated if you run this function again with
224+
different arguments. Therefore, if you need to keep the values returned by this function
225+
around over multiple calls to this function with the same `cache`, you should take a
226+
copy (using `copy` or `deepcopy`) of them before calling again.
227+
228+
# Example Usage
229+
```jldoctest
230+
f(x, y) = sum(x .* y)
231+
x = [2.0, 2.0]
232+
y = [1.0, 1.0]
233+
cache = Mooncake.prepare_pullback_cache(f, x, y)
234+
Mooncake.value_and_pullback!!(cache, 1.0, f, x, y)
204235
205-
Like other methods of `value_and_pullback!!`, but makes use of the `cache` object returned
206-
by [`prepare_pullback_cache`](@ref) in order to avoid having to re-allocate various tangent
207-
objects repeatedly. You must ensure that `f` and `x` are the same types and sizes as those
208-
used to construct `cache`.
236+
# output
209237
210-
Warning: `cache` owns any mutable state returned by this function, meaning that mutable
211-
components of values returned by it will be mutated if you run this function again with
212-
different arguments. Therefore, if you need to keep the values returned by this function
213-
around over multiple calls to this function with the same `cache`, you should take a copy
214-
(using `copy` or `deepcopy`) of them before calling again.
238+
(4.0, (NoTangent(), [1.0, 1.0], [2.0, 2.0]))
239+
```
215240
"""
216241
function value_and_pullback!!(cache::Cache, ȳ, f::F, x::Vararg{Any,N}) where {F,N}
217242
tangents = tuple_map(set_to_zero!!, cache.tangents)
@@ -222,10 +247,7 @@ end
222247
"""
223248
prepare_gradient_cache(f, x...)
224249
225-
WARNING: experimental functionality. Interface subject to change without warning!
226-
227-
Returns a `cache` which can be passed to `value_and_gradient!!`. See the docstring for
228-
`Mooncake.value_and_gradient!!` for more info.
250+
Returns a cache used with [`value_and_gradient!!`](@ref). See that function for more info.
229251
"""
230252
function prepare_gradient_cache(fx...; kwargs...)
231253
rule = build_rrule(fx...; kwargs...)
@@ -236,20 +258,42 @@ function prepare_gradient_cache(fx...; kwargs...)
236258
end
237259

238260
"""
239-
value_and_gradient!!(cache::Cache, fx::Vararg{Any, N}) where {N}
261+
value_and_gradient!!(cache::Cache, f, x...)
262+
263+
Computes a 2-tuple. The first element is `f(x...)`, and the second is a tuple containing the
264+
gradient of `f` w.r.t. each argument. The first element is the gradient w.r.t any
265+
differentiable fields of `f`, the second w.r.t the first element of `x`, etc.
266+
267+
Assumes that `f` returns a `Union{Float16, Float32, Float64}`.
268+
269+
As with all functionality in Mooncake, if `f` modifes itself or `x`, `value_and_gradient!!`
270+
will return both to their original state as part of the process of computing the gradient.
240271
241-
WARNING: experimental functionality. Interface subject to change without warning!
272+
!!! info
273+
`cache` must be the output of [`prepare_gradient_cache`](@ref), and (fields of) `f` and
274+
`x` must be of the same size and shape as those used to construct the `cache`. This is
275+
to ensure that the gradient can be written to the memory allocated when the `cache` was
276+
built.
242277
243-
Like other methods of `value_and_gradient!!`, but makes use of the `cache` object returned
244-
by [`prepare_gradient_cache`](@ref) in order to avoid having to re-allocate various tangent
245-
objects repeatedly. You must ensure that `f` and `x` are the same types and sizes as those
246-
used to construct `cache`.
278+
!!! warning
279+
`cache` owns any mutable state returned by this function, meaning that mutable
280+
components of values returned by it will be mutated if you run this function again with
281+
different arguments. Therefore, if you need to keep the values returned by this function
282+
around over multiple calls to this function with the same `cache`, you should take a
283+
copy (using `copy` or `deepcopy`) of them before calling again.
247284
248-
Warning: `cache` owns any mutable state returned by this function, meaning that mutable
249-
components of values returned by it will be mutated if you run this function again with
250-
different arguments. Therefore, if you need to keep the values returned by this function
251-
around over multiple calls to this function with the same `cache`, you should take a copy
252-
(using `copy` or `deepcopy`) of them before calling again.
285+
# Example Usage
286+
```jldoctest
287+
f(x, y) = sum(x .* y)
288+
x = [2.0, 2.0]
289+
y = [1.0, 1.0]
290+
cache = prepare_gradient_cache(f, x, y)
291+
value_and_gradient!!(cache, f, x, y)
292+
293+
# output
294+
295+
(4.0, (NoTangent(), [1.0, 1.0], [2.0, 2.0]))
296+
```
253297
"""
254298
function value_and_gradient!!(cache::Cache, f::F, x::Vararg{Any,N}) where {F,N}
255299
coduals = tuple_map(CoDual, (f, x...), tuple_map(set_to_zero!!, cache.tangents))

0 commit comments

Comments
 (0)