@@ -175,10 +175,7 @@ _copy!!(::Number, src::Number) = src
175
175
"""
176
176
prepare_pullback_cache(f, x...)
177
177
178
- WARNING: experimental functionality. Interface subject to change without warning!
179
-
180
- Returns a `cache` which can be passed to `value_and_gradient!!`. See the docstring for
181
- `Mooncake.value_and_gradient!!` for more info.
178
+ Returns a cache used with [`value_and_pullback!!`](@ref). See that function for more info.
182
179
"""
183
180
function prepare_pullback_cache (fx... ; kwargs... )
184
181
@@ -200,18 +197,46 @@ end
200
197
"""
201
198
value_and_pullback!!(cache::Cache, ȳ, f, x...)
202
199
203
- WARNING: experimental functionality. Interface subject to change without warning!
200
+ !!! info
201
+ If `f(x...)` returns a scalar, you should use [`value_and_gradient!!`](@ref), not this
202
+ function.
203
+
204
+ Computes a 2-tuple. The first element is `f(x...)`, and the second is a tuple containing the
205
+ pullback of `f` applied to `ȳ`. The first element is the component of the pullback
206
+ associated to any fields of `f`, the second w.r.t the first element of `x`, etc.
207
+
208
+ There are no restrictions on what `y = f(x...)` is permitted to return. However, `ȳ` must be
209
+ an acceptable tangent for `y`. This means that, for example, it must be true that
210
+ `tangent_type(typeof(y)) == typeof(ȳ)`.
211
+
212
+ As with all functionality in Mooncake, if `f` modifes itself or `x`, `value_and_gradient!!`
213
+ will return both to their original state as part of the process of computing the gradient.
214
+
215
+ !!! info
216
+ `cache` must be the output of [`prepare_pullback_cache`](@ref), and (fields of) `f` and
217
+ `x` must be of the same size and shape as those used to construct the `cache`. This is
218
+ to ensure that the gradient can be written to the memory allocated when the `cache` was
219
+ built.
220
+
221
+ !!! warning
222
+ `cache` owns any mutable state returned by this function, meaning that mutable
223
+ components of values returned by it will be mutated if you run this function again with
224
+ different arguments. Therefore, if you need to keep the values returned by this function
225
+ around over multiple calls to this function with the same `cache`, you should take a
226
+ copy (using `copy` or `deepcopy`) of them before calling again.
227
+
228
+ # Example Usage
229
+ ```jldoctest
230
+ f(x, y) = sum(x .* y)
231
+ x = [2.0, 2.0]
232
+ y = [1.0, 1.0]
233
+ cache = Mooncake.prepare_pullback_cache(f, x, y)
234
+ Mooncake.value_and_pullback!!(cache, 1.0, f, x, y)
204
235
205
- Like other methods of `value_and_pullback!!`, but makes use of the `cache` object returned
206
- by [`prepare_pullback_cache`](@ref) in order to avoid having to re-allocate various tangent
207
- objects repeatedly. You must ensure that `f` and `x` are the same types and sizes as those
208
- used to construct `cache`.
236
+ # output
209
237
210
- Warning: `cache` owns any mutable state returned by this function, meaning that mutable
211
- components of values returned by it will be mutated if you run this function again with
212
- different arguments. Therefore, if you need to keep the values returned by this function
213
- around over multiple calls to this function with the same `cache`, you should take a copy
214
- (using `copy` or `deepcopy`) of them before calling again.
238
+ (4.0, (NoTangent(), [1.0, 1.0], [2.0, 2.0]))
239
+ ```
215
240
"""
216
241
function value_and_pullback!! (cache:: Cache , ȳ, f:: F , x:: Vararg{Any,N} ) where {F,N}
217
242
tangents = tuple_map (set_to_zero!!, cache. tangents)
222
247
"""
223
248
prepare_gradient_cache(f, x...)
224
249
225
- WARNING: experimental functionality. Interface subject to change without warning!
226
-
227
- Returns a `cache` which can be passed to `value_and_gradient!!`. See the docstring for
228
- `Mooncake.value_and_gradient!!` for more info.
250
+ Returns a cache used with [`value_and_gradient!!`](@ref). See that function for more info.
229
251
"""
230
252
function prepare_gradient_cache (fx... ; kwargs... )
231
253
rule = build_rrule (fx... ; kwargs... )
@@ -236,20 +258,42 @@ function prepare_gradient_cache(fx...; kwargs...)
236
258
end
237
259
238
260
"""
239
- value_and_gradient!!(cache::Cache, fx::Vararg{Any, N}) where {N}
261
+ value_and_gradient!!(cache::Cache, f, x...)
262
+
263
+ Computes a 2-tuple. The first element is `f(x...)`, and the second is a tuple containing the
264
+ gradient of `f` w.r.t. each argument. The first element is the gradient w.r.t any
265
+ differentiable fields of `f`, the second w.r.t the first element of `x`, etc.
266
+
267
+ Assumes that `f` returns a `Union{Float16, Float32, Float64}`.
268
+
269
+ As with all functionality in Mooncake, if `f` modifes itself or `x`, `value_and_gradient!!`
270
+ will return both to their original state as part of the process of computing the gradient.
240
271
241
- WARNING: experimental functionality. Interface subject to change without warning!
272
+ !!! info
273
+ `cache` must be the output of [`prepare_gradient_cache`](@ref), and (fields of) `f` and
274
+ `x` must be of the same size and shape as those used to construct the `cache`. This is
275
+ to ensure that the gradient can be written to the memory allocated when the `cache` was
276
+ built.
242
277
243
- Like other methods of `value_and_gradient!!`, but makes use of the `cache` object returned
244
- by [`prepare_gradient_cache`](@ref) in order to avoid having to re-allocate various tangent
245
- objects repeatedly. You must ensure that `f` and `x` are the same types and sizes as those
246
- used to construct `cache`.
278
+ !!! warning
279
+ `cache` owns any mutable state returned by this function, meaning that mutable
280
+ components of values returned by it will be mutated if you run this function again with
281
+ different arguments. Therefore, if you need to keep the values returned by this function
282
+ around over multiple calls to this function with the same `cache`, you should take a
283
+ copy (using `copy` or `deepcopy`) of them before calling again.
247
284
248
- Warning: `cache` owns any mutable state returned by this function, meaning that mutable
249
- components of values returned by it will be mutated if you run this function again with
250
- different arguments. Therefore, if you need to keep the values returned by this function
251
- around over multiple calls to this function with the same `cache`, you should take a copy
252
- (using `copy` or `deepcopy`) of them before calling again.
285
+ # Example Usage
286
+ ```jldoctest
287
+ f(x, y) = sum(x .* y)
288
+ x = [2.0, 2.0]
289
+ y = [1.0, 1.0]
290
+ cache = prepare_gradient_cache(f, x, y)
291
+ value_and_gradient!!(cache, f, x, y)
292
+
293
+ # output
294
+
295
+ (4.0, (NoTangent(), [1.0, 1.0], [2.0, 2.0]))
296
+ ```
253
297
"""
254
298
function value_and_gradient!! (cache:: Cache , f:: F , x:: Vararg{Any,N} ) where {F,N}
255
299
coduals = tuple_map (CoDual, (f, x... ), tuple_map (set_to_zero!!, cache. tangents))
0 commit comments