Merged
Conversation
…asic reverse diff and manual mat vec product.
eb8680
approved these changes
Jul 6, 2024
Contributor
eb8680
left a comment
There was a problem hiding this comment.
Nice sleuthing! Great to see such big performance gains.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Addresses memory issue stemming from
vmapovertorch.func.jvpinMonteCarloInfluenceEstimator. Instead, uses reverse mode autodiff for Jacobian of functional (largely because parameter dimensionality will typically far exceed dimensionality of functional) and then manually right multipliesparam_eif(the fisher matrix X data log probability). Right multiplication is performed agnostically wrt both pytree structures and tensor shapes (emulatingtorch.func.jvp, with slightly more agnosticity actually).Memory use is orders of magnitude lower, to the point of not being noticeable.
One possible difference (/cause of original problem): the
vmapoverjvpwas potentially estimating and computing the jacobian separately for each batch inparam_eif. This is very redundant, but also meant each batch saw different randomness in the Jacobian estimate, thereby propagating some notion of variability in the Jacobian estimate to the user. This implementation estimates/computes the Jacobian once only for all batches inparam_eif. This may or may not be desirable, but it's important to note that doing so separately for each batch comes at very high computational cost.Adds tests for the alternative jvp implementation, including a test of memory consumption.