-
Notifications
You must be signed in to change notification settings - Fork 14
Start forward mode AD #389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great. I've left a few comments, but if you're planning to do a bunch of additional stuff, then maybe they're redundant. Either way, don't feel the need to respond to them.
Co-authored-by: Will Tebbutt <[email protected]> Signed-off-by: Guillaume Dalle <[email protected]>
@willtebbutt following our discussion yesterday I scratched my head some more, and I decided that it would be infinitely simpler to enforce the invariant that one line of primal IR maps to one line of dual IR. While this may require additional fallbacks in the Julia code itself, I hope it will make our lives much easier on the IR side. What do you think? |
I think this could work. You could just replace the @inline function call_frule!!(rule::R, fargs::Vararg{Any, N}) where {N}
return rule(map(x -> x isa Dual ? x : zero_dual(x), fargs)...)
end The optimisation pass will lower this to the what we were thinking about writing out in the IR anyway. I think the other important kinds of nodes would be largely straightforward to handle. |
I think we might need to be slightly more subtle. If an argument to the |
Yes. I think my propose code handles this though, or am I missing something? |
In the spirit of higher-order AD, we may encounter |
Very good point.
Agreed. Specifically, I think we need to distinguish between literals / |
I still need to dig into the different node types we might encounter (and I still don't understand |
I was reviewing the design docs and realised that, sadly, the "one line of primal IR maps to one line of dual IR" won't work for |
I think that's okay, the main trouble is adding new lines which insert new variables because it requires manual renumbering. A GoTo should be much simpler. |
Were the difficulties around renumbering etc not resolved by not |
No they weren't. I experimented with |
Ah, right, but we do need to insert a new SSAValue. Suppose that the GotoIfNot(%5, #3) i.e. jump to block 3 if not %new_ssa = Expr(:call, primal, %5)
GotoIfNot(%new_ssa, #3) Does this not cause the same kind of problems? |
Oh yes you're probably right. Although it might be slightly less of a hassle because the new SSA is only used in one spot, right after. I'll take a look |
Do you know what I should do about expressions of type |
Yup -- I just strip them out of the IR entirely in reverse-mode. See https://github.com/compintell/Mooncake.jl/blob/0f37c079bd1ae064e7b84696eed4a1f7eb763f1f/src/interpreter/s2s_reverse_mode_ad.jl#L728 The way to remove an instruction from an |
I think this works for
MWE (requires this branch of Mooncake): const CC = Core.Compiler
using Mooncake
using MistyClosures
f(x) = x > 1 ? 2x : 3 + x
ir = Base.code_ircode(f, (Float64,))[1][1]
initial_ir = copy(ir)
get_primal_inst = CC.NewInstruction(Expr(:call, +, 1, 2), Any) # placeholder for get_primal
CC.insert_node!(ir, CC.SSAValue(3), get_primal_inst, false)
ir = CC.compact!(ir)
for k in 1:length(ir.stmts)
inst = ir[CC.SSAValue(k)][:stmt]
if inst isa Core.GotoIfNot
Mooncake.replace_call!(ir,CC.SSAValue(k), Core.GotoIfNot(CC.SSAValue(k-1), inst.dest))
end
end
ir julia> initial_ir
5 1 ─ %1 = Base.lt_float(1.0, _2)::Bool │╻╷╷ >
│ %2 = Base.or_int(%1, false)::Bool ││╻ <
└── goto #3 if not %2 │
2 ─ %4 = Base.mul_float(2.0, _2)::Float64 ││╻ *
└── return %4 │
3 ─ %6 = Base.add_float(3.0, _2)::Float64 ││╻ +
└── return %6 │
julia> ir
5 1 ─ %1 = Base.lt_float(1.0, _2)::Bool │╻╷╷ >
│ Base.or_int(%1, false)::Bool ││╻ <
│ %3 = (+)(1, 2)::Any │
└── goto #3 if not %3 │
2 ─ %5 = Base.mul_float(2.0, _2)::Float64 ││╻ *
└── return %5 │
3 ─ %7 = Base.add_float(3.0, _2)::Float64 ││╻ +
└── return %7 |
Just requires implementing forwards-mode for edit: also for |
@from_rrule DefaultCtx Tuple{typeof(cosint),IEEEFloat} | ||
@from_rrule DefaultCtx Tuple{typeof(ellipk),IEEEFloat} | ||
@from_rrule DefaultCtx Tuple{typeof(ellipe),IEEEFloat} | ||
@from_chain_rule DefaultCtx Tuple{typeof(airyai),IEEEFloat} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A relatively minor comment: from_chainrules
is more precise than from_chain_rule
. The former clarifies that we are importing a rule from ChainRules
, while the latter mislead me since I thought it refers to the generic chain rule terminology.
@from_chain_rule DefaultCtx Tuple{typeof(airyai),IEEEFloat} | |
@from_chainrules DefaultCtx Tuple{typeof(airyai),IEEEFloat} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, interesting. I have no strong view either way, so I'm happy to change it if you think it's from_chainerules
is clearer.
@gdalle is there going to be an easy way to test Hessian-vector products using DI before we have the various things sorted in ADTypes / DI that we need to have sorted in order to make forwards-mode Mooncake via DI work. I'm just wondering to try and include Hessian-vector products computed via DI as test cases in this PR, or to punt it until a later date. |
Yeah no it's not gonna be completely straightforward. But doing the ADTypes changes is a matter of minutes |
Codecov ReportAttention: Patch coverage is
📢 Thoughts on this report? Let us know! |
Hey @willtebbutt! Anything I can do to help bring this over the finish line? |
I'm going to do a pass over the remaining items (listed at the top) tonight, and will report back. Probably the most helpful thing really will be reviewing -- it's such a large PR that it will be hard for anyone who hasn't been involved in its development. |
I think it would make sense to keep the two separate, for the case where defining rules is not strictly necessary but useful to enhance performance. For example, someone may want to define a reverse rule for an optimization solver to avoid backpropagating through every iteration, but decide that the forward-mode behavior is good enough. |
I completely agree. I'm going to edit the todo item to reflect this. |
Gentle bump on this one :) @willtebbutt do you need an ADTypes object / some DI infrastructure for HVPs? |
Apologies @gdalle -- I've made some progress locally on this, but have yet to push changes. I don't think I need an ADType object yet. My feeling is that it's probably best to do this as part of some follow up work. My approach in this PR is to:
I think I'm most of the way with 1, but getting 2 to work will be the real test. |
Is it such a big deal if nested differentiation is only part of a later push? Adding forward mode in addition to reverse mode (not necessarily on top of it) would already be useful for the community. |
To my mind, much of the utility of forwards-mode derives from its ability to be combined with reverse-mode to compute HVPs. Usually I'd be in favour of incrementally adding stuff over the course of a few PRs, but this is such an important feature that I'd rather not merge any forwards-mode stuff until we can apply it over reverse-mode. |
That's because you're an optimization kind of person, but you need to look further ;) |
That's a fair point. Okay, my proposal is this: if I've not managed to get a basic forwards-over-reverse example working by the end of the week, we punt it to a subsequent PR. I agree that it would be very nice to get this merged sooner rather than later. |
Sounds good! Let me know when this is review-ready |
Just chiming in that I ended up here in this thread today specifically because I had a SciML OrdinaryDiffeEq use-case where I wanted Forward-Mode AD, but ForwardDiff.jl would be difficult to use, and Enzyme.jl complained too much. Forward-over-reverse is of course always nice, but not the only usecase for forward mode! Exciting to see progress being made. |
It would be great to make a fresh push to finish off this PR, @willtebbutt. Despite the forward-over-reverse functionality, there are a few other issues to address for a robust design and codebase, as listed above. |
# Construct a callable which performs reverse-mode, and apply forwards-mode over it. | ||
rule = Mooncake.build_rrule(Tuple{typeof(quadratic), Float64}) | ||
TestUtils.test_rule( | ||
StableRNG(123), low_level_gradient, rule, quadratic, 5.0; | ||
interface_only=false, | ||
is_primitive=false, | ||
perf_flag=:none, | ||
unsafe_perturb=true, | ||
forward=true, | ||
) | ||
|
||
# Manually test that this correectly computes the second derivative. | ||
frule = Mooncake.build_frule( | ||
Mooncake.get_interpreter(), | ||
Tuple{typeof(low_level_gradient), typeof(rule), typeof(quadratic), Float64} | ||
) | ||
result = frule( | ||
zero_dual(low_level_gradient), | ||
zero_dual(rule), | ||
zero_dual(quadratic), | ||
Mooncake.Dual(5.0, 1.0), | ||
) | ||
@test tangent(result) == 2.0 | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pleased to say that this works -- we can successfully compute the second derivative using forwards-over-reverse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Should I get started on adding AutoForwardMooncake
to ADTypes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That way we could run DI tests with this branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That might actually be good at this point -- it would be good to have forward-mode + forward-over-reverse tests with DI before releasing, as that's how users will interact with it anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add interface functions to this PR for DI?
value_and_pushforward!!
prepare_pushforward_cache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gentle bump on this :)
This is a very rough backbone of forward mode AD, based on #386 and the existing reverse mode implementation.
Will's edits (apologies for editing your thing @gdalle -- I just want to make sure that the todo list is at the top of the PR):
Todo:
make FunctionWrappers work correctlynot going to do this in this PRis_primitive
separately for forwards and reverse pass.Once the above are complete, I'll request reviews.