You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Multitask GPs (using the LMCVariationalStrategy and MeanFieldVariationalDistribution) are seemingly incompatible with the NNVariationalStrategy. This makes it difficult to train multitask GPs with very large numbers of inducing points.
To reproduce
I modified the Variational GPs w/ Multiple Outputs example to try to use NNVariationalStrategy as follows:
** Code snippet to reproduce **
importmathimporttorchimportgpytorchimporttqdmfrommatplotlibimportpyplotasplttrain_x=torch.linspace(0, 1, 100)
train_y=torch.stack([
torch.sin(train_x* (2*math.pi)) +torch.randn(train_x.size()) *0.2,
torch.cos(train_x* (2*math.pi)) +torch.randn(train_x.size()) *0.2,
torch.sin(train_x* (2*math.pi)) +2*torch.cos(train_x* (2*math.pi)) +torch.randn(train_x.size()) *0.2,
-torch.cos(train_x* (2*math.pi)) +torch.randn(train_x.size()) *0.2,
], -1)
print(train_x.shape, train_y.shape)
num_latents=3num_tasks=4classMultitaskGPModel(gpytorch.models.ApproximateGP):
def__init__(self, inducing_points):
variational_distribution=gpytorch.variational.MeanFieldVariationalDistribution(
inducing_points.size(-2), batch_shape=torch.Size([num_latents])
)
variational_strategy=gpytorch.variational.LMCVariationalStrategy(
gpytorch.variational.NNVariationalStrategy(self, inducing_points, variational_distribution, k=8, training_batch_size=16),
num_tasks=num_tasks,
num_latents=num_latents,
latent_dim=-1,
)
super().__init__(variational_strategy)
# The mean and covariance modules should be marked as batch# so we learn a different set of hyperparametersself.mean_module=gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents]))
self.covar_module=gpytorch.kernels.ScaleKernel(
gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])),
batch_shape=torch.Size([num_latents])
)
defforward(self, x):
# The forward function should be written as if we were dealing with each output# dimension in batchmean_x=self.mean_module(x)
covar_x=self.covar_module(x)
returngpytorch.distributions.MultivariateNormal(mean_x, covar_x)
model=MultitaskGPModel(train_x)
likelihood=gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=num_tasks)
num_epochs=150model.train()
likelihood.train()
optimizer=torch.optim.Adam([
{'params': model.parameters()},
{'params': likelihood.parameters()},
], lr=0.1)
# Our loss object. We're using the VariationalELBO, which essentially just computes the ELBOmll=gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))
# We use more CG iterations here because the preconditioner introduced in the NeurIPS paper seems to be less# effective for VI.epochs_iter=tqdm.tqdm(range(num_epochs), desc="Epoch")
foriinepochs_iter:
# Within each iteration, we will go over each minibatch of dataoptimizer.zero_grad()
output=model(None)
loss=-mll(output, train_y)
epochs_iter.set_postfix(loss=loss.item())
loss.backward()
optimizer.step()
** Stack trace/error message **
Traceback (most recent call last):
File "<stdin>", line 4, in <module>
File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/models/approximate_gp.py", line 114, in __call__
return self.variational_strategy(inputs, prior=prior, **kwargs)
File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/variational/lmc_variational_strategy.py", line 197, in __call__
latent_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 145, in __call__
return self.forward(
File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 192, in forward
kl = self._kl_divergence(kl_indices)
File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 369, in _kl_divergence
kl = self._stochastic_kl_helper(kl_indices) * self.M / len(kl_indices)
File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 313, in _stochastic_kl_helper
cov = self.model.covar_module.forward(nearest_neighbors, nearest_neighbors)
File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/kernels/scale_kernel.py", line 109, in forward
orig_output = self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params)
File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/kernels/rbf_kernel.py", line 80, in forward
return RBFCovariance.apply(
File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/torch/autograd/function.py", line 598, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/functions/rbf_covariance.py", line 12, in forward
x1_ = x1.div(lengthscale)
RuntimeError: The size of tensor a (16) must match the size of tensor b (3) at non-singleton dimension 1
Expected Behavior
The forward call to NNVariationalStrategy should work in a multi-task setting
System information
GPyTorch Version 1.12
PyTorch Version 2.3.1
MacOS 14.5
The text was updated successfully, but these errors were encountered:
🐛 Bug
Multitask GPs (using the LMCVariationalStrategy and MeanFieldVariationalDistribution) are seemingly incompatible with the NNVariationalStrategy. This makes it difficult to train multitask GPs with very large numbers of inducing points.
To reproduce
I modified the Variational GPs w/ Multiple Outputs example to try to use NNVariationalStrategy as follows:
** Code snippet to reproduce **
** Stack trace/error message **
Expected Behavior
The forward call to NNVariationalStrategy should work in a multi-task setting
System information
The text was updated successfully, but these errors were encountered: