-
Notifications
You must be signed in to change notification settings - Fork 578
Fix for #2674 - Corrected sizes for alpha in RQKernel when using Deep GPs #2677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Adjust alpha handling in postprocess_rq function
kayween
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @mrlj-hash, thank you for the PR! I left some comments. Can you also update the PR description to include a brief summary (e.g., what was the bug and how it's fixed).
| actual_value = torch.tensor(3.0).view_as(kernel.alpha) | ||
| self.assertLess(torch.norm(kernel.alpha - actual_value), 1e-5) | ||
|
|
||
| def test_last_layer_alpha(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test only verifies the kernel matrix can be computed, but it does not verify if the computation is correct. Can we improve the test case by checking the kernel matrix against the ground truth (like other test cases in this file)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do deterministic tests like other test cases in this file (i.e., no torch.randn)? Note that determinism is not guaranteed on different devices even with fixed random seeds by torch.manual_seed.
You could do something like
x1 = torch.arange(24).view(2, 2, 3, 2)
x2 = torch.arange(16).view(2, 2, 2, 2)
gpytorch/kernels/rq_kernel.py
Outdated
| alpha = alpha.unsqueeze(-1) | ||
|
|
||
| # for loop above overruns by 1 in deep GPs due to additional sampling dimension | ||
| if len(alpha) > 1 and alpha.shape[0] != dist_mat.shape[0]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My main concern with this if-statement is that it seems a bit ad hoc. This line checks if the first dimension is broadcastable, and does an unsqueeze if not.
If alpha.shape[0] is equal to dist_mat.shape[0] by accident, we might accidentally broadcast alpha and dist_mat in a wrong way and potentially compute the kernel matrix incorrectly, even in situations when an unsqueeze is indeed required.
Implement manual conditional unsqueezing of alpha Co-authored-by: Kaiwen Wu <[email protected]>
Need last_dim_is_batch parameter to do the conditional unsqueezing of alpha
Ensure consistent whitespace
|
@mrlj-hash It's recommended to install a linter locally (e.g., flake8) or a pre-commit hook aid your development, see CONTRIBUTING.md. (That way those minor lint errors in the github workflow can be fixed more easily.) |
No description provided.