Skip to content

Start forward mode AD #389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 131 commits into
base: main
Choose a base branch
from
Draft

Start forward mode AD #389

wants to merge 131 commits into from

Conversation

gdalle
Copy link
Collaborator

@gdalle gdalle commented Nov 24, 2024

This is a very rough backbone of forward mode AD, based on #386 and the existing reverse mode implementation.

Will's edits (apologies for editing your thing @gdalle -- I just want to make sure that the todo list is at the top of the PR):

Todo:

  • make FunctionWrappers work correctly not going to do this in this PR
  • add support for MistyClosures
  • add tests for Hessian vector products
  • define is_primitive separately for forwards and reverse pass.
  • do a complete pass to review design -- are there any high-level things we ought to modify?
  • improve DRY-ness of code, particularly in testing infrastructure in particular.

Once the above are complete, I'll request reviews.

Copy link

codecov bot commented Nov 24, 2024

Codecov Report

Attention: Patch coverage is 94.04070% with 82 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/interpreter/s2s_forward_mode_ad.jl 88.77% 22 Missing ⚠️
src/test_utils.jl 86.66% 16 Missing ⚠️
src/rrules/foreigncall.jl 75.75% 8 Missing ⚠️
src/rrules/memory.jl 87.69% 8 Missing ⚠️
src/utils.jl 76.92% 6 Missing ⚠️
src/rrules/tasks.jl 64.28% 5 Missing ⚠️
src/dual.jl 85.71% 3 Missing ⚠️
src/rrules/builtins.jl 97.82% 3 Missing ⚠️
src/developer_tools.jl 0.00% 2 Missing ⚠️
src/interpreter/s2s_reverse_mode_ad.jl 71.42% 2 Missing ⚠️
... and 5 more
Files with missing lines Coverage Δ
src/Mooncake.jl 100.00% <ø> (ø)
src/interpreter/ir_utils.jl 89.68% <100.00%> (+2.81%) ⬆️
src/rrules/array_legacy.jl 100.00% <100.00%> (ø)
src/rrules/avoiding_non_differentiable_code.jl 100.00% <100.00%> (ø)
src/rrules/blas.jl 99.64% <100.00%> (+0.84%) ⬆️
src/rrules/fastmath.jl 100.00% <100.00%> (ø)
src/rrules/lapack.jl 100.00% <100.00%> (+0.56%) ⬆️
src/rrules/linear_algebra.jl 100.00% <100.00%> (ø)
src/rrules/low_level_maths.jl 100.00% <100.00%> (ø)
src/rrules/new.jl 91.30% <100.00%> (+2.84%) ⬆️
... and 20 more

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. I've left a few comments, but if you're planning to do a bunch of additional stuff, then maybe they're redundant. Either way, don't feel the need to respond to them.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

@willtebbutt following our discussion yesterday I scratched my head some more, and I decided that it would be infinitely simpler to enforce the invariant that one line of primal IR maps to one line of dual IR. While this may require additional fallbacks in the Julia code itself, I hope it will make our lives much easier on the IR side. What do you think?

@willtebbutt
Copy link
Collaborator

I think this could work.

You could just replace the frule!! calls with a call to a function call_frule!! which would be something like

@inline function call_frule!!(rule::R, fargs::Vararg{Any, N}) where {N}
    return rule(map(x -> x isa Dual ? x : zero_dual(x), fargs)...)
end

The optimisation pass will lower this to the what we were thinking about writing out in the IR anyway.

I think the other important kinds of nodes would be largely straightforward to handle.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

I think we might need to be slightly more subtle. If an argument to the :call or :invoke expression is a CC.Argument or a CC.SSAValue, we don't wrap it in a Dual because we assume it will already be one, right?

@willtebbutt
Copy link
Collaborator

willtebbutt commented Nov 26, 2024

Yes. I think my propose code handles this though, or am I missing something?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

In the spirit of higher-order AD, we may encounter Dual inputs that we want to wrap with a second Dual, and Dual inputs that we want to leave as-is. So I think this wrapping needs to be decided from the type of each argument in the IR?

@willtebbutt
Copy link
Collaborator

Very good point.

So I think this wrapping needs to be decided from the type of each argument in the IR?

Agreed. Specifically, I think we need to distinguish between literals / QuoteNodes / GlobalRefs, and Argument / SSAValues?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

I still need to dig into the different node types we might encounter (and I still don't understand QuoteNodes) but yeah, Argument and SSAValue don't need to be wrapped.

@gdalle gdalle mentioned this pull request Nov 27, 2024
@willtebbutt
Copy link
Collaborator

I was reviewing the design docs and realised that, sadly, the "one line of primal IR maps to one line of dual IR" won't work for Core.GotoIfNot nodes. See https://compintell.github.io/Mooncake.jl/previews/PR386/developer_documentation/forwards_mode_design/#Statement-Transformation .

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

I think that's okay, the main trouble is adding new lines which insert new variables because it requires manual renumbering. A GoTo should be much simpler.

@willtebbutt
Copy link
Collaborator

Were the difficulties around renumbering etc not resolved by not compact!ing until the end? I feel like I might be missing something.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

No they weren't. I experimented with compact! in various places and I was struggling a lot, so I asked Frames for advice. She agreed that insertion should usually be avoided.
If we have to insert something for GoTo, I think it will still be easier because we're not defining a new SSAValue so we don't have to adapt future statements that refer to it.

@willtebbutt
Copy link
Collaborator

willtebbutt commented Nov 27, 2024

Ah, right, but we do need to insert a new SSAValue. Suppose that the GotoIfNot of interest is

GotoIfNot(%5, #3)

i.e. jump to block 3 if not %5. In the forwards-mode IR this would become

%new_ssa = Expr(:call, primal, %5)
GotoIfNot(%new_ssa, #3)

Does this not cause the same kind of problems?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

Oh yes you're probably right. Although it might be slightly less of a hassle because the new SSA is only used in one spot, right after. I'll take a look

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

Do you know what I should do about expressions of type :code_coverage_effect? I assume they're inserted automatically and they're alone on their lines?

@willtebbutt
Copy link
Collaborator

willtebbutt commented Nov 27, 2024

Yup -- I just strip them out of the IR entirely in reverse-mode. See https://github.com/compintell/Mooncake.jl/blob/0f37c079bd1ae064e7b84696eed4a1f7eb763f1f/src/interpreter/s2s_reverse_mode_ad.jl#L728

The way to remove an instruction from an IRCode is just to replace the instruction with nothing.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

I think this works for GotoIfNot:

  1. make all the insertions necessary
  2. compact! once to make sure they applied
  3. shift the conditions of all GotoIfNot nodes to refer to the node right before them (where we get the primal value of the condition)

MWE (requires this branch of Mooncake):

const CC = Core.Compiler
using Mooncake
using MistyClosures

f(x) = x > 1 ? 2x : 3 + x
ir = Base.code_ircode(f, (Float64,))[1][1]
initial_ir = copy(ir)
get_primal_inst = CC.NewInstruction(Expr(:call, +, 1, 2), Any)  # placeholder for get_primal
CC.insert_node!(ir, CC.SSAValue(3), get_primal_inst, false)
ir = CC.compact!(ir)
for k in 1:length(ir.stmts)
    inst = ir[CC.SSAValue(k)][:stmt]
    if inst isa Core.GotoIfNot
        Mooncake.replace_call!(ir,CC.SSAValue(k), Core.GotoIfNot(CC.SSAValue(k-1), inst.dest))
    end
end
ir
julia> initial_ir
5 1%1 = Base.lt_float(1.0, _2)::Bool                                                                                 │╻╷╷ >%2 = Base.or_int(%1, false)::Bool                                                                                 ││╻   <
  └──      goto #3 if not %2                                                                                            │   
  2%4 = Base.mul_float(2.0, _2)::Float64                                                                             ││╻   *
  └──      return %43%6 = Base.add_float(3.0, _2)::Float64                                                                             ││╻   +
  └──      return %6                                                                                                    │   
                                                                                                                            

julia> ir
5 1%1 = Base.lt_float(1.0, _2)::Bool                                                                                 │╻╷╷ >
  │        Base.or_int(%1, false)::Bool                                                                                 ││╻   <%3 = (+)(1, 2)::Any                                                                                               │   
  └──      goto #3 if not %3                                                                                            │   
  2%5 = Base.mul_float(2.0, _2)::Float64                                                                             ││╻   *
  └──      return %53%7 = Base.add_float(3.0, _2)::Float64                                                                             ││╻   +
  └──      return %7      

@willtebbutt
Copy link
Collaborator

willtebbutt commented Mar 31, 2025

Just requires implementing forwards-mode for FunctionWrappers, then will be ready for review.

edit: also for MistyClosures, and do test this works by computing some Hessian-vector products!

@from_rrule DefaultCtx Tuple{typeof(cosint),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(ellipk),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(ellipe),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(airyai),IEEEFloat}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A relatively minor comment: from_chainrules is more precise than from_chain_rule. The former clarifies that we are importing a rule from ChainRules, while the latter mislead me since I thought it refers to the generic chain rule terminology.

Suggested change
@from_chain_rule DefaultCtx Tuple{typeof(airyai),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(airyai),IEEEFloat}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, interesting. I have no strong view either way, so I'm happy to change it if you think it's from_chainerules is clearer.

@willtebbutt
Copy link
Collaborator

willtebbutt commented Apr 1, 2025

@gdalle is there going to be an easy way to test Hessian-vector products using DI before we have the various things sorted in ADTypes / DI that we need to have sorted in order to make forwards-mode Mooncake via DI work. I'm just wondering to try and include Hessian-vector products computed via DI as test cases in this PR, or to punt it until a later date.

@gdalle
Copy link
Collaborator Author

gdalle commented Apr 1, 2025

Yeah no it's not gonna be completely straightforward. But doing the ADTypes changes is a matter of minutes

@codecov-commenter
Copy link

codecov-commenter commented May 11, 2025

@gdalle
Copy link
Collaborator Author

gdalle commented May 11, 2025

Hey @willtebbutt! Anything I can do to help bring this over the finish line?

@willtebbutt
Copy link
Collaborator

I'm going to do a pass over the remaining items (listed at the top) tonight, and will report back. Probably the most helpful thing really will be reviewing -- it's such a large PR that it will be hard for anyone who hasn't been involved in its development.

@gdalle
Copy link
Collaborator Author

gdalle commented May 12, 2025

In particular, should we tie together the definition of is_primitive for forwards-mode and reverse-mode, or permit something to be primitive in just one?

I think it would make sense to keep the two separate, for the case where defining rules is not strictly necessary but useful to enhance performance. For example, someone may want to define a reverse rule for an optimization solver to avoid backpropagating through every iteration, but decide that the forward-mode behavior is good enough.

@willtebbutt
Copy link
Collaborator

I think it would make sense to keep the two separate, for the case where defining rules is not strictly necessary but useful to enhance performance. For example, someone may want to define a reverse rule for an optimization solver to avoid backpropagating through every iteration, but decide that the forward-mode behavior is good enough.

I completely agree. I'm going to edit the todo item to reflect this.

@gdalle
Copy link
Collaborator Author

gdalle commented May 20, 2025

Gentle bump on this one :) @willtebbutt do you need an ADTypes object / some DI infrastructure for HVPs?

@willtebbutt
Copy link
Collaborator

Apologies @gdalle -- I've made some progress locally on this, but have yet to push changes. I don't think I need an ADType object yet. My feeling is that it's probably best to do this as part of some follow up work. My approach in this PR is to:

  1. check that forwards mode can differentiate simple MistyClosures, and then
  2. check that forwards mode can differentiate the specific MistyClosures produced by reverse-mode AD.

I think I'm most of the way with 1, but getting 2 to work will be the real test.

@gdalle
Copy link
Collaborator Author

gdalle commented May 20, 2025

Is it such a big deal if nested differentiation is only part of a later push? Adding forward mode in addition to reverse mode (not necessarily on top of it) would already be useful for the community.
Unless you think that getting nested differentiation to work will require significant changes to the design, which may lead to breaking the user-facing aspects?

@willtebbutt
Copy link
Collaborator

To my mind, much of the utility of forwards-mode derives from its ability to be combined with reverse-mode to compute HVPs. Usually I'd be in favour of incrementally adding stuff over the course of a few PRs, but this is such an important feature that I'd rather not merge any forwards-mode stuff until we can apply it over reverse-mode.

@gdalle
Copy link
Collaborator Author

gdalle commented May 20, 2025

To my mind, much of the utility of forwards-mode derives from its ability to be combined with reverse-mode to compute HVPs.

That's because you're an optimization kind of person, but you need to look further ;)
For instance, SciML mostly cares about efficient forward-mode (sparse) Jacobians inside OrdinaryDiffEq and friends. It's sometimes hard to get Enzyme to work there, so the current options are mostly ForwardDiff and FiniteDiff, mediated through DifferentiationInterface. A forward mode in Mooncake would be an interesting fourth option in such cases.

@willtebbutt
Copy link
Collaborator

That's a fair point. Okay, my proposal is this: if I've not managed to get a basic forwards-over-reverse example working by the end of the week, we punt it to a subsequent PR. I agree that it would be very nice to get this merged sooner rather than later.

@gdalle
Copy link
Collaborator Author

gdalle commented May 20, 2025

Sounds good! Let me know when this is review-ready

@MasonProtter
Copy link

MasonProtter commented May 30, 2025

Just chiming in that I ended up here in this thread today specifically because I had a SciML OrdinaryDiffeEq use-case where I wanted Forward-Mode AD, but ForwardDiff.jl would be difficult to use, and Enzyme.jl complained too much.

Forward-over-reverse is of course always nice, but not the only usecase for forward mode!

Exciting to see progress being made.

@yebai
Copy link
Member

yebai commented May 30, 2025

It would be great to make a fresh push to finish off this PR, @willtebbutt. Despite the forward-over-reverse functionality, there are a few other issues to address for a robust design and codebase, as listed above.

Comment on lines +66 to +89
# Construct a callable which performs reverse-mode, and apply forwards-mode over it.
rule = Mooncake.build_rrule(Tuple{typeof(quadratic), Float64})
TestUtils.test_rule(
StableRNG(123), low_level_gradient, rule, quadratic, 5.0;
interface_only=false,
is_primitive=false,
perf_flag=:none,
unsafe_perturb=true,
forward=true,
)

# Manually test that this correectly computes the second derivative.
frule = Mooncake.build_frule(
Mooncake.get_interpreter(),
Tuple{typeof(low_level_gradient), typeof(rule), typeof(quadratic), Float64}
)
result = frule(
zero_dual(low_level_gradient),
zero_dual(rule),
zero_dual(quadratic),
Mooncake.Dual(5.0, 1.0),
)
@test tangent(result) == 2.0
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pleased to say that this works -- we can successfully compute the second derivative using forwards-over-reverse.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Should I get started on adding AutoForwardMooncake to ADTypes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That way we could run DI tests with this branch

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That might actually be good at this point -- it would be good to have forward-mode + forward-over-reverse tests with DI before releasing, as that's how users will interact with it anyway.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add interface functions to this PR for DI?

  • value_and_pushforward!!
  • prepare_pushforward_cache

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gentle bump on this :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants