forked from FluxML/Zygote.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgrad.jl
304 lines (238 loc) · 9.25 KB
/
grad.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
macro which(ex)
@capture(ex, f_(args__)) || error("Zygote.@which f(args...)")
:(InteractiveUtils.@which adjoint(Context(), $(esc(f)), $(esc.(args)...)))
end
"""
checkpointed(f, xs...)
Use gradient checkpointing on the call `f(xs...)`. This means that
`checkpointed(f, xs...) === f(xs...)`, but when computing the derivative
intermediate results from the forward pass of `f` will not be stored. Instead the forward
pass will be repeated, when computing the derivative.
This saves memory at the cost of increasing execution time.
!!! warning
If `f` is not a pure function, `checkpointed` will likely give wrong results.
"""
checkpointed(f, xs...) = f(xs...)
function Zygote._pullback(ctx::Zygote.AContext, ::typeof(checkpointed), f, xs...)
y = f(xs...)
function pullback_checkpointed(Δy)
y, pb = Zygote._pullback(ctx, f, xs...)
return (nothing, pb(Δy)...)
end
return y, pullback_checkpointed
end
"""
eager_update(f, update, state, xs...)
Allows training large models when the gradients cannot all fit in memory simultaneously.
A combination of gradient checkpointing and eagerly updating the model parameters, discarding the updated gradients.
Assumes that `f` is a callable struct, `state` is the optimization state (eg. from Optimisers.jl) matching `f`, and
`update` is the function that updates the parameters of `f` from the state and the gradients, called as `update(state, f, grads)`.
If eg. `model.layers[i]` is layer in a transformer, then:
```
for i in 1:length(model.layers)
h = eager_updater(model.layers[i], Optimisers.update!, opt_state.layers[i], h, other_args)
end
```
!!! warning
If different layers share trainable parameters, then `eager_update` will likely give wrong results.
"""
eager_update(f, update, state, xs...) = f(state, xs...)
function Zygote._pullback(ctx::Zygote.AContext, ::typeof(eager_update), f, update, state, xs...)
y = f(xs...)
function pullback_eager_update(Δy)
y, pb = Zygote._pullback(ctx, f, xs...)
ret = pb(Δy)
update(state, f, ret[1])
return (nothing, nothing, nothing, nothing, ret[2:end]...)
end
return y, pullback_eager_update
end
"""
hessian(f, x)
Construct the Hessian `∂²f/∂x²`, where `x` is a real number or an array,
and `f(x)` is a real number. When `x` is an array, the result is a matrix
`H[i,j] = ∂²f/∂x[i]∂x[j]`, using linear indexing `x[i]` even if the argument
is higher-dimensional.
This uses forward over reverse, ForwardDiff over Zygote, calling `hessian_dual(f, x)`.
See [`hessian_reverse`](@ref) for an all-Zygote alternative.
See also [`diaghessian`](@ref) to compute only the diagonal part.
# Examples
```jldoctest; setup=:(using Zygote)
julia> hessian(x -> x[1]*x[2], randn(2))
2×2 Matrix{Float64}:
0.0 1.0
1.0 0.0
julia> hessian(x -> sum(x.^3), [1 2; 3 4]) # uses linear indexing of x
4×4 Matrix{$Int}:
6 0 0 0
0 18 0 0
0 0 12 0
0 0 0 24
julia> hessian(sin, pi/2)
-1.0
```
"""
hessian(f, x) = hessian_dual(f, x)
hessian_dual(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2]
hessian_dual(f, x::Number) = ForwardDiff.derivative(x -> gradient(f, x)[1], x)
"""
hessian_reverse(f, x)
This should be equivalent to [`hessian(f, x)`](@ref hessian),
but implemented using reverse over reverse mode, all Zygote.
(This is usually much slower, and more likely to find errors.)
"""
hessian_reverse(f, x::AbstractArray) = jacobian(x -> gradient(f, x)[1], x)[1]
hessian_reverse(f, x::Number) = gradient(x -> gradient(f, x)[1], x)[1]
"""
jacobian(f, args...) -> Tuple
For each array `a ∈ args` this returns a matrix with `Ja[k,i] = ∂y[k]/∂a[i]`
where `y = f(args...)` is usually a vector.
Arrays of higher dimension are treated like `vec(a)`, or `vec(y)` for output.
For scalar `x::Number ∈ args`, the result is a vector `Jx[k] = ∂y[k]/∂x`,
while for scalar `y` all results have just one row.
With any other argument type, no result is produced, even if [`gradient`](@ref) would work.
This reverse-mode Jacobian needs to evaluate the pullback once for each element of `y`.
Doing so is usually only efficient when `length(y)` is small compared to `length(a)`,
otherwise forward mode is likely to be better.
See also [`withjacobian`](@ref), [`hessian`](@ref), [`hessian_reverse`](@ref).
# Examples
```jldoctest; setup=:(using Zygote)
julia> jacobian(a -> 100*a[1:3].^2, 1:7)[1] # first index (rows) is output
3×7 Matrix{$Int}:
200 0 0 0 0 0 0
0 400 0 0 0 0 0
0 0 600 0 0 0 0
julia> jacobian((a,x) -> a.^2 .* x, [1,2,3], 1) # scalar argument has vector jacobian
([2 0 0; 0 4 0; 0 0 6], [1, 4, 9])
julia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2)
([2 0 … 0 0; 0 4 … 3 0; 0 0 … 0 5], [0, 0, 0])
```
!!! warning
For arguments of any type except `Number` & `AbstractArray`, the result is `nothing`.
```
julia> jacobian((a,s) -> a.^length(s), [1,2,3], "str")
([3 0 0; 0 12 0; 0 0 27], nothing)
julia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5))
([4 4 4], nothing)
julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # gradient undersands the tuple
([4 4 4], (6, 1))
```
"""
jacobian(f, args...) = withjacobian(f, args...).grad
"""
withjacobian(f, args...)
Returns both the value `f(args...)` and the [`jacobian`](@ref) as a named tuple.
```jldoctest; setup=:(using Zygote)
julia> withjacobian(cumsum, [1,2,3])
(val = [1, 3, 6], grad = ([1 0 0; 1 1 0; 1 1 1],))
```
"""
function withjacobian(f, args...)
y, back = pullback(_jvec∘f, args...)
out = map(args) do x
T = promote_type(eltype(x), eltype(y))
dx = x isa AbstractArray ? similar(x, T, length(y), length(x)) :
x isa Number ? similar(y, T, length(y)) :
nothing
end
delta = _eyelike(y)
for k in LinearIndices(y)
grads = back(delta[:,k])
for (dx, grad) in zip(out, grads)
dx isa AbstractArray || continue
_gradcopy!(view(dx,k,:), grad)
end
end
(val=y, grad=out)
end
_jvec(x::AbstractArray) = vec(x)
_jvec(x::Number) = _jvec(vcat(x))
_jvec(x) = throw(ArgumentError("jacobian expected a function which returns an array, or a scalar, got $(typeof(x))"))
_jvec(x::AbstractArray{<:Complex}) = throw(ArgumentError("jacobian does not accept complex output"))
_eyelike(y::Vector) = Matrix{eltype(y)}(I, length(y), length(y))
function _eyelike(y::AbstractVector) # version which works on GPU
out = fill!(similar(y, length(y), length(y)), 0)
out[LinearAlgebra.diagind(out)] .= 1
out
end
_gradcopy!(dst::AbstractArray, src::AbstractArray{<:Number}) = copyto!(dst, src)
_gradcopy!(dst::AbstractArray, src::Number) = copyto!(dst, src)
_gradcopy!(dst::AbstractArray, src::Nothing) = dst .= 0
_gradcopy!(dst::AbstractArray, src::AbstractArray) = copyto!(dst, g isa Number ? g : 0 for g in src) # e.g. Union{Nothing,Float64}
"""
jacobian(loss, ::Params)
Like [`gradient`](@ref) with implicit parameters, this method takes a zero-argument function
and returns an `IdDict`-like object, now containing the Jacobian for each parameter.
# Examples
```jldoctest; setup=:(using Zygote)
julia> xs = [1 2; 3 4]; ys = [5,7,9];
julia> Jxy = jacobian(() -> ys[1:2] .+ sum(xs.^2), Params([xs, ys]))
Grads(...)
julia> Jxy[ys]
2×3 Matrix{$Int}:
1 0 0
0 1 0
julia> Jxy[xs]
2×4 Matrix{$Int}:
2 6 4 8
2 6 4 8
```
"""
jacobian(f, pars::Params) = withjacobian(f, pars::Params).grad
function withjacobian(f, pars::Params)
y, back = pullback(_jvec∘f, pars)
out = IdDict()
for p in pars
T = Base.promote_type(eltype(p), eltype(y))
J = similar(y, T, length(y), length(p))
out[p] = J
end
delta = _eyelike(y)
for k in LinearIndices(y)
grads = back(delta[:,k])
for p in pars
out[p] isa AbstractArray || continue
_gradcopy!(view(out[p],k,:), grads[p])
end
end
(val=y, grad=Grads(out, pars))
end
"""
diaghessian(f, args...) -> Tuple
Diagonal part of the Hessian. Returns a tuple containing, for each argument `x`,
`h` of the same shape with `h[i] = Hᵢᵢ = ∂²y/∂x[i]∂x[i]`.
The original evaluation `y = f(args...)` must give a real number `y`.
For one vector argument `x`, this is equivalent to `(diag(hessian(f,x)),)`.
Like [`hessian`](@ref) it uses ForwardDiff over Zygote.
!!! warning
For arguments of any type except `Number` & `AbstractArray`, the result is `nothing`.
# Examples
```jldoctest; setup=:(using Zygote, LinearAlgebra)
julia> diaghessian(x -> sum(x.^3), [1 2; 3 4])[1]
2×2 Matrix{$Int}:
6 12
18 24
julia> Diagonal(vec(ans)) == hessian(x -> sum(x.^3), [1 2; 3 4]) # full Hessian is diagonal
true
julia> diaghessian((x,y) -> sum(x .* y .* y'), [1 22; 333 4], [0.5, 0.666]) # two array arguments
([0.0 0.0; 0.0 0.0], [2.0, 8.0])
julia> diaghessian(atan, 1, 2) # two scalar arguments
(-0.16, 0.16)
julia> hessian(xy -> atan(xy[1], xy[2]), [1, 2]) # full Hessian is not diagonal
2×2 Matrix{Float64}:
-0.16 -0.12
-0.12 0.16
```
"""
function diaghessian(f, args...)
ntuple(length(args)) do n
let x = args[n], valn = Val(n) # let Val improves speed, sometimes
if x isa AbstractArray
forward_diag(x -> gradient(f, _splice(x, args, valn)...)[n], x)[2]
elseif x isa Number
ForwardDiff.derivative(x -> gradient(f, _splice(x, args, valn)...)[n], x)
end
end
end
end
_splice(x, args, ::Val{n}) where {n} = ntuple(i -> i==n ? x : args[i], length(args))