-
Notifications
You must be signed in to change notification settings - Fork 14
Start forward mode AD #389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
gdalle
wants to merge
131
commits into
chalk-lab:main
Choose a base branch
from
gdalle:gd/forward
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
131 commits
Select commit
Hold shift + click to select a range
be316ff
Start forward mode prototype
gdalle deac913
First working autodiff
gdalle 9c96c8d
Docstring
gdalle 136aff6
Apply suggestions from code review
gdalle f65cc53
Moving files around
gdalle 053a8bb
Primitives already known
gdalle 6d8ec04
Merge branch 'main' into gd/forward
gdalle a3107a8
Keep pushing forward (pun intended)
gdalle 2836ac8
Still buggy, don't touch
gdalle 09d63bd
Keep instruction mapping one to one
gdalle fa679eb
Use replace_call
gdalle a68257c
Ignore code cov
gdalle 7a096ba
No Aqua piracies test
gdalle 46c3e5a
Start control flow
gdalle ad3f98a
Fix intrinsic
gdalle 9071574
Import
gdalle dcfe282
Typos
gdalle e44380d
Co-authored-by: Will Tebbutt <[email protected]>
gdalle dd89e57
Figure out incremental additions
gdalle 9bdb57f
Initial test case additions
willtebbutt 4bb9911
Formatting
willtebbutt 9b037e7
Add verify_dual_type
willtebbutt 6dea624
test_frule_interface runs
willtebbutt a614846
Fix ReturnNode
willtebbutt eadae95
Correctness testing runs
willtebbutt 345b3fd
Add randn_dual
willtebbutt f58c394
Improve sin and cos frules
willtebbutt c8d8895
Performance tests run
willtebbutt 578e41b
Tidy up implementation
willtebbutt b5d34b2
Standard testing infrastructure
willtebbutt 205e716
Fix typos
willtebbutt d328db0
Fix return node to return dual
gdalle 66a48c8
Handle PiNode
gdalle e455cf6
Deleted line
gdalle 8d120b2
Case 7 solved
gdalle cd7167f
Resolve merge conflict
willtebbutt c5ffae7
Fix precompile issue
willtebbutt 94aa904
Fix isa rule
willtebbutt cc7a3fa
Fix is_primitive
willtebbutt 70d7183
More test cases
gdalle aec412e
progress
gdalle 0ea1084
fixes
gdalle d8a949f
Bump patch vesion
willtebbutt 79844d2
Fix terminators
willtebbutt 49aa4ca
Merge remote-tracking branch 'upstream/wct/fix-terminator-issue' into…
gdalle 9ce99ec
More cases
gdalle 6ce2488
More cases
gdalle 8954361
Merge remote-tracking branch 'upstream/main' into gd/forward
gdalle 941a2de
Merge remote-tracking branch 'upstream/main' into gd/forward
gdalle af49eac
Tuple rule
gdalle 0b4e5fa
Merge in main
willtebbutt 8204665
Formatting
willtebbutt 70fec10
Code to view forwards-mode IR from a signature
willtebbutt 6cde147
Use widenconst to get actual argtype from ircode argtypes
willtebbutt 0eabff0
MyInstruction -> new_instruction
willtebbutt 8b391c6
Formatting
willtebbutt 5d6b826
Merge branch 'main' into gd/forward
willtebbutt a919a28
Various improvements
willtebbutt 2808a12
Rules for foreigncalls
willtebbutt cb28759
Fix pointer tests with forwards mode
willtebbutt f9d1697
Enable more tests
willtebbutt 9bc53cc
All derivation tests pass
willtebbutt d6fc35d
Initial pass over legacy array functionality
willtebbutt 6b2409c
Fix tangent usage in tests
willtebbutt d6974c1
Rules for nice BLAS functions
willtebbutt fbcc6ce
Tweak test inputs slightly
willtebbutt 732762b
Enable CI for BLAS and foreigncalls
willtebbutt fd48f02
Enable linear_algebra rules
willtebbutt f6bc752
More stuff works
willtebbutt a96611e
Make IdDict work
willtebbutt 44e78b4
Code to identify SSA uses
willtebbutt a504413
Fix failing test via special case
willtebbutt f68c79b
Remove outdated TODO note
willtebbutt 05cbb83
Merge branch 'main' into gd/forward
willtebbutt 30c5294
Fix typo
willtebbutt fe4ec4a
BLAS support nearly finished
willtebbutt f771f70
All BLAS rules passing
willtebbutt 86fa1b6
Initial work on getrf
willtebbutt 04ea669
Merge branch 'main' into gd/forward
willtebbutt e1a1260
getrf frule sketch
willtebbutt fda2ab9
Merge branch 'gd/forward' of https://github.com/gdalle/Mooncake.jl in…
willtebbutt 37baaf0
Improve getrf performance
willtebbutt c0c4167
trtrs implementation + type stability checks
willtebbutt 9a12b23
Type stability checks for BLAS rules
willtebbutt bb8feba
Note Seth's blog
willtebbutt 64d6176
getrs frule implementation
willtebbutt be57d7f
getri frule implementation
willtebbutt 2934409
potrs
willtebbutt fe289a0
Enable lapack CI
willtebbutt 39354bc
Fix pivoting
willtebbutt 8bcde33
Enable diff tests integration tests
willtebbutt 497c907
Only run extra CI on 1
willtebbutt e1dce38
More lapack fixes
willtebbutt 8739c6c
widenconst
willtebbutt 899d4c4
Replace field access with method call
willtebbutt 594ba13
Catch __vec_to_tuple edge case
willtebbutt 0510235
Display more stuff when correctness test fails
willtebbutt 4af2276
Enable more integration tests
willtebbutt 83cd097
Make output on test error sensible
willtebbutt da3d7ee
Tidy up blas implementations
willtebbutt eee18dd
Fix pointerset error
willtebbutt 9bd274d
Merge branch 'main' into gd/forward
willtebbutt 3a4f70a
Fix ^ rule
willtebbutt 5aed9b2
Implement from_chain_rule macro
willtebbutt f4f62c9
Get SpecialFunctions extension working
willtebbutt 9c11e6a
Enable SpecialFunctions in CI
willtebbutt e19cb63
logexpfunctions
willtebbutt be93bfd
Run gpu jobs on 1.11 only
willtebbutt 1d1e7e9
Restrict FD step for forward mode
willtebbutt f21b575
Enable GP tests
willtebbutt c691a92
More integration testing
willtebbutt b28961a
bijectors
willtebbutt b67f2c3
Enable battery of tests
willtebbutt 2bdb0ad
Distributions integration testing
willtebbutt d4fa5c8
Enable DI CI
willtebbutt 4902ce2
Enable reverse-mode integration tests for Lux etc
willtebbutt 7f57a06
Enable 1.10
willtebbutt 60e4d89
Fix LAPACK on 1.10
willtebbutt 41bb3c3
Implement copytrito for 1.10
willtebbutt 2140edd
formatting
willtebbutt 05bac94
Merge branch 'main' into gd/forward
willtebbutt df0cf38
Tidying up
willtebbutt dac008f
Remove type piracy
willtebbutt 48b61ec
Initial forwards-mode timings
willtebbutt 3d9f9bf
Merge in main
willtebbutt 05d3c65
Constrain JuliaInterpreter
willtebbutt df0d2d7
Basic MistyClosure support
willtebbutt ed912eb
Merge in main
willtebbutt b9c5f7e
Do not use MistyClosure internals inside reverse-mode
willtebbutt 6990348
Forwards-over-reverse mwe
willtebbutt 941e171
Remove overly strict performance check
willtebbutt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ steps: | |
setup: | ||
version: | ||
- "1" | ||
- "1.10" | ||
# - "1.10" | ||
label: | ||
- "cuda" | ||
- "nnlib" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ profile.pb.gz | |
scratch.jl | ||
docs/build/ | ||
docs/site/ | ||
playground.jl |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
DebugFRule(rule) = rule # TODO: make it non-trivial | ||
|
||
""" | ||
DebugPullback(pb, y, x) | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
struct Dual{P,T} | ||
primal::P | ||
tangent::T | ||
end | ||
|
||
primal(x::Dual) = x.primal | ||
tangent(x::Dual) = x.tangent | ||
Base.copy(x::Dual) = Dual(copy(primal(x)), copy(tangent(x))) | ||
_copy(x::P) where {P<:Dual} = x | ||
|
||
""" | ||
extract(x::CoDual) | ||
|
||
Helper function. Returns the 2-tuple `x.x, x.dx`. | ||
""" | ||
extract(x::Dual) = primal(x), tangent(x) | ||
|
||
zero_dual(x) = Dual(x, zero_tangent(x)) | ||
randn_dual(rng::AbstractRNG, x) = Dual(x, randn_tangent(rng, x)) | ||
|
||
function dual_type(::Type{P}) where {P} | ||
P == DataType && return Dual | ||
P isa Union && return Union{dual_type(P.a),dual_type(P.b)} | ||
P <: UnionAll && return Dual # P is abstract, so we don't know its tangent type. | ||
return isconcretetype(P) ? Dual{P,tangent_type(P)} : Dual | ||
end | ||
|
||
function dual_type(p::Type{Type{P}}) where {P} | ||
return @isdefined(P) ? Dual{Type{P},NoTangent} : Dual{_typeof(p),NoTangent} | ||
end | ||
|
||
_primal(x) = x | ||
_primal(x::Dual) = primal(x) | ||
|
||
""" | ||
verify_dual_type(x::Dual) | ||
|
||
Check that the type of `tangent(x)` is the tangent type of the type of `primal(x)`. | ||
""" | ||
verify_dual_type(x::Dual) = tangent_type(typeof(primal(x))) == typeof(tangent(x)) | ||
|
||
@inline uninit_dual(x::P) where {P} = Dual(x, uninit_tangent(x)) | ||
|
||
# Always sharpen the first thing if it's a type so static dispatch remains possible. | ||
function Dual(x::Type{P}, dx::NoTangent) where {P} | ||
return Dual{@isdefined(P) ? Type{P} : typeof(x),NoTangent}(x, dx) | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A relatively minor comment:
from_chainrules
is more precise thanfrom_chain_rule
. The former clarifies that we are importing a rule fromChainRules
, while the latter mislead me since I thought it refers to the generic chain rule terminology.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, interesting. I have no strong view either way, so I'm happy to change it if you think it's
from_chainerules
is clearer.