You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Might also explain #2306. Potential bug, but since the contributing guideline asks for specification of the issue I declared it as a bug right away. Please feel free to change, if it is not correct.
Using the gpytorch.settings.verbose_linalg(state=True) context revealed, that switching from single- to multi-task GPs, the linalg output changed from "CG" to "symeig". This is unexpected for me, because from the paper gardner2018gpytorch I do not see a reason why the "mBCG" algorithm (which I assume is called "CG" by the output shown below) should not be applicable in the multi-task case. Of course I could be missing that point, in this case please be so kind and point me to that.
To reproduce
Primarily, I adapted the GPyTorch Regression Tutorial (GPU) from the documentation.
I wanted to make it convenient to easily switch back and forth between single- and multi-task GPs, shapes of tensors, and CPU/GPU, so I wrapped the code to reproduce in a function run_gpytorch(...), which is called with desired kwargs. The signature is
num_samples:int specifies the number of data points that are generated for training the GP, as I came across gpytorch.settings.max_cholesky_size(value) in the documentation, and num_samples allows to easily change the size of the matrices
dims_in:int specifies the dimensionality of the inputs to the GP, one might also call it number of features of the input data
dims_out:int specifies the dimensionality of the outputs from the GP. This also is equal to num_tasks in the gpytorch.means.MultitaskMean, ...MultitaskKernel, and ...MultitaskGaussianLikelihood classes.
device:str specifies whether the GP should be trained and tested using CPU or GPU
** Code snippet to reproduce **
importtorchimportgpytorchimportnumpyasnp
defrun_gpytorch(dims_in:int, dims_out:int, num_samples:int, device:str={'cpu''gpu'}):
# Set context for mBCG and linalg debuggingwithgpytorch.settings.verbose_linalg(state=True) \
,gpytorch.settings.fast_computations(covar_root_decomposition=True, log_prob=True, solves=True):
# Generate inputs for training and testingsamples_train=np.linspace(start=[0.]*dims_in, stop=[1.]*dims_in, num=num_samples)
samples_test=np.linspace(start=[0.]*dims_in, stop=[1.]*dims_in, num=np.floor(num_samples/2.67).astype(int))
# For dims_in=1 reshaping is necessary, because gpytorch.models.ExactGP expects inputs as 1-D arrays (n,) [not (n,1)]ifdims_in==1:
samples_train=samples_train.reshape(-1,)
samples_test=samples_test.reshape(-1,)
train_x=torch.tensor(samples_train).to(torch.float)
test_x=torch.tensor(samples_test).to(torch.float)
# Generate outputs for training (that the model should learn to predict)if (dims_in>1) and (dims_out>1):
train_y=torch.stack([torch.sin(2*torch.pi*train_x[:,0])] *dims_out , 1).to(torch.float)
elif (dims_in==1) and (dims_out>1):
train_y=torch.stack([torch.sin(2*torch.pi*train_x )] *dims_out , 1).to(torch.float)
elif (dims_in>1) and (dims_out==1):
train_y=torch.sin(2*torch.pi*train_x[:,0]) .to(torch.float)
elif (dims_in==1) and (dims_out==1):
train_y=torch.sin(2*torch.pi*train_x ) .to(torch.float)
print(f'Shape of train_x: {train_x.shape}')
print(f'Shape of test_x: {test_x.shape}')
print(f'Shape of train_y: {train_y.shape}'+'\n')
# Define class for single-/multitask GP modelifdims_out==1:
print(f'Using single-task GP as dims_out = {dims_out}'+'\n')
classExactGPModel(gpytorch.models.ExactGP):
def__init__(self, train_x, train_y, likelihood):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module=gpytorch.means.ConstantMean()
self.covar_module=gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
defforward(self, x):
mean_x=self.mean_module(x)
covar_x=self.covar_module(x)
returngpytorch.distributions.MultivariateNormal(mean_x, covar_x)
# Instantiate single-task likelihood and GP modellikelihood=gpytorch.likelihoods.GaussianLikelihood()
model=ExactGPModel(train_x, train_y, likelihood)
else:
print(f'Using multi-task GP as dims_out = {dims_out}'+'\n')
classMultitaskGPModel(gpytorch.models.ExactGP):
def__init__(self, train_x, train_y, likelihood):
super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module=gpytorch.means.MultitaskMean(
gpytorch.means.ConstantMean(), num_tasks=dims_out
)
self.covar_module=gpytorch.kernels.MultitaskKernel(
gpytorch.kernels.RBFKernel(), num_tasks=dims_out
)
defforward(self, x):
mean_x=self.mean_module(x)
covar_x=self.covar_module(x)
returngpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)
# Instantiate multi-task likelihood and GP model likelihood=gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=dims_out)
model=MultitaskGPModel(train_x, train_y, likelihood)
# Move data, GP model, and likelihood to gpu if desiredifdevice=='gpu':
print('Move all structures to GPU since device=gpu'+'\n')
train_x=train_x.cuda()
test_x=test_x.cuda()
train_y=train_y.cuda()
model=model.cuda()
likelihood=likelihood.cuda()
else:
print('Do not move structures to GPU since device=cpu'+'\n')
# Switch to training modemodel.train()
likelihood.train()
# Define optimizer and loss functionoptimizer=torch.optim.Adam(model.parameters(), lr=0.1)
mll=gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
# Train GP using training dataprint('Start training')
training_iterations=1foriinrange(training_iterations):
optimizer.zero_grad()
output=model(train_x)
loss=-mll(output, train_y)
loss.backward()
optimizer.step()
print(' Finished training iteration %i/%i'% (i+1, training_iterations))
print('Finished training'+'\n')
# Switch to evaluation mode, and probe trained GP using testing datamodel.eval()
likelihood.eval()
print('Start testing')
withtorch.no_grad():
observed_model=model(test_x)
print(' Testing: Finished evaluation')
observed_pred=likelihood(observed_model)
print(' Testing: Finished likelihood')
mean=observed_pred.meanlower, upper=observed_pred.confidence_region()
print('Finished testing'+'\n')
return
# Create dict of inputs to run_gpytorch(...)kwargs_profile= {'dims_in':1, 'dims_out':2, 'num_samples':1000, 'device':'gpu'}
run_gpytorch(**kwargs_profile)
Outputs
The code above generates the following output, where the usage of ... symeig ... is revealed:
Shape of train_x: torch.Size([1000])
Shape of test_x: torch.Size([374])
Shape of train_y: torch.Size([1000, 2])
Using multi-task GP as dims_out = 2
Move all structures to GPU since device=gpu
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([1000, 1000]).
Start training
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([2, 2]).
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([1000, 1000]).
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([2, 2]).
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([1000, 1000]).
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([2, 2]).
c:\Users\{user}\micromamba\envs\gpytorch_mwe\Lib\site-packages\linear_operator\utils\interpolation.py:71: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\tensor\python_tensor.cpp:80.)
summing_matrix = cls(summing_matrix_indices, summing_matrix_values, size)
c:\Users\{user}\micromamba\envs\gpytorch_mwe\Lib\site-packages\linear_operator\utils\interpolation.py:71: UserWarning: torch.sparse.SparseTensor(indices, values, shape, *, device=) is deprecated. Please use torch.sparse_coo_tensor(indices, values, shape, dtype=, device=). (Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\utils\tensor_new.cpp:623.)
summing_matrix = cls(summing_matrix_indices, summing_matrix_values, size)
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([1000, 1000]).
Finished training iteration 1/1
Finished training
Start testing
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([2, 2]).
Testing: Finished evaluation
Testing: Finished likelihood
Finished testing
Expected Behavior
I expected the lines in the output starting with LinAlg (Verbose) ... to say something like LinAlg (Verbose) - DEBUG - Running CG ... instead of LinAlg (Verbose) - DEBUG - Running symeig...
This is because testing a single-task GP, run by (notice dims_out changed to 1)
generates the following output, revealing ... CG ...:
Shape of train_x: torch.Size([1000])
Shape of test_x: torch.Size([374])
Shape of train_y: torch.Size([1000])
Using single-task GP as dims_out = 1
Move all structures to GPU since device=gpu
LinAlg (Verbose) - DEBUG - Running CG on a torch.Size([1000, 11]) RHS for 1000 iterations (tol=1). Output: torch.Size([1000, 11]).
Start training
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([10, 11, 11]).
LinAlg (Verbose) - DEBUG - Running CG on a torch.Size([1000, 1]) RHS for 1000 iterations (tol=0.01). Output: torch.Size([1000, 1]).
LinAlg (Verbose) - DEBUG - Running CG on a torch.Size([1000, 374]) RHS for 1000 iterations (tol=0.01). Output: torch.Size([1000, 374]).
Finished training iteration 1/1
Finished training
Start testing
Testing: Finished evaluation
Testing: Finished likelihood
Finished testing
System information
GPyTorch Version (run print(gpytorch.__version__)) --> 1.11
PyTorch Version (run print(torch.__version__)) --> 2.3.0
Python Version (run print(sys.__version__)) --> 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:20:11) [MSC v.1938 64 bit (AMD64)]
micromamba Version (run micromamba info) --> libmamba version: 1.5.8
VSCode Version (Help >> About ) --> Version: 1.89.0
Computer OS --> Windows 11
Additional context
I was profiling GPyTorch the other day (using cProfiler), and noticed, that in the multi-task GPs the linalg solver called by GPyTorch was torch._C._linalg.linalg_eigh. That lead to the investigation above. If you are interested in that, I can also provide the profiling information.
I am using a jupyter notebook inside VSCode and a mamba environment created by the following yaml file, using micromamba create -f .\{file_name}.yml which resulted in the specs below, using micromamba env export > {other_file_name}.yml.
🐛 Bug
Might also explain #2306. Potential bug, but since the contributing guideline asks for specification of the issue I declared it as a bug right away. Please feel free to change, if it is not correct.
Using the
gpytorch.settings.verbose_linalg(state=True)
context revealed, that switching from single- to multi-task GPs, the linalg output changed from "CG" to "symeig". This is unexpected for me, because from the paper gardner2018gpytorch I do not see a reason why the "mBCG" algorithm (which I assume is called "CG" by the output shown below) should not be applicable in the multi-task case. Of course I could be missing that point, in this case please be so kind and point me to that.To reproduce
Primarily, I adapted the GPyTorch Regression Tutorial (GPU) from the documentation.
I wanted to make it convenient to easily switch back and forth between single- and multi-task GPs, shapes of tensors, and CPU/GPU, so I wrapped the code to reproduce in a function
run_gpytorch(...)
, which is called with desired kwargs. The signature isnum_samples:int
specifies the number of data points that are generated for training the GP, as I came acrossgpytorch.settings.max_cholesky_size(value)
in the documentation, and num_samples allows to easily change the size of the matricesdims_in:int
specifies the dimensionality of the inputs to the GP, one might also call it number of features of the input datadims_out:int
specifies the dimensionality of the outputs from the GP. This also is equal tonum_tasks
in thegpytorch.means.MultitaskMean
,...MultitaskKernel
, and...MultitaskGaussianLikelihood
classes.device:str
specifies whether the GP should be trained and tested using CPU or GPU** Code snippet to reproduce **
Outputs
The code above generates the following output, where the usage of
... symeig ...
is revealed:Expected Behavior
I expected the lines in the output starting with
LinAlg (Verbose) ...
to say something likeLinAlg (Verbose) - DEBUG - Running CG ...
instead ofLinAlg (Verbose) - DEBUG - Running symeig...
This is because testing a single-task GP, run by (notice dims_out changed to 1)
generates the following output, revealing
... CG ...
:System information
print(gpytorch.__version__)
) --> 1.11print(torch.__version__)
) --> 2.3.0print(sys.__version__)
) --> 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:20:11) [MSC v.1938 64 bit (AMD64)]micromamba info
) --> libmamba version: 1.5.8Additional context
I was profiling GPyTorch the other day (using cProfiler), and noticed, that in the multi-task GPs the linalg solver called by GPyTorch was
torch._C._linalg.linalg_eigh
. That lead to the investigation above. If you are interested in that, I can also provide the profiling information.I am using a jupyter notebook inside VSCode and a mamba environment created by the following yaml file, using
micromamba create -f .\{file_name}.yml
which resulted in the specs below, usingmicromamba env export > {other_file_name}.yml
.The text was updated successfully, but these errors were encountered: