Basic rewrite of the package 2023 edition Part I: ADVI#49
Basic rewrite of the package 2023 edition Part I: ADVI#49Red-Portal merged 213 commits intoTuringLang:masterfrom
Conversation
This is to avoid having to reconstruct transformed distributions all the time. The direct use of bijectors also avoids going through lots of abstraction layers that could break. Instead, transformed distributions could be constructed only once when returing the VI result.
…into rewriting_advancedvi
This reverts commit 2a4514e.
- Full Monte Carlo ELBO estimation now works. I checked.
|
@torfjelde @Red-Portal This PR already looks good; let's try to get this PR merged in the next two weeks. |
|
@torfjelde @yebai As discussed, I've removed the parts directly interacting with Bijectors. I think it's ready for another review pass. (Hopeful that we're close to being done for this PR!) |
|
@torfjelde a reminder on this. |
|
Having a proper look now:) |
torfjelde
left a comment
There was a problem hiding this comment.
I've added a few more comments, most of which should be quick accepts.
There's one one argument ordering of reparam_with_entropy that is potentially "discussable", but it should be a simple change if you accept it.
So. I think the time as come 👀
I am ready accept the PR 🙏
But this is really great work @Red-Portal 👏 And thank you so much for your persistence and just continuously chipping away ❤️ It must have been a pain, but I do think the PR is a much better state now than when we started, so it's been worth it:)
| using ..ReverseDiff | ||
| end | ||
|
|
||
| # ReverseDiff without compiled tape |
There was a problem hiding this comment.
Can we not handle compiled tape?
Co-authored-by: Tor Erlend Fjelde <[email protected]>
…/AdvancedVI.jl into rewriting_advancedvi_optimize
Hi, this is a partial pull request for #45.
The content of Part I is as follows:
Roadmap
LogDensityProblemsinterface.ADTypesinterface.Functor.jlfor flattening/unflattening variational parameters.optimize. (see Missing API method #32 )Optimisers.jl.callbackoption (Callback function during training #5)Notes
LocationScalevariational family (Part II) and the documentation (Part III).DistributionsAD.TuringDiagMvNormalvariational family for running the tests.