Skip to content

Start forward mode AD #389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 131 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
131 commits
Select commit Hold shift + click to select a range
be316ff
Start forward mode prototype
gdalle Nov 24, 2024
deac913
First working autodiff
gdalle Nov 24, 2024
9c96c8d
Docstring
gdalle Nov 24, 2024
136aff6
Apply suggestions from code review
gdalle Nov 24, 2024
f65cc53
Moving files around
gdalle Nov 24, 2024
053a8bb
Primitives already known
gdalle Nov 24, 2024
6d8ec04
Merge branch 'main' into gd/forward
gdalle Nov 25, 2024
a3107a8
Keep pushing forward (pun intended)
gdalle Nov 25, 2024
2836ac8
Still buggy, don't touch
gdalle Nov 25, 2024
09d63bd
Keep instruction mapping one to one
gdalle Nov 26, 2024
fa679eb
Use replace_call
gdalle Nov 26, 2024
a68257c
Ignore code cov
gdalle Nov 27, 2024
7a096ba
No Aqua piracies test
gdalle Nov 27, 2024
46c3e5a
Start control flow
gdalle Nov 28, 2024
ad3f98a
Fix intrinsic
gdalle Nov 28, 2024
9071574
Import
gdalle Nov 28, 2024
dcfe282
Typos
gdalle Nov 28, 2024
e44380d
Co-authored-by: Will Tebbutt <[email protected]>
gdalle Dec 6, 2024
dd89e57
Figure out incremental additions
gdalle Dec 6, 2024
9bdb57f
Initial test case additions
willtebbutt Dec 6, 2024
4bb9911
Formatting
willtebbutt Dec 6, 2024
9b037e7
Add verify_dual_type
willtebbutt Dec 6, 2024
6dea624
test_frule_interface runs
willtebbutt Dec 6, 2024
a614846
Fix ReturnNode
willtebbutt Dec 6, 2024
eadae95
Correctness testing runs
willtebbutt Dec 6, 2024
345b3fd
Add randn_dual
willtebbutt Dec 6, 2024
f58c394
Improve sin and cos frules
willtebbutt Dec 6, 2024
c8d8895
Performance tests run
willtebbutt Dec 6, 2024
578e41b
Tidy up implementation
willtebbutt Dec 6, 2024
b5d34b2
Standard testing infrastructure
willtebbutt Dec 6, 2024
205e716
Fix typos
willtebbutt Dec 6, 2024
d328db0
Fix return node to return dual
gdalle Dec 6, 2024
66a48c8
Handle PiNode
gdalle Dec 6, 2024
e455cf6
Deleted line
gdalle Dec 6, 2024
8d120b2
Case 7 solved
gdalle Jan 27, 2025
cd7167f
Resolve merge conflict
willtebbutt Jan 27, 2025
c5ffae7
Fix precompile issue
willtebbutt Jan 27, 2025
94aa904
Fix isa rule
willtebbutt Jan 27, 2025
cc7a3fa
Fix is_primitive
willtebbutt Jan 27, 2025
70d7183
More test cases
gdalle Feb 4, 2025
aec412e
progress
gdalle Feb 6, 2025
0ea1084
fixes
gdalle Feb 7, 2025
d8a949f
Bump patch vesion
willtebbutt Feb 12, 2025
79844d2
Fix terminators
willtebbutt Feb 12, 2025
49aa4ca
Merge remote-tracking branch 'upstream/wct/fix-terminator-issue' into…
gdalle Feb 12, 2025
9ce99ec
More cases
gdalle Feb 12, 2025
6ce2488
More cases
gdalle Feb 12, 2025
8954361
Merge remote-tracking branch 'upstream/main' into gd/forward
gdalle Feb 14, 2025
941a2de
Merge remote-tracking branch 'upstream/main' into gd/forward
gdalle Feb 14, 2025
af49eac
Tuple rule
gdalle Feb 14, 2025
0b4e5fa
Merge in main
willtebbutt Mar 14, 2025
8204665
Formatting
willtebbutt Mar 14, 2025
70fec10
Code to view forwards-mode IR from a signature
willtebbutt Mar 14, 2025
6cde147
Use widenconst to get actual argtype from ircode argtypes
willtebbutt Mar 14, 2025
0eabff0
MyInstruction -> new_instruction
willtebbutt Mar 14, 2025
8b391c6
Formatting
willtebbutt Mar 15, 2025
5d6b826
Merge branch 'main' into gd/forward
willtebbutt Mar 17, 2025
a919a28
Various improvements
willtebbutt Mar 18, 2025
2808a12
Rules for foreigncalls
willtebbutt Mar 19, 2025
cb28759
Fix pointer tests with forwards mode
willtebbutt Mar 19, 2025
f9d1697
Enable more tests
willtebbutt Mar 19, 2025
9bc53cc
All derivation tests pass
willtebbutt Mar 19, 2025
d6fc35d
Initial pass over legacy array functionality
willtebbutt Mar 20, 2025
6b2409c
Fix tangent usage in tests
willtebbutt Mar 20, 2025
d6974c1
Rules for nice BLAS functions
willtebbutt Mar 20, 2025
fbcc6ce
Tweak test inputs slightly
willtebbutt Mar 20, 2025
732762b
Enable CI for BLAS and foreigncalls
willtebbutt Mar 20, 2025
fd48f02
Enable linear_algebra rules
willtebbutt Mar 20, 2025
f6bc752
More stuff works
willtebbutt Mar 21, 2025
a96611e
Make IdDict work
willtebbutt Mar 21, 2025
44e78b4
Code to identify SSA uses
willtebbutt Mar 21, 2025
a504413
Fix failing test via special case
willtebbutt Mar 21, 2025
f68c79b
Remove outdated TODO note
willtebbutt Mar 21, 2025
05cbb83
Merge branch 'main' into gd/forward
willtebbutt Mar 21, 2025
30c5294
Fix typo
willtebbutt Mar 21, 2025
fe4ec4a
BLAS support nearly finished
willtebbutt Mar 23, 2025
f771f70
All BLAS rules passing
willtebbutt Mar 24, 2025
86fa1b6
Initial work on getrf
willtebbutt Mar 25, 2025
04ea669
Merge branch 'main' into gd/forward
willtebbutt Mar 25, 2025
e1a1260
getrf frule sketch
willtebbutt Mar 25, 2025
fda2ab9
Merge branch 'gd/forward' of https://github.com/gdalle/Mooncake.jl in…
willtebbutt Mar 25, 2025
37baaf0
Improve getrf performance
willtebbutt Mar 26, 2025
c0c4167
trtrs implementation + type stability checks
willtebbutt Mar 26, 2025
9a12b23
Type stability checks for BLAS rules
willtebbutt Mar 26, 2025
bb8feba
Note Seth's blog
willtebbutt Mar 26, 2025
64d6176
getrs frule implementation
willtebbutt Mar 27, 2025
be57d7f
getri frule implementation
willtebbutt Mar 27, 2025
2934409
potrs
willtebbutt Mar 27, 2025
fe289a0
Enable lapack CI
willtebbutt Mar 27, 2025
39354bc
Fix pivoting
willtebbutt Mar 28, 2025
8bcde33
Enable diff tests integration tests
willtebbutt Mar 28, 2025
497c907
Only run extra CI on 1
willtebbutt Mar 28, 2025
e1dce38
More lapack fixes
willtebbutt Mar 28, 2025
8739c6c
widenconst
willtebbutt Mar 28, 2025
899d4c4
Replace field access with method call
willtebbutt Mar 28, 2025
594ba13
Catch __vec_to_tuple edge case
willtebbutt Mar 28, 2025
0510235
Display more stuff when correctness test fails
willtebbutt Mar 28, 2025
4af2276
Enable more integration tests
willtebbutt Mar 28, 2025
83cd097
Make output on test error sensible
willtebbutt Mar 28, 2025
da3d7ee
Tidy up blas implementations
willtebbutt Mar 28, 2025
eee18dd
Fix pointerset error
willtebbutt Mar 28, 2025
9bd274d
Merge branch 'main' into gd/forward
willtebbutt Mar 28, 2025
3a4f70a
Fix ^ rule
willtebbutt Mar 28, 2025
5aed9b2
Implement from_chain_rule macro
willtebbutt Mar 29, 2025
f4f62c9
Get SpecialFunctions extension working
willtebbutt Mar 29, 2025
9c11e6a
Enable SpecialFunctions in CI
willtebbutt Mar 29, 2025
e19cb63
logexpfunctions
willtebbutt Mar 29, 2025
be93bfd
Run gpu jobs on 1.11 only
willtebbutt Mar 29, 2025
1d1e7e9
Restrict FD step for forward mode
willtebbutt Mar 29, 2025
f21b575
Enable GP tests
willtebbutt Mar 29, 2025
c691a92
More integration testing
willtebbutt Mar 29, 2025
b28961a
bijectors
willtebbutt Mar 29, 2025
b67f2c3
Enable battery of tests
willtebbutt Mar 29, 2025
2bdb0ad
Distributions integration testing
willtebbutt Mar 29, 2025
d4fa5c8
Enable DI CI
willtebbutt Mar 29, 2025
4902ce2
Enable reverse-mode integration tests for Lux etc
willtebbutt Mar 29, 2025
7f57a06
Enable 1.10
willtebbutt Mar 31, 2025
60e4d89
Fix LAPACK on 1.10
willtebbutt Mar 31, 2025
41bb3c3
Implement copytrito for 1.10
willtebbutt Mar 31, 2025
2140edd
formatting
willtebbutt Mar 31, 2025
05bac94
Merge branch 'main' into gd/forward
willtebbutt Mar 31, 2025
df0cf38
Tidying up
willtebbutt Mar 31, 2025
dac008f
Remove type piracy
willtebbutt Mar 31, 2025
48b61ec
Initial forwards-mode timings
willtebbutt Mar 31, 2025
3d9f9bf
Merge in main
willtebbutt May 11, 2025
05d3c65
Constrain JuliaInterpreter
willtebbutt May 26, 2025
df0d2d7
Basic MistyClosure support
willtebbutt May 26, 2025
ed912eb
Merge in main
willtebbutt May 26, 2025
b9c5f7e
Do not use MistyClosure internals inside reverse-mode
willtebbutt Jun 4, 2025
6990348
Forwards-over-reverse mwe
willtebbutt Jun 4, 2025
941e171
Remove overly strict performance check
willtebbutt Jun 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ steps:
setup:
version:
- "1"
- "1.10"
# - "1.10"
label:
- "cuda"
- "nnlib"
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
'rrules/builtins',
'rrules/fastmath',
'rrules/foreigncall',
'rrules/functionwrappers',
# 'rrules/functionwrappers',
'rrules/iddict',
'rrules/lapack',
'rrules/linear_algebra',
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ profile.pb.gz
scratch.jl
docs/build/
docs/site/
playground.jl
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ FunctionWrappers = "1.1.3"
GPUArraysCore = "0.1, 0.2"
Graphs = "1"
InteractiveUtils = "1"
JET = "0.9, 0.10"
JET = "0.9"
JuliaFormatter = "1.0, 2.1"
JuliaInterpreter = "0.9"
LinearAlgebra = "1"
LuxLib = "1"
MistyClosures = "2"
Expand All @@ -75,9 +76,10 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AllocCheck", "Aqua", "BenchmarkTools", "DiffTests", "JET", "JuliaFormatter", "Pkg", "StableRNGs", "Test"]
test = ["AllocCheck", "Aqua", "BenchmarkTools", "DiffTests", "JET", "JuliaFormatter", "JuliaInterpreter", "Pkg", "StableRNGs", "Test"]
23 changes: 22 additions & 1 deletion bench/run_benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,21 @@ using AbstractGPs,
Zygote

using Mooncake:
Dual,
CoDual,
generate_hand_written_rrule!!_test_cases,
generate_derived_rrule!!_test_cases,
TestUtils,
_typeof,
primal,
tangent,
zero_dual,
zero_codual

using Mooncake.TestUtils: _deepcopy

to_benchmark(__frule!!::R, dx::Vararg{Dual,N}) where {R,N} = __frule!!(dx...)

function to_benchmark(__rrule!!::R, dx::Vararg{CoDual,N}) where {R,N}
dx_f = Mooncake.tuple_map(x -> CoDual(primal(x), Mooncake.fdata(tangent(x))), dx)
out, pb!! = __rrule!!(dx_f...)
Expand Down Expand Up @@ -207,6 +211,20 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo
evals=1,
)

# Benchmark AD via Mooncake.
@info "Mooncake (Forward)"
rule = Mooncake.build_frule(args...)
duals = map(x -> x isa Dual ? x : zero_dual(x), args)
to_benchmark(rule, duals...)
include_other_frameworks && GC.gc(true)
suite["mooncake_fwd"] = Chairmarks.benchmark(
() -> (rule, duals),
identity,
a -> to_benchmark(a[1], a[2]...),
_ -> true;
evals=1,
)

if include_other_frameworks
if should_run_benchmark(Val(:zygote), args...)
@info "Zygote"
Expand Down Expand Up @@ -259,6 +277,7 @@ function combine_results(result, tag, _range, default_range)
d = result[2]
primal_time = minimum(d["primal"]).time
mooncake_time = minimum(d["mooncake"]).time
mooncake_fwd_time = minimum(d["mooncake_fwd"]).time
zygote_time = in("zygote", keys(d)) ? minimum(d["zygote"]).time : missing
rd_time = in("rd", keys(d)) ? minimum(d["rd"]).time : missing
ez_time = in("enzyme", keys(d)) ? minimum(d["enzyme"]).time : missing
Expand All @@ -268,6 +287,8 @@ function combine_results(result, tag, _range, default_range)
primal_time=primal_time,
mooncake_time=mooncake_time,
Mooncake=mooncake_time / primal_time,
mooncake_fwd_time=mooncake_fwd_time,
MooncakeFwd=mooncake_fwd_time / primal_time,
zygote_time=zygote_time,
Zygote=zygote_time / primal_time,
rd_time=rd_time,
Expand Down Expand Up @@ -349,7 +370,7 @@ end

function create_inter_ad_benchmarks()
results = benchmark_inter_framework_rules()
tools = [:Mooncake, :Zygote, :ReverseDiff, :Enzyme]
tools = [:Mooncake, :MooncakeFwd, :Zygote, :ReverseDiff, :Enzyme]
df = DataFrame(results)[:, [:tag, :primal_time, tools...]]

# Plot graph of results.
Expand Down
74 changes: 37 additions & 37 deletions ext/MooncakeSpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,44 @@ module MooncakeSpecialFunctionsExt
using SpecialFunctions, Mooncake
using Base: IEEEFloat

import Mooncake: @from_rrule, DefaultCtx, @zero_adjoint
import Mooncake: DefaultCtx, @from_chain_rule, @zero_derivative

@from_rrule DefaultCtx Tuple{typeof(airyai),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(airyaix),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(airyaiprime),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(airybi),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(airybiprime),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(besselj0),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(besselj1),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(bessely0),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(bessely1),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(dawson),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(digamma),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(erf),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(erf),IEEEFloat,IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(erfc),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(logerfc),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(erfcinv),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(erfcx),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(logerfcx),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(erfi),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(erfinv),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(gamma),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(invdigamma),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(trigamma),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(polygamma),Integer,IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(beta),IEEEFloat,IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(logbeta),IEEEFloat,IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(logabsgamma),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(loggamma),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(expint),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(expintx),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(expinti),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(sinint),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(cosint),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(ellipk),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(ellipe),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(airyai),IEEEFloat}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A relatively minor comment: from_chainrules is more precise than from_chain_rule. The former clarifies that we are importing a rule from ChainRules, while the latter mislead me since I thought it refers to the generic chain rule terminology.

Suggested change
@from_chain_rule DefaultCtx Tuple{typeof(airyai),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(airyai),IEEEFloat}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, interesting. I have no strong view either way, so I'm happy to change it if you think it's from_chainerules is clearer.

@from_chain_rule DefaultCtx Tuple{typeof(airyaix),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(airyaiprime),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(airybi),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(airybiprime),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(besselj0),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(besselj1),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(bessely0),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(bessely1),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(dawson),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(digamma),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(erf),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(erf),IEEEFloat,IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(erfc),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(logerfc),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(erfcinv),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(erfcx),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(logerfcx),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(erfi),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(erfinv),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(gamma),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(invdigamma),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(trigamma),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(polygamma),Integer,IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(beta),IEEEFloat,IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(logbeta),IEEEFloat,IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(logabsgamma),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(loggamma),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(expint),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(expintx),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(expinti),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(sinint),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(cosint),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(ellipk),IEEEFloat}
@from_chain_rule DefaultCtx Tuple{typeof(ellipe),IEEEFloat}

@zero_adjoint DefaultCtx Tuple{typeof(logfactorial),Integer}
@zero_derivative DefaultCtx Tuple{typeof(logfactorial),Integer}

end
11 changes: 11 additions & 0 deletions src/Mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using Base:
twiceprecision
using Base.Experimental: @opaque
using Base.Iterators: product
using Base.Meta: isexpr
using Core:
Intrinsics,
bitcast,
Expand All @@ -42,6 +43,13 @@ using FunctionWrappers: FunctionWrapper
# Needs to be defined before various other things.
function _foreigncall_ end

"""
frule!!(f::Dual, x::Dual...)

Performs AD in forward mode, possibly modifying the inputs, and returns a `Dual`.
"""
function frule!! end

"""
rrule!!(f::CoDual, x::CoDual...)

Expand Down Expand Up @@ -92,6 +100,7 @@ build_primitive_rrule(::Type{<:Tuple}) = rrule!!

include("utils.jl")
include("tangents.jl")
include("dual.jl")
include("fwds_rvs_data.jl")
include("codual.jl")
include("debug_mode.jl")
Expand All @@ -106,6 +115,7 @@ include(joinpath("interpreter", "patch_for_319.jl"))
include(joinpath("interpreter", "ir_utils.jl"))
include(joinpath("interpreter", "ir_normalisation.jl"))
include(joinpath("interpreter", "zero_like_rdata.jl"))
include(joinpath("interpreter", "s2s_forward_mode_ad.jl"))
include(joinpath("interpreter", "s2s_reverse_mode_ad.jl"))

include("tools_for_rules.jl")
Expand All @@ -123,6 +133,7 @@ include(joinpath("rrules", "lapack.jl"))
include(joinpath("rrules", "linear_algebra.jl"))
include(joinpath("rrules", "low_level_maths.jl"))
include(joinpath("rrules", "misc.jl"))
include(joinpath("rrules", "misty_closures.jl"))
include(joinpath("rrules", "new.jl"))
include(joinpath("rrules", "random.jl"))
include(joinpath("rrules", "tasks.jl"))
Expand Down
7 changes: 7 additions & 0 deletions src/codual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
Base.copy(x::CoDual) = CoDual(copy(primal(x)), copy(tangent(x)))
_copy(x::P) where {P<:CoDual} = x

"""
extract(x::CoDual)

Helper function. Returns the 2-tuple `x.x, x.dx`.
"""
extract(x::CoDual) = primal(x), tangent(x)

Check warning on line 25 in src/codual.jl

View check run for this annotation

Codecov / codecov/patch

src/codual.jl#L25

Added line #L25 was not covered by tests

"""
zero_codual(x)

Expand Down
1 change: 1 addition & 0 deletions src/debug_mode.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
DebugFRule(rule) = rule # TODO: make it non-trivial

Check warning on line 1 in src/debug_mode.jl

View check run for this annotation

Codecov / codecov/patch

src/debug_mode.jl#L1

Added line #L1 was not covered by tests

"""
DebugPullback(pb, y, x)
Expand Down
12 changes: 12 additions & 0 deletions src/developer_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,15 @@
)::IRCode
return generate_ir(interp, sig; debug_mode, do_inline).rvs_ir
end

"""

"""
function dual_ir(

Check warning on line 111 in src/developer_tools.jl

View check run for this annotation

Codecov / codecov/patch

src/developer_tools.jl#L111

Added line #L111 was not covered by tests
sig::Type{<:Tuple};
interp=get_interpreter(),
debug_mode::Bool=false,
do_inline::Bool=true,
)
return generate_dual_ir(interp, sig; debug_mode, do_inline)

Check warning on line 117 in src/developer_tools.jl

View check run for this annotation

Codecov / codecov/patch

src/developer_tools.jl#L117

Added line #L117 was not covered by tests
end
47 changes: 47 additions & 0 deletions src/dual.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
struct Dual{P,T}
primal::P
tangent::T
end

primal(x::Dual) = x.primal
tangent(x::Dual) = x.tangent
Base.copy(x::Dual) = Dual(copy(primal(x)), copy(tangent(x)))
_copy(x::P) where {P<:Dual} = x

Check warning on line 9 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L8-L9

Added lines #L8 - L9 were not covered by tests

"""
extract(x::CoDual)

Helper function. Returns the 2-tuple `x.x, x.dx`.
"""
extract(x::Dual) = primal(x), tangent(x)

zero_dual(x) = Dual(x, zero_tangent(x))
randn_dual(rng::AbstractRNG, x) = Dual(x, randn_tangent(rng, x))

function dual_type(::Type{P}) where {P}
P == DataType && return Dual
P isa Union && return Union{dual_type(P.a),dual_type(P.b)}
P <: UnionAll && return Dual # P is abstract, so we don't know its tangent type.
return isconcretetype(P) ? Dual{P,tangent_type(P)} : Dual
end

function dual_type(p::Type{Type{P}}) where {P}
return @isdefined(P) ? Dual{Type{P},NoTangent} : Dual{_typeof(p),NoTangent}
end

_primal(x) = x

Check warning on line 32 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L32

Added line #L32 was not covered by tests
_primal(x::Dual) = primal(x)

"""
verify_dual_type(x::Dual)

Check that the type of `tangent(x)` is the tangent type of the type of `primal(x)`.
"""
verify_dual_type(x::Dual) = tangent_type(typeof(primal(x))) == typeof(tangent(x))

@inline uninit_dual(x::P) where {P} = Dual(x, uninit_tangent(x))

# Always sharpen the first thing if it's a type so static dispatch remains possible.
function Dual(x::Type{P}, dx::NoTangent) where {P}
return Dual{@isdefined(P) ? Type{P} : typeof(x),NoTangent}(x, dx)
end
2 changes: 2 additions & 0 deletions src/interpreter/ir_normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
if stmt isa Expr && stmt.head == :boundscheck && length(stmt.args) == 1
def = SSAValue(n)
val = only(stmt.args)
# TODO: this could just be `statements[n] = val` (Valentin C says)
for (m, stmt) in enumerate(statements)
statements[m] = replace_uses_with!(stmt, def, val)
end
Expand Down Expand Up @@ -349,6 +350,7 @@

@is_primitive MinimalCtx Tuple{typeof(gc_preserve),Vararg{Any,N}} where {N}

frule!!(::Dual{typeof(gc_preserve)}, ::Vararg{Dual,N}) where {N} = zero_dual(nothing)

Check warning on line 353 in src/interpreter/ir_normalisation.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/ir_normalisation.jl#L353

Added line #L353 was not covered by tests
function rrule!!(f::CoDual{typeof(gc_preserve)}, xs::CoDual...)
pb = NoPullback(f, xs...)
gc_preserve_pb!!(::NoRData) = GC.@preserve xs pb(NoRData())
Expand Down
Loading
Loading