Skip to content

Commit 1b111d8

Browse files
authored
Unthunk tangents (if any) before returning gradient (#1551)
1 parent a38a4a5 commit 1b111d8

File tree

5 files changed

+39
-9
lines changed

5 files changed

+39
-9
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Requires = "1.1"
5757
SpecialFunctions = "1.6, 2"
5858
Statistics = "1"
5959
Tracker = "0.2"
60-
ZygoteRules = "0.2.5"
60+
ZygoteRules = "0.2.7"
6161
julia = "1.6"
6262

6363
[extras]

src/compiler/chainrules.jl

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from
22
# Zygote rules here?
3-
function unthunk_tangent end
4-
@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
5-
@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x
6-
@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
7-
@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
8-
unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
3+
@inline ZygoteRules.unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
4+
@inline ZygoteRules.unthunk_tangent(x::NTuple{N,<:Number}) where N = x
5+
@inline ZygoteRules.unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
6+
@inline ZygoteRules.unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
7+
ZygoteRules.unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
98
@non_differentiable unthunk_tangent(::IdDict)
109

1110

src/compiler/interface.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d
152152
function gradient(f, args...)
153153
y, back = pullback(f, args...)
154154
grad = back(sensitivity(y))
155-
return _project_all(args, grad)
155+
return _project_all(args, unthunk_tangent(grad))
156156
end
157157

158158
# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
@@ -218,7 +218,7 @@ function withgradient(f, args...)
218218
else
219219
back(sensitivity(y))
220220
end
221-
results = _project_all(args, grad)
221+
results = _project_all(args, unthunk_tangent(grad))
222222
(val=y, grad=results)
223223
end
224224

src/lib/lib.jl

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ end
4040
accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y))
4141
accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y)
4242

43+
accum(x::Nothing, y::AbstractThunk) = y
44+
accum(x::AbstractThunk, y::Nothing) = x
45+
4346
accum(x, y::AbstractThunk) = @thunk(accum(x, unthunk(y)))
4447
accum(x::AbstractThunk, y) = @thunk(accum(unthunk(x), y))
4548
accum(x::AbstractThunk, y::AbstractThunk) = @thunk(accum(unthunk(x), unthunk(y)))

test/chainrules.jl

+28
Original file line numberDiff line numberDiff line change
@@ -428,3 +428,31 @@ end
428428
@test Zygote.wrap_chainrules_input([[2.0; 4.0], [1.0; 3.0]]) == [[2.0; 4.0], [1.0; 3.0]]
429429
@test Zygote.wrap_chainrules_input([nothing; 4.0]) == [0.0; 4.0] # ChainRules uses the numeric zero where possible
430430
end
431+
432+
@testset "Lazy" begin
433+
custom_add(x, y) = x + y
434+
function ChainRulesCore.rrule(::typeof(custom_add), x, y)
435+
function pullback(Δ)
436+
return NoTangent(), unthunk(Δ), @thunk(error("Should not compute."))
437+
end
438+
custom_add(x, y), pullback
439+
end
440+
441+
x, y = 1f0, 1f0
442+
Zygote.gradient(x) do x
443+
sum(custom_add(x, y))
444+
end
445+
end
446+
447+
@testset "No thunks in the gradient" begin
448+
struct CustomDense
449+
w::Matrix{Float32}
450+
end
451+
(d::CustomDense)(x) = d.w * x
452+
453+
layers = [CustomDense(rand(Float32, 3, 3))]
454+
x = ones(Float32, 3)
455+
g = gradient(layers -> sum(layers[1](x)), layers)[1]
456+
@test g[1] isa NamedTuple
457+
@test g[1].w isa Array
458+
end

0 commit comments

Comments
 (0)