Skip to content

Start forward mode AD #389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 147 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
be316ff
Start forward mode prototype
gdalle Nov 24, 2024
deac913
First working autodiff
gdalle Nov 24, 2024
9c96c8d
Docstring
gdalle Nov 24, 2024
136aff6
Apply suggestions from code review
gdalle Nov 24, 2024
f65cc53
Moving files around
gdalle Nov 24, 2024
053a8bb
Primitives already known
gdalle Nov 24, 2024
6d8ec04
Merge branch 'main' into gd/forward
gdalle Nov 25, 2024
a3107a8
Keep pushing forward (pun intended)
gdalle Nov 25, 2024
2836ac8
Still buggy, don't touch
gdalle Nov 25, 2024
09d63bd
Keep instruction mapping one to one
gdalle Nov 26, 2024
fa679eb
Use replace_call
gdalle Nov 26, 2024
a68257c
Ignore code cov
gdalle Nov 27, 2024
7a096ba
No Aqua piracies test
gdalle Nov 27, 2024
46c3e5a
Start control flow
gdalle Nov 28, 2024
ad3f98a
Fix intrinsic
gdalle Nov 28, 2024
9071574
Import
gdalle Nov 28, 2024
dcfe282
Typos
gdalle Nov 28, 2024
e44380d
Co-authored-by: Will Tebbutt <[email protected]>
gdalle Dec 6, 2024
dd89e57
Figure out incremental additions
gdalle Dec 6, 2024
9bdb57f
Initial test case additions
willtebbutt Dec 6, 2024
4bb9911
Formatting
willtebbutt Dec 6, 2024
9b037e7
Add verify_dual_type
willtebbutt Dec 6, 2024
6dea624
test_frule_interface runs
willtebbutt Dec 6, 2024
a614846
Fix ReturnNode
willtebbutt Dec 6, 2024
eadae95
Correctness testing runs
willtebbutt Dec 6, 2024
345b3fd
Add randn_dual
willtebbutt Dec 6, 2024
f58c394
Improve sin and cos frules
willtebbutt Dec 6, 2024
c8d8895
Performance tests run
willtebbutt Dec 6, 2024
578e41b
Tidy up implementation
willtebbutt Dec 6, 2024
b5d34b2
Standard testing infrastructure
willtebbutt Dec 6, 2024
205e716
Fix typos
willtebbutt Dec 6, 2024
d328db0
Fix return node to return dual
gdalle Dec 6, 2024
66a48c8
Handle PiNode
gdalle Dec 6, 2024
e455cf6
Deleted line
gdalle Dec 6, 2024
8d120b2
Case 7 solved
gdalle Jan 27, 2025
cd7167f
Resolve merge conflict
willtebbutt Jan 27, 2025
c5ffae7
Fix precompile issue
willtebbutt Jan 27, 2025
94aa904
Fix isa rule
willtebbutt Jan 27, 2025
cc7a3fa
Fix is_primitive
willtebbutt Jan 27, 2025
70d7183
More test cases
gdalle Feb 4, 2025
aec412e
progress
gdalle Feb 6, 2025
0ea1084
fixes
gdalle Feb 7, 2025
d8a949f
Bump patch vesion
willtebbutt Feb 12, 2025
79844d2
Fix terminators
willtebbutt Feb 12, 2025
49aa4ca
Merge remote-tracking branch 'upstream/wct/fix-terminator-issue' into…
gdalle Feb 12, 2025
9ce99ec
More cases
gdalle Feb 12, 2025
6ce2488
More cases
gdalle Feb 12, 2025
8954361
Merge remote-tracking branch 'upstream/main' into gd/forward
gdalle Feb 14, 2025
941a2de
Merge remote-tracking branch 'upstream/main' into gd/forward
gdalle Feb 14, 2025
af49eac
Tuple rule
gdalle Feb 14, 2025
0b4e5fa
Merge in main
willtebbutt Mar 14, 2025
8204665
Formatting
willtebbutt Mar 14, 2025
70fec10
Code to view forwards-mode IR from a signature
willtebbutt Mar 14, 2025
6cde147
Use widenconst to get actual argtype from ircode argtypes
willtebbutt Mar 14, 2025
0eabff0
MyInstruction -> new_instruction
willtebbutt Mar 14, 2025
8b391c6
Formatting
willtebbutt Mar 15, 2025
5d6b826
Merge branch 'main' into gd/forward
willtebbutt Mar 17, 2025
a919a28
Various improvements
willtebbutt Mar 18, 2025
2808a12
Rules for foreigncalls
willtebbutt Mar 19, 2025
cb28759
Fix pointer tests with forwards mode
willtebbutt Mar 19, 2025
f9d1697
Enable more tests
willtebbutt Mar 19, 2025
9bc53cc
All derivation tests pass
willtebbutt Mar 19, 2025
d6fc35d
Initial pass over legacy array functionality
willtebbutt Mar 20, 2025
6b2409c
Fix tangent usage in tests
willtebbutt Mar 20, 2025
d6974c1
Rules for nice BLAS functions
willtebbutt Mar 20, 2025
fbcc6ce
Tweak test inputs slightly
willtebbutt Mar 20, 2025
732762b
Enable CI for BLAS and foreigncalls
willtebbutt Mar 20, 2025
fd48f02
Enable linear_algebra rules
willtebbutt Mar 20, 2025
f6bc752
More stuff works
willtebbutt Mar 21, 2025
a96611e
Make IdDict work
willtebbutt Mar 21, 2025
44e78b4
Code to identify SSA uses
willtebbutt Mar 21, 2025
a504413
Fix failing test via special case
willtebbutt Mar 21, 2025
f68c79b
Remove outdated TODO note
willtebbutt Mar 21, 2025
05cbb83
Merge branch 'main' into gd/forward
willtebbutt Mar 21, 2025
30c5294
Fix typo
willtebbutt Mar 21, 2025
fe4ec4a
BLAS support nearly finished
willtebbutt Mar 23, 2025
f771f70
All BLAS rules passing
willtebbutt Mar 24, 2025
86fa1b6
Initial work on getrf
willtebbutt Mar 25, 2025
04ea669
Merge branch 'main' into gd/forward
willtebbutt Mar 25, 2025
e1a1260
getrf frule sketch
willtebbutt Mar 25, 2025
fda2ab9
Merge branch 'gd/forward' of https://github.com/gdalle/Mooncake.jl in…
willtebbutt Mar 25, 2025
37baaf0
Improve getrf performance
willtebbutt Mar 26, 2025
c0c4167
trtrs implementation + type stability checks
willtebbutt Mar 26, 2025
9a12b23
Type stability checks for BLAS rules
willtebbutt Mar 26, 2025
bb8feba
Note Seth's blog
willtebbutt Mar 26, 2025
64d6176
getrs frule implementation
willtebbutt Mar 27, 2025
be57d7f
getri frule implementation
willtebbutt Mar 27, 2025
2934409
potrs
willtebbutt Mar 27, 2025
fe289a0
Enable lapack CI
willtebbutt Mar 27, 2025
39354bc
Fix pivoting
willtebbutt Mar 28, 2025
8bcde33
Enable diff tests integration tests
willtebbutt Mar 28, 2025
497c907
Only run extra CI on 1
willtebbutt Mar 28, 2025
e1dce38
More lapack fixes
willtebbutt Mar 28, 2025
8739c6c
widenconst
willtebbutt Mar 28, 2025
899d4c4
Replace field access with method call
willtebbutt Mar 28, 2025
594ba13
Catch __vec_to_tuple edge case
willtebbutt Mar 28, 2025
0510235
Display more stuff when correctness test fails
willtebbutt Mar 28, 2025
4af2276
Enable more integration tests
willtebbutt Mar 28, 2025
83cd097
Make output on test error sensible
willtebbutt Mar 28, 2025
da3d7ee
Tidy up blas implementations
willtebbutt Mar 28, 2025
eee18dd
Fix pointerset error
willtebbutt Mar 28, 2025
9bd274d
Merge branch 'main' into gd/forward
willtebbutt Mar 28, 2025
3a4f70a
Fix ^ rule
willtebbutt Mar 28, 2025
5aed9b2
Implement from_chain_rule macro
willtebbutt Mar 29, 2025
f4f62c9
Get SpecialFunctions extension working
willtebbutt Mar 29, 2025
9c11e6a
Enable SpecialFunctions in CI
willtebbutt Mar 29, 2025
e19cb63
logexpfunctions
willtebbutt Mar 29, 2025
be93bfd
Run gpu jobs on 1.11 only
willtebbutt Mar 29, 2025
1d1e7e9
Restrict FD step for forward mode
willtebbutt Mar 29, 2025
f21b575
Enable GP tests
willtebbutt Mar 29, 2025
c691a92
More integration testing
willtebbutt Mar 29, 2025
b28961a
bijectors
willtebbutt Mar 29, 2025
b67f2c3
Enable battery of tests
willtebbutt Mar 29, 2025
2bdb0ad
Distributions integration testing
willtebbutt Mar 29, 2025
d4fa5c8
Enable DI CI
willtebbutt Mar 29, 2025
4902ce2
Enable reverse-mode integration tests for Lux etc
willtebbutt Mar 29, 2025
7f57a06
Enable 1.10
willtebbutt Mar 31, 2025
60e4d89
Fix LAPACK on 1.10
willtebbutt Mar 31, 2025
41bb3c3
Implement copytrito for 1.10
willtebbutt Mar 31, 2025
2140edd
formatting
willtebbutt Mar 31, 2025
05bac94
Merge branch 'main' into gd/forward
willtebbutt Mar 31, 2025
df0cf38
Tidying up
willtebbutt Mar 31, 2025
dac008f
Remove type piracy
willtebbutt Mar 31, 2025
48b61ec
Initial forwards-mode timings
willtebbutt Mar 31, 2025
3d9f9bf
Merge in main
willtebbutt May 11, 2025
05d3c65
Constrain JuliaInterpreter
willtebbutt May 26, 2025
df0d2d7
Basic MistyClosure support
willtebbutt May 26, 2025
ed912eb
Merge in main
willtebbutt May 26, 2025
b9c5f7e
Do not use MistyClosure internals inside reverse-mode
willtebbutt Jun 4, 2025
6990348
Forwards-over-reverse mwe
willtebbutt Jun 4, 2025
941e171
Remove overly strict performance check
willtebbutt Jun 4, 2025
180b43e
Docstring and improved field naming
willtebbutt Jun 4, 2025
2bbf98d
Separate forward-mode and reverse-mode primitives
willtebbutt Jun 16, 2025
f9151ed
Fix docs and rrule creation
willtebbutt Jun 16, 2025
2d52a2c
Fix low_level_maths
willtebbutt Jun 16, 2025
dd023e7
Fix SpecialFunctions tests cases
willtebbutt Jun 16, 2025
31b1733
Fix more testing
willtebbutt Jun 16, 2025
bfa2476
Fix formatting
willtebbutt Jun 16, 2025
27cd7c1
Make symbols available in tests
willtebbutt Jun 16, 2025
8079137
Fix GP test suite
willtebbutt Jun 16, 2025
f37626f
Fix SpecialFunctions test suite
willtebbutt Jun 16, 2025
2df8d36
Merge branch 'main' into gd/forward
willtebbutt Jun 16, 2025
2275845
Fix performance
willtebbutt Jun 16, 2025
720e410
Fix array tests
willtebbutt Jun 16, 2025
0d16652
Fix formatting
willtebbutt Jun 16, 2025
3aaac95
Fix forward-mode benchmarking
willtebbutt Jun 16, 2025
75f8a76
Fix benchmarking
willtebbutt Jun 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/Mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using Base:
arrayset, TwicePrecision, twiceprecision
using Base.Experimental: @opaque
using Base.Iterators: product
using Base.Meta: isexpr
using Core:
Intrinsics, bitcast, SimpleVector, svec, ReturnNode, GotoNode, GotoIfNot, PhiNode,
PiNode, SSAValue, Argument, OpaqueClosure, compilerbarrier
Expand All @@ -34,6 +35,13 @@ using FunctionWrappers: FunctionWrapper
# Needs to be defined before various other things.
function _foreigncall_ end

"""
frule!!(f::Dual, x::Dual...)

Performs AD in forward mode, possibly modifying the inputs, and returns a `Dual`.
"""
function frule!! end

"""
rrule!!(f::CoDual, x::CoDual...)

Expand Down Expand Up @@ -61,6 +69,7 @@ function rrule!! end

include("utils.jl")
include("tangents.jl")
include("dual.jl")
include("fwds_rvs_data.jl")
include("codual.jl")
include("debug_mode.jl")
Expand All @@ -72,12 +81,15 @@ include(joinpath("interpreter", "ir_utils.jl"))
include(joinpath("interpreter", "bbcode.jl"))
include(joinpath("interpreter", "ir_normalisation.jl"))
include(joinpath("interpreter", "zero_like_rdata.jl"))
include(joinpath("interpreter", "s2s_forward_mode_ad.jl"))
include(joinpath("interpreter", "s2s_reverse_mode_ad.jl"))

include("tools_for_rules.jl")
include("test_utils.jl")
include("test_resources.jl")

include(joinpath("frules", "basic.jl"))

include(joinpath("rrules", "avoiding_non_differentiable_code.jl"))
include(joinpath("rrules", "blas.jl"))
include(joinpath("rrules", "builtins.jl"))
Expand Down Expand Up @@ -118,9 +130,13 @@ export
_add_to_primal,
_diff,
_dot,
Dual,
zero_dual,
zero_codual,
codual_type,
frule!!,
rrule!!,
build_frule,
build_rrule,
value_and_gradient!!,
value_and_pullback!!,
Expand Down
1 change: 1 addition & 0 deletions src/debug_mode.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
DebugFRule(rule) = rule # TODO: make it non-trivial

Check warning on line 1 in src/debug_mode.jl

View check run for this annotation

Codecov / codecov/patch

src/debug_mode.jl#L1

Added line #L1 was not covered by tests

"""
DebugPullback(pb, y, x)
Expand Down
13 changes: 13 additions & 0 deletions src/dual.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
struct Dual{P,T}
primal::P
tangent::T
end

primal(x::Dual) = x.primal
tangent(x::Dual) = x.tangent
Base.copy(x::Dual) = Dual(copy(primal(x)), copy(tangent(x)))
_copy(x::P) where {P<:Dual} = x

Check warning on line 9 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L6-L9

Added lines #L6 - L9 were not covered by tests

zero_dual(x) = Dual(x, zero_tangent(x))

Check warning on line 11 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L11

Added line #L11 was not covered by tests

dual_type(::Type{P}) where {P} = Dual{P,tangent_type(P)}

Check warning on line 13 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L13

Added line #L13 was not covered by tests
11 changes: 11 additions & 0 deletions src/frules/basic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
frule!!(f::F, args::Vararg{Dual,N}) where {F,N} = frule!!(zero_dual(f), args...)

Check warning on line 1 in src/frules/basic.jl

View check run for this annotation

Codecov / codecov/patch

src/frules/basic.jl#L1

Added line #L1 was not covered by tests

@is_primitive MinimalCtx Tuple{typeof(sin),Number}
function frule!!(::Dual{typeof(sin)}, x::Dual{<:Number})
return Dual(sin(primal(x)), cos(primal(x)) * tangent(x))

Check warning on line 5 in src/frules/basic.jl

View check run for this annotation

Codecov / codecov/patch

src/frules/basic.jl#L4-L5

Added lines #L4 - L5 were not covered by tests
end

@is_primitive MinimalCtx Tuple{typeof(cos),Number}
function frule!!(::Dual{typeof(cos)}, x::Dual{<:Number})
return Dual(cos(primal(x)), -sin(primal(x)) * tangent(x))

Check warning on line 10 in src/frules/basic.jl

View check run for this annotation

Codecov / codecov/patch

src/frules/basic.jl#L9-L10

Added lines #L9 - L10 were not covered by tests
end
147 changes: 147 additions & 0 deletions src/interpreter/s2s_forward_mode_ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
function build_frule(args...; debug_mode=false)
interp = get_interpreter()
sig = _typeof(TestUtils.__get_primals(args))
return build_frule(interp, sig; debug_mode)

Check warning on line 4 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L1-L4

Added lines #L1 - L4 were not covered by tests
end

function build_frule(

Check warning on line 7 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L7

Added line #L7 was not covered by tests
interp::MooncakeInterpreter{C},
sig_or_mi;
debug_mode=false,
silence_debug_messages=true,
) where {C}
# To avoid segfaults, ensure that we bail out if the interpreter's world age is greater
# than the current world age.
if Base.get_world_counter() > interp.world
throw(

Check warning on line 16 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L15-L16

Added lines #L15 - L16 were not covered by tests
ArgumentError(
"World age associated to interp is behind current world age. Please " *
"a new interpreter for the current world age.",
),
)
end

# If we're compiling in debug mode, let the user know by default.
if !silence_debug_messages && debug_mode
@info "Compiling rule for $sig_or_mi in debug mode. Disable for best performance."

Check warning on line 26 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L25-L26

Added lines #L25 - L26 were not covered by tests
end

# If we have a hand-coded rule, just use that.
_is_primitive(C, sig_or_mi) && return (debug_mode ? DebugFRule(frule!!) : frule!!)

Check warning on line 30 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L30

Added line #L30 was not covered by tests


# We don't have a hand-coded rule, so derived one.
lock(MOONCAKE_INFERENCE_LOCK)
try

Check warning on line 35 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L34-L35

Added lines #L34 - L35 were not covered by tests
# If we've already derived the OpaqueClosures and info, do not re-derive, just
# create a copy and pass in new shared data.
oc_cache_key = ClosureCacheKey(interp.world, (sig_or_mi, debug_mode))
if haskey(interp.oc_cache, oc_cache_key)
return _copy(interp.oc_cache[oc_cache_key])

Check warning on line 40 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L38-L40

Added lines #L38 - L40 were not covered by tests
else
# Derive forward-pass IR, and shove in a `MistyClosure`.
forward_ir = generate_forward_ir(interp, sig_or_mi; debug_mode)
fwd_oc = MistyClosure(forward_ir; do_compile=true)
raw_rule = DerivedFRule(fwd_oc)
rule = debug_mode ? DebugFRule(raw_rule) : raw_rule
interp.oc_cache[oc_cache_key] = rule
return rule

Check warning on line 48 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L43-L48

Added lines #L43 - L48 were not covered by tests
end
catch e
rethrow(e)

Check warning on line 51 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L51

Added line #L51 was not covered by tests
finally
unlock(MOONCAKE_INFERENCE_LOCK)

Check warning on line 53 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L53

Added line #L53 was not covered by tests
end
end

function generate_forward_ir(

Check warning on line 57 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L57

Added line #L57 was not covered by tests
interp::MooncakeInterpreter,
sig_or_mi;
debug_mode=false,
do_inline=true,
)
# Reset id count. This ensures that the IDs generated are the same each time this
# function runs.
seed_id!()

Check warning on line 65 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L65

Added line #L65 was not covered by tests

# Grab code associated to the primal.
primal_ir, _ = lookup_ir(interp, sig_or_mi)

Check warning on line 68 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L68

Added line #L68 was not covered by tests

# Normalise the IR.
isva, spnames = is_vararg_and_sparam_names(sig_or_mi)
ir = normalise!(primal_ir, spnames)

Check warning on line 72 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L71-L72

Added lines #L71 - L72 were not covered by tests

fwd_ir = dualize_ir(ir)
opt_fwd_ir = optimise_ir!(fwd_ir; do_inline)
return opt_fwd_ir

Check warning on line 76 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L74-L76

Added lines #L74 - L76 were not covered by tests
end

function dualize_ir(ir::IRCode)
new_stmts_stmt = map(make_fwd_ad_stmt, ir.stmts.stmt)
new_stmts_type = map(dual_type, ir.stmts.type)
new_stmts_info = ir.stmts.info
new_stmts_line = ir.stmts.line
new_stmts_flag = ir.stmts.flag
new_stmts = CC.InstructionStream(

Check warning on line 85 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L79-L85

Added lines #L79 - L85 were not covered by tests
new_stmts_stmt,
new_stmts_type,
new_stmts_info,
new_stmts_line,
new_stmts_flag,
)
new_cfg = ir.cfg
new_linetable = ir.linetable
rule_type = Any
new_argtypes = convert(Vector{Any}, vcat(rule_type, map(make_fwd_argtype, ir.argtypes)))
new_meta = ir.meta
new_sptypes = ir.sptypes
return IRCode(new_stmts, new_cfg, new_linetable, new_argtypes, new_meta, new_sptypes)

Check warning on line 98 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L92-L98

Added lines #L92 - L98 were not covered by tests
end

make_fwd_argtype(::Type{P}) where {P} = dual_type(P)
make_fwd_argtype(c::Core.Const) = Dual # TODO: refine to type of const

Check warning on line 102 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L101-L102

Added lines #L101 - L102 were not covered by tests

function make_fwd_ad_stmt(stmt::Expr)
interp = get_interpreter() # TODO: pass it around
C = context_type(interp)
if isexpr(stmt, :invoke) || isexpr(stmt, :call)
mi = stmt.args[1]::Core.MethodInstance
sig = mi.specTypes
if is_primitive(C, sig)
shifted_args = map(stmt.args) do a
if a isa Core.Argument
Core.Argument(a.n + 1)

Check warning on line 113 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L104-L113

Added lines #L104 - L113 were not covered by tests
else
a

Check warning on line 115 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L115

Added line #L115 was not covered by tests
end
end
new_stmt = Expr(

Check warning on line 118 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L118

Added line #L118 was not covered by tests
:call,
:($frule!!),
stmt.args[2],
shifted_args[3:end]...
)
return new_stmt

Check warning on line 124 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L124

Added line #L124 was not covered by tests
else
throw(ArgumentError("Recursing into non-primitive calls is not yet supported in forward mode"))

Check warning on line 126 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L126

Added line #L126 was not covered by tests
end
return stmt

Check warning on line 128 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L128

Added line #L128 was not covered by tests
else
throw(ArgumentError("Expressions of type `:$(stmt.head)` are not yet supported in forward mode"))

Check warning on line 130 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L130

Added line #L130 was not covered by tests
end
return stmt

Check warning on line 132 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L132

Added line #L132 was not covered by tests
end

function make_fwd_ad_stmt(stmt::ReturnNode)
return stmt

Check warning on line 136 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L135-L136

Added lines #L135 - L136 were not covered by tests
end

struct DerivedFRule{Tfwd_oc}
fwd_oc::Tfwd_oc
end

_copy(rule::DerivedFRule) = deepcopy(rule)

Check warning on line 143 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L143

Added line #L143 was not covered by tests

@inline function (fwd::DerivedFRule)(args::Vararg{Dual,N}) where {N}
return fwd.fwd_oc.oc(args...)

Check warning on line 146 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L145-L146

Added lines #L145 - L146 were not covered by tests
end
2 changes: 1 addition & 1 deletion src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ function test_rrule_performance(
end
end

__get_primals(xs) = map(x -> x isa CoDual ? primal(x) : x, xs)
__get_primals(xs) = map(x -> x isa Union{Dual,CoDual} ? primal(x) : x, xs)

@doc"""
test_rule(
Expand Down
27 changes: 27 additions & 0 deletions test/forward.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using Mooncake
using Test

x, dx = 2.0, 3.0
xdual = Dual(x, dx)

@testset "Manual frule" begin
sin_rule = build_frule(sin, x)
ydual = sin_rule(zero_dual(sin), xdual)

@test primal(ydual) == sin(x)
@test tangent(ydual) == dx * cos(x)
end

function func(x)
y = sin(x)
z = cos(y)
return z
end

@testset "Automatic frule" begin
func_rule = build_frule(func, x)
ydual = func_rule(zero_dual(func), xdual)

@test primal(ydual) == cos(sin(x))
@test tangent(ydual) ≈ dx * -sin(sin(x)) * cos(x)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ include("front_matter.jl")
include("config.jl")
include("developer_tools.jl")
include("test_utils.jl")
include("forward.jl")
elseif test_group == "rrules/avoiding_non_differentiable_code"
include(joinpath("rrules", "avoiding_non_differentiable_code.jl"))
elseif test_group == "rrules/blas"
Expand Down
Loading