Skip to content

Commit 4471a73

Browse files
authored
Merge branch 'master' into drop16
2 parents 5ad23d7 + 572eb2a commit 4471a73

File tree

13 files changed

+87
-81
lines changed

13 files changed

+87
-81
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Zygote"
22
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
3-
version = "0.7.0"
3+
version = "0.7.3"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -57,7 +57,7 @@ Requires = "1.1"
5757
SpecialFunctions = "1.6, 2"
5858
Statistics = "1"
5959
Tracker = "0.2"
60-
ZygoteRules = "0.2.5"
60+
ZygoteRules = "0.2.7"
6161
julia = "1.10"
6262

6363
[extras]

docs/src/limitations.md

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ Let's explore this with a more concrete example. Here we define a simple mutatin
2020
```julia
2121
function f!(x)
2222
x .= 2 .* x
23-
2423
return x
2524
end
2625
```
@@ -42,43 +41,36 @@ Stacktrace:
4241
...
4342
```
4443
We got an error message and a long stacktrace. The error informs us that our code performs array mutation by calling `copyto!` (we might not have directly called this function, but it is being invoked somewhere in the call stack). We see that our code includes `x .= ...` which is given as an example of array mutation. Other examples of mutating operations include:
45-
- setting values (`x .= ...`)
46-
- appending/popping values (`push!(x, v)` / `pop!(x)`)
47-
- calling mutating functions (`mul!(C, A, B)`)
44+
- setting values (`x[i] = val` or `x .= values`)
45+
- appending/popping values (`push!(x, v)` or `pop!(x)`)
46+
- calling mutating functions (such as `LinearAlgebra.mul!(C, A, B)`)
4847

4948
!!! warning
5049

5150
Non-mutating functions might also use mutation under the hood. This can be done for performance reasons or code re-use.
5251

5352
```julia
54-
function g!(x, y)
55-
x .= 2 .* y
56-
53+
function g_inner!(x, y)
54+
for i in eachindex(x, y)
55+
x[i] = 2 * y[i]
56+
end
5757
return x
5858
end
59-
g(y) = g!(similar(y), y)
60-
```
61-
Here `g` is a "non-mutating function," and it indeed does not mutate `y`, its only argument. But it still allocates a new array and calls `g!` on this array which will result in a mutating operation. You may encounter such functions when working with another package.
62-
63-
Specifically for array mutation, we can use [`Zygote.Buffer`](@ref) to re-write our function. For example, let's fix the function `g!` above.
64-
```julia
65-
function g!(x, y)
66-
x .= 2 .* y
6759

68-
return x
60+
function g_outer(y)
61+
z = similar(y)
62+
g_inner!(z, y)
63+
return z
6964
end
65+
```
66+
Here `g_outer` does not mutate `y`, its only argument. But it still allocates a new array `z` and calls `g_inner!` on this array, which will result in a mutating operation. You may encounter such functions when working with another package.
7067

71-
function g(y)
72-
x = Zygote.Buffer(y) # Buffer supports syntax like similar
73-
g!(x, y)
74-
return copy(x) # this step makes the Buffer immutable (w/o actually copying)
75-
end
68+
How can you solve this problem?
69+
* Re-write the code not to use mutation. Here we can obviously write `g_better(y) = 2 .* y` using broadcasting. Many other cases may be solved by writing comprehensions `[f(x, y) for x in xs, y in ys]` or using `map(f, xs, ys)`, instead of explicitly allocating an output array and then writing into it.
70+
* Write a custom rule, defining `rrule(::typeof(g), y)` using what you know about `g` to derive the right expression.
71+
* Use another AD package instead of Zygote for part of the calculation. Replacing `g(y)` with `Zygote.forwarddiff(g, y)` will compute the same value, but when it is time to find the gradient, this job is outsourced to [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). ForwardDiff has its own limitations but mutation isn't one of them.
7672

77-
julia> gradient(rand(3)) do y
78-
sum(g(y))
79-
end
80-
([2.0, 2.0, 2.0],)
81-
```
73+
Finally, there is also [`Zygote.Buffer`](@ref) which aims to handle the pattern of allocating space and then mutating it. But it has many bugs and is not really recommended.
8274

8375
## Try-catch statements
8476

@@ -136,7 +128,8 @@ For all of the errors above, the suggested solutions are similar. You have the f
136128
2. define a [custom `ChainRulesCore.rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html)
137129
3. open an [issue on Zygote](https://github.com/FluxML/Zygote.jl/issues)
138130

139-
Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. If you are using `try`/`catch` statements, try to use more graceful error handling such as returning `nothing` or another sentinel value. Recall that array mutation can also be avoided by using [`Zygote.Buffer`](@ref) as discussed above.
131+
Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. Instead of allocating an array and writing into it, try to make the output directly using broadcasting, `map`, or a comprehension.
132+
If you are using `try`/`catch` statements, try to use more graceful error handling such as returning `nothing` or another sentinel value.
140133

141134
Sometimes, we cannot avoid expressions that Zygote cannot differentiate, but we may be able to manually derive a gradient. In these cases, you can write [a custom `rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html) using ChainRules.jl. Please refer to the linked ChainRules documentation for how to do this. _This solution is the only solution available for foreign call expressions._ Below, we provide a custom `rrule` for `jclock`.
142135
```julia

src/compiler/chainrules.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from
22
# Zygote rules here?
3-
function unthunk_tangent end
4-
@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
5-
@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x
6-
@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
7-
@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
8-
unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
3+
@inline ZygoteRules.unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
4+
@inline ZygoteRules.unthunk_tangent(x::NTuple{N,<:Number}) where N = x
5+
@inline ZygoteRules.unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
6+
@inline ZygoteRules.unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
7+
ZygoteRules.unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
98
@non_differentiable unthunk_tangent(::IdDict)
109

1110

src/compiler/interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d
152152
function gradient(f, args...)
153153
y, back = pullback(f, args...)
154154
grad = back(sensitivity(y))
155-
return _project_all(args, grad)
155+
return _project_all(args, unthunk_tangent(grad))
156156
end
157157

158158
# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
@@ -218,7 +218,7 @@ function withgradient(f, args...)
218218
else
219219
back(sensitivity(y))
220220
end
221-
results = _project_all(args, grad)
221+
results = _project_all(args, unthunk_tangent(grad))
222222
(val=y, grad=results)
223223
end
224224

src/compiler/reverse.jl

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ end
299299

300300
function adjoint(pr::Primal)
301301
ir, sigs = adjointcfg(pr)
302+
catch_blocks = falses(length(blocks(pr.ir)))
302303
for b in reverse(blocks(pr.ir))
303304
rb = block(ir, b.id)
304305
grads = Dict()
@@ -309,12 +310,13 @@ function adjoint(pr::Primal)
309310
grad(sigs[b.id][i], arguments(rb)[i])
310311
end
311312

312-
has_leave = false
313-
314313
# Backprop through statements
315314
for v in reverse(keys(b))
316315
ex = b[v].expr
317-
has_leave |= isexpr(ex, :leave)
316+
317+
if isexpr(ex, :catch)
318+
catch_blocks[first(ex.args)] = true
319+
end
318320

319321
if haskey(pr.pullbacks, v)
320322
g = push!(rb, stmt(Expr(:call, alpha(pr.pullbacks[v]), grad(v)),
@@ -338,16 +340,6 @@ function adjoint(pr::Primal)
338340
end
339341
end
340342

341-
# This is corresponds to a catch blocks which technically
342-
# has predecessors but they are not modelled in the IRTools CFG.
343-
# We put an error message at the beginning of said block.
344-
if has_leave && isempty(predecessors(b)) && b.id != 1
345-
_, f_stmt = first(b)
346-
li = pr.ir.lines[f_stmt.line]
347-
pushfirst!(rb, stmt(xcall(Base, :error,
348-
"Can't differentiate function execution in catch block at $(li.file):$(li.line).")))
349-
end
350-
351343
if b.id > 1 # Backprop through (predecessor) branch arguments
352344
gs = grad.(arguments(b))
353345
for br in branches(rb)
@@ -368,6 +360,22 @@ function adjoint(pr::Primal)
368360
branches(rb)[1].args[1] = Δ
369361
end
370362
end
363+
364+
for (id, is_catch) in enumerate(catch_blocks)
365+
is_catch || continue
366+
367+
b = block(pr.ir, id)
368+
rb = block(ir, id)
369+
err_message = if isempty(b)
370+
"Can't differentiate function execution in catch block"
371+
else
372+
_, f_stmt = first(b)
373+
li = pr.ir.lines[f_stmt.line]
374+
"Can't differentiate function execution in catch block at $(li.file):$(li.line)."
375+
end
376+
pushfirst!(rb, stmt(xcall(Base, :error, err_message)))
377+
end
378+
371379
return ir
372380
end
373381

src/lib/array.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -337,17 +337,6 @@ end
337337
end
338338

339339
# Reductions
340-
@adjoint function sum(xs::AbstractArray; dims = :)
341-
if dims === (:)
342-
sum(xs), Δ -> (Fill(Δ, size(xs)),)
343-
else
344-
sum(xs, dims = dims), Δ -> (similar(xs) .= Δ,)
345-
end
346-
end
347-
348-
@adjoint function sum(xs::AbstractArray{Bool}; dims = :)
349-
sum(xs, dims = dims), Δ -> (nothing,)
350-
end
351340

352341
function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
353342
return _pullback(cx, (f, xs) -> prod(f.(xs)), f, xs)

src/lib/broadcast.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -365,11 +365,6 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve
365365
@adjoint (::Type{T})(xs::Array) where {T <: AbstractGPUArray} =
366366
T(xs), Δ -> (convert(Array, Δ), )
367367

368-
@adjoint function sum(xs::AbstractGPUArray; dims = :)
369-
placeholder = similar(xs)
370-
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
371-
end
372-
373368
# Make sure sum(f, ::CuArray) uses broadcast through forward-mode defined above
374369
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
375370
function _pullback(cx::AContext, ::typeof(sum), f, xs::AbstractGPUArray)

src/lib/lib.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ end
3232
accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y))
3333
accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y)
3434

35+
accum(x::Nothing, y::AbstractThunk) = y
36+
accum(x::AbstractThunk, y::Nothing) = x
37+
3538
accum(x, y::AbstractThunk) = @thunk(accum(x, unthunk(y)))
3639
accum(x::AbstractThunk, y) = @thunk(accum(unthunk(x), y))
3740
accum(x::AbstractThunk, y::AbstractThunk) = @thunk(accum(unthunk(x), unthunk(y)))

test/chainrules.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,3 +428,31 @@ end
428428
@test Zygote.wrap_chainrules_input([[2.0; 4.0], [1.0; 3.0]]) == [[2.0; 4.0], [1.0; 3.0]]
429429
@test Zygote.wrap_chainrules_input([nothing; 4.0]) == [0.0; 4.0] # ChainRules uses the numeric zero where possible
430430
end
431+
432+
@testset "Lazy" begin
433+
custom_add(x, y) = x + y
434+
function ChainRulesCore.rrule(::typeof(custom_add), x, y)
435+
function pullback(Δ)
436+
return NoTangent(), unthunk(Δ), @thunk(error("Should not compute."))
437+
end
438+
custom_add(x, y), pullback
439+
end
440+
441+
x, y = 1f0, 1f0
442+
Zygote.gradient(x) do x
443+
sum(custom_add(x, y))
444+
end
445+
end
446+
447+
@testset "No thunks in the gradient" begin
448+
struct CustomDense
449+
w::Matrix{Float32}
450+
end
451+
(d::CustomDense)(x) = d.w * x
452+
453+
layers = [CustomDense(rand(Float32, 3, 3))]
454+
x = ones(Float32, 3)
455+
g = gradient(layers -> sum(layers[1](x)), layers)[1]
456+
@test g[1] isa NamedTuple
457+
@test g[1].w isa Array
458+
end

test/compiler.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,7 @@ end
309309
@test res == 12.
310310
@test_throws ErrorException pull(1.)
311311
err = try pull(1.) catch ex; ex end
312-
if VERSION >= v"1.11"
313-
@test_broken occursin("Can't differentiate function execution in catch block", string(err))
314-
else
315-
@test occursin("Can't differentiate function execution in catch block", string(err))
316-
end
312+
@test occursin("Can't differentiate function execution in catch block", string(err))
317313
end
318314

319315
@testset "try/catch/else" begin
@@ -339,9 +335,5 @@ end
339335
@test_throws ErrorException pull(1.)
340336

341337
err = try pull(1.) catch ex; ex end
342-
if VERSION >= v"1.11"
343-
@test_broken occursin("Can't differentiate function execution in catch block", string(err))
344-
else
345-
@test occursin("Can't differentiate function execution in catch block", string(err))
346-
end
338+
@test occursin("Can't differentiate function execution in catch block", string(err))
347339
end

0 commit comments

Comments
 (0)