[RFC] Refactor Input Transforms#1176
Conversation
Summary: Currently, we apply the input transforms in `train` mode at the `forward` call, and in `eval` model at the `posterior` call. We also use a `transform_train_inputs` call at the `eval/train` calls to make sure that at `eval` time the `train_inputs` are stored as transformed (since they don't pass through `posterior`). This design supports `ExactGP` models, and supports specifying where to apply which input transform via the flags (so that one-to-many transforms are only applied to test inputs). However, this does not work great with Approximate GP models, since this setup does not transform the inducing points at `eval` time. This refactor splits out one-to-many transforms as `InputAugmentationTransform`, allowing us to revert to simply applying the `transform_inputs` in the `forward` pass (at all times). We still need to apply one-to-many transforms (now called `InputAugmentationTransform`) in `posterior`, so we introduce an `augment_inputs` method. (Inspired by the public-private APIs of Ax) In order to minimize the transform related knowledge expected from developers, this introduces a `Model.forward` call that applies `transform_inputs` and calls `self._forward`. `<AnyGivenModel>._forward` is the usual `forward` call that computes the prior, except that it no longer has to worry about transforms. Similarly, for the `posterior`, this makes `Model.posterior` into a simple wrapper around `Model._posterior`, which applies the `augment_inputs` call and the `posterior_transform`. Again, the `<AnyGivenModel>._posterior` becomes the usual posterior call that no longer has to worry about the input or posterior transforms (still has to deal with the outcome transform in the current implementation, though we can fix this by bringing back the `fantasize` flag). This diff presents a minimal implementation around the `SingleTaskGP` model. Differential Revision: D35129407 fbshipit-source-id: 0a8ab840774bcd281f50925314d04725b453a7c8
|
This pull request was exported from Phabricator. Differential Revision: D35129407 |
|
cc @wjmaddox. For context, this came out of a discussion around the input transforms and variational strategy / inducing points. The current "apply only in posterior in eval mode" skips over the inducing points when evaluating the posterior (we pre-transform the |
|
This looks great! Yeah, I really struggled with input transforms with variational GPs (don't think the version in Botorch really handles them super well now) and had to place them in the forwards call for my own research code. This seems like a pretty sensible structure to dichotomize 1-1 transforms with 1-many transforms too. |
|
Closed in favor of #1372 |
Summary:
Currently, we apply the input transforms in
trainmode at theforwardcall, and inevalmodel at theposteriorcall. We also use atransform_train_inputscall at theeval/traincalls to make sure that atevaltime thetrain_inputsare stored as transformed (since they don't pass throughposterior). This design supportsExactGPmodels, and supports specifying where to apply which input transform via the flags (so that one-to-many transforms are only applied to test inputs). However, this does not work great with Approximate GP models, since this setup does not transform the inducing points atevaltime.This refactor splits out one-to-many transforms as
InputAugmentationTransform, allowing us to revert to simply applying thetransform_inputsin theforwardpass (at all times). We still need to apply one-to-many transforms (now calledInputAugmentationTransform) inposterior, so we introduce anaugment_inputsmethod.(Inspired by the public-private APIs of Ax) In order to minimize the transform related knowledge expected from developers, this introduces a
Model.forwardcall that appliestransform_inputsand callsself._forward.<AnyGivenModel>._forwardis the usualforwardcall that computes the prior, except that it no longer has to worry about transforms.Similarly, for the
posterior, this makesModel.posteriorinto a simple wrapper aroundModel._posterior, which applies theaugment_inputscall and theposterior_transform. Again, the<AnyGivenModel>._posteriorbecomes the usual posterior call that no longer has to worry about the input or posterior transforms (still has to deal with the outcome transform in the current implementation, though we can fix this by bringing back thefantasizeflag).This diff presents a minimal implementation around the
SingleTaskGPmodel.Differential Revision: D35129407