Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eager parameter updating #1541

Merged
merged 8 commits into from
Dec 29, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/lib/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,42 @@ function Zygote._pullback(ctx::Zygote.AContext, ::typeof(checkpointed), f, xs...
return y, pullback_checkpointed
end


"""

eager_update(f, update, state, xs...)

Allows training large models when the gradients cannot all fit in memory simultaneously.

A combination of gradient checkpointing and eagerly updating the model parameters, discarding the updated gradients.
Assumes that `f` is a callable struct, `state` is the optimization state (eg. from Optimisers.jl) matching `f`, and
`update` is the function that updates the parameters of `f` from the state and the gradients, called as `update(state, f, grads)`.

If eg. `model.layers[i]` is layer in a transformer, then:

```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
```
```julia

for i in 1:length(model.layers)
h = eager_updater(model.layers[i], Optimisers.update!, opt_state.layers[i], h, other_args)
end
```

!!! warning
If different layers share trainable parameters, then `eager_update` will likely give wrong results.
"""
eager_update(f, update, state, xs...) = f(state, xs...)

function Zygote._pullback(ctx::Zygote.AContext, ::typeof(eager_update), f, update, state, xs...)
y = f(xs...)
function pullback_eager_update(Δy)
y, pb = Zygote._pullback(ctx, f, xs...)
ret = pb(Δy)
update(state, f, ret[1])
return (nothing, nothing, nothing, nothing, ret[2:end]...)
end
return y, pullback_eager_update
end


"""
hessian(f, x)

Expand Down
Loading