Fix for Robust Estimation Memory Issue #548
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Addresses memory issue stemming from
vmap
overtorch.func.jvp
inMonteCarloInfluenceEstimator
. Instead, uses reverse mode autodiff for Jacobian of functional (largely because parameter dimensionality will typically far exceed dimensionality of functional) and then manually right multipliesparam_eif
(the fisher matrix X data log probability). Right multiplication is performed agnostically wrt both pytree structures and tensor shapes (emulatingtorch.func.jvp
, with slightly more agnosticity actually).Memory use is orders of magnitude lower, to the point of not being noticeable.
One possible difference (/cause of original problem): the
vmap
overjvp
was potentially estimating and computing the jacobian separately for each batch inparam_eif
. This is very redundant, but also meant each batch saw different randomness in the Jacobian estimate, thereby propagating some notion of variability in the Jacobian estimate to the user. This implementation estimates/computes the Jacobian once only for all batches inparam_eif
. This may or may not be desirable, but it's important to note that doing so separately for each batch comes at very high computational cost.Adds tests for the alternative jvp implementation, including a test of memory consumption.