Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid unnecessary memory allocation for covariance downdate in SGPR prediction strategy #2559

Merged
merged 2 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gpytorch/models/exact_prediction_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,5 +865,5 @@ def exact_predictive_covar(self, test_test_covar, test_train_covar):
"This is likely a bug in GPyTorch."
)

res = test_test_covar - (L @ (covar_cache @ L.transpose(-1, -2)))
res = test_test_covar - MatmulLinearOperator(L, covar_cache @ L.mT)
return res
25 changes: 25 additions & 0 deletions gpytorch/priors/prior.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
#!/usr/bin/env python3

from abc import ABC
from typing import Any, Mapping

from torch.distributions import TransformedDistribution
from torch.nn import Module

from ..distributions import Distribution
from .utils import _load_transformed_to_base_dist


TRANSFORMED_ERROR_MSG = """Priors of TransformedDistributions should not have their \
'_transformed' attributes modified, these are just copies of the base attribute. \
Please modify the base attribute (e.g. {}) instead."""


class Prior(Distribution, Module, ABC):
Expand All @@ -25,3 +33,20 @@ def log_prob(self, x):
:rtype: torch.Tensor
"""
return super(Prior, self).log_prob(self.transform(x))

def load_state_dict(self, state_dict: Mapping[str, Any], *args, **kwargs):
Module.load_state_dict(self, state_dict, *args, **kwargs)
if isinstance(self, TransformedDistribution):
_load_transformed_to_base_dist(self)

def __setattr__(self, name: str, value: Any) -> None:
if hasattr(self, name) and "_transformed_" in name:
base_attr_name = name.replace("_transformed_", "")
raise AttributeError(TRANSFORMED_ERROR_MSG.format(base_attr_name))

elif hasattr(self, f"_transformed_{name}"):
self.base_dist.__setattr__(name, value)
super().__setattr__(f"_transformed_{name}", value)

else:
return super().__setattr__(name, value)
3 changes: 3 additions & 0 deletions gpytorch/priors/torch_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class HalfNormalPrior(Prior, HalfNormal):
def __init__(self, scale, validate_args=None, transform=None):
TModule.__init__(self)
HalfNormal.__init__(self, scale=scale, validate_args=validate_args)
_bufferize_attributes(self, ("scale",))
self._transform = transform

def expand(self, batch_shape):
Expand All @@ -54,6 +55,7 @@ class LogNormalPrior(Prior, LogNormal):
def __init__(self, loc, scale, validate_args=None, transform=None):
TModule.__init__(self)
LogNormal.__init__(self, loc=loc, scale=scale, validate_args=validate_args)
_bufferize_attributes(self, ("loc", "scale"))
self._transform = transform

def expand(self, batch_shape):
Expand Down Expand Up @@ -84,6 +86,7 @@ class HalfCauchyPrior(Prior, HalfCauchy):
def __init__(self, scale, validate_args=None, transform=None):
TModule.__init__(self)
HalfCauchy.__init__(self, scale=scale, validate_args=validate_args)
_bufferize_attributes(self, ("scale",))
self._transform = transform

def expand(self, batch_shape):
Expand Down
33 changes: 29 additions & 4 deletions gpytorch/priors/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,36 @@
#!/usr/bin/env python3

from torch.distributions import TransformedDistribution


def _bufferize_attributes(module, attributes):
attr_clones = {attr: getattr(module, attr).clone() for attr in attributes}
for attr, value in attr_clones.items():
delattr(module, attr)
module.register_buffer(attr, value)
r"""
Adds the parameters of the prior as a torch buffer to enable saving/
loading to/from state_dicts.
For TransformedDistributions Adds a _transformed_ attribute to the
parameters. This enables its parameters to be saved and
loaded to/from state_dicts, as the original parameters cannot be.
"""
if isinstance(module, TransformedDistribution):
for attr in attributes:
module.register_buffer(f"_transformed_{attr}", getattr(module, attr))
else:
attr_clones = {attr: getattr(module, attr).clone() for attr in attributes}
for attr, value in attr_clones.items():
delattr(module, attr)
module.register_buffer(attr, value)


def _load_transformed_to_base_dist(module):
r"""loads the _transformed_ attributes to the parameters of a torch
TransformedDistribution. This enables its parameters to be saved and
loaded to/from state_dicts, as the original parameters cannot be.
"""
transf_str = "_transformed_"
transformed_attrs = [attr for attr in dir(module) if transf_str in attr]
for transf_attr in transformed_attrs:
base_attr_name = transf_attr.replace(transf_str, "")
setattr(module.base_dist, base_attr_name, getattr(module, transf_attr))


def _del_attributes(module, attributes, raise_on_error=False):
Expand Down
70 changes: 70 additions & 0 deletions test/priors/test_prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python3

import unittest

from torch import Tensor

from gpytorch.priors import GammaPrior, HalfCauchyPrior, LogNormalPrior, NormalPrior


TRANSFORMED_ERROR_MSG = """Priors of TransformedDistributions should not have their \
'_transformed' attributes modified, these are just copies of the base attribute. \
Please modify the base attribute (e.g. {}) instead."""


class TestPrior(unittest.TestCase):
def test_state_dict(self):
normal = NormalPrior(0.1, 1).state_dict()
self.assertTrue("loc" in normal)
self.assertTrue("scale" in normal)
self.assertEqual(normal["loc"], 0.1)

gamma = GammaPrior(1.1, 2).state_dict()
self.assertTrue("concentration" in gamma)
self.assertTrue("rate" in gamma)
self.assertEqual(gamma["concentration"], 1.1)

ln = LogNormalPrior(2.1, 1.2).state_dict()
self.assertTrue("_transformed_loc" in ln)
self.assertTrue("_transformed_scale" in ln)
self.assertEqual(ln["_transformed_loc"], 2.1)

hc = HalfCauchyPrior(1.3).state_dict()
self.assertTrue("_transformed_scale" in hc)

def test_load_state_dict(self):
ln1 = LogNormalPrior(loc=0.5, scale=0.1)
ln2 = LogNormalPrior(loc=2.5, scale=2.1)
gm1 = GammaPrior(concentration=0.5, rate=0.1)
gm2 = GammaPrior(concentration=2.5, rate=2.1)
hc1 = HalfCauchyPrior(scale=1.1)
hc2 = HalfCauchyPrior(scale=101.1)

ln2.load_state_dict(ln1.state_dict())
self.assertEqual(ln2.loc, ln1.loc)
self.assertEqual(ln2.scale, ln1.scale)

gm2.load_state_dict(gm1.state_dict())
self.assertEqual(gm2.concentration, gm1.concentration)
self.assertEqual(gm2.rate, gm1.rate)

hc2.load_state_dict(hc1.state_dict())
self.assertEqual(hc2.scale, hc1.scale)

def test_transformed_attributes(self):
norm = NormalPrior(loc=2.5, scale=2.1)
ln = LogNormalPrior(loc=2.5, scale=2.1)
hc = HalfCauchyPrior(scale=2.2)

with self.assertRaisesRegex(AttributeError, "'NormalPrior' object has no attribute '_transformed_loc'"):
getattr(norm, "_transformed_loc")

self.assertTrue(getattr(ln, "_transformed_loc"), 2.5)
norm.loc = Tensor([1.01])
ln.loc = Tensor([1.01])
self.assertEqual(ln._transformed_loc, 1.01)
with self.assertRaises(AttributeError):
ln._transformed_loc = 1.1

with self.assertRaises(AttributeError):
hc._transformed_scale = 1.01
61 changes: 61 additions & 0 deletions test/priors/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env python3

import unittest

from torch import Tensor

from gpytorch.priors import GammaPrior, HalfCauchyPrior, LogNormalPrior, NormalPrior


class TestPrior(unittest.TestCase):
def test_state_dict(self):
normal = NormalPrior(0.1, 1).state_dict()
self.assertTrue("loc" in normal)
self.assertTrue("scale" in normal)
self.assertEqual(normal["loc"], 0.1)

gamma = GammaPrior(1.1, 2).state_dict()
self.assertTrue("concentration" in gamma)
self.assertTrue("rate" in gamma)
self.assertEqual(gamma["concentration"], 1.1)

ln = LogNormalPrior(2.1, 1.2).state_dict()
self.assertTrue("_transformed_loc" in ln)
self.assertTrue("_transformed_scale" in ln)
self.assertEqual(ln["_transformed_loc"], 2.1)

hc = HalfCauchyPrior(1.3).state_dict()
self.assertTrue("_transformed_scale" in hc)

def test_load_state_dict(self):
ln1 = LogNormalPrior(loc=0.5, scale=0.1)
ln2 = LogNormalPrior(loc=2.5, scale=2.1)
gm1 = GammaPrior(concentration=0.5, rate=0.1)
gm2 = GammaPrior(concentration=2.5, rate=2.1)
hc1 = HalfCauchyPrior(scale=1.1)
hc2 = HalfCauchyPrior(scale=101.1)

ln2.load_state_dict(ln1.state_dict())
self.assertEqual(ln2.loc, ln1.loc)
self.assertEqual(ln2.scale, ln1.scale)

gm2.load_state_dict(gm1.state_dict())
self.assertEqual(gm2.concentration, gm1.concentration)
self.assertEqual(gm2.rate, gm1.rate)

hc2.load_state_dict(hc1.state_dict())
self.assertEqual(hc2.scale, hc1.scale)

def test_transformed_attributes(self):
norm = NormalPrior(loc=2.5, scale=2.1)
ln = LogNormalPrior(loc=2.5, scale=2.1)
hc = HalfCauchyPrior(scale=2.2)

with self.assertRaisesRegex(AttributeError, "'NormalPrior' object has no attribute '_transformed_loc'"):
getattr(norm, "_transformed_loc")

self.assertTrue(getattr(ln, "_transformed_loc"), 2.5)
norm.loc = Tensor([1.01])
ln.loc = Tensor([1.01])
self.assertEqual(ln._transformed_loc, 1.01)
self.assertEqual(hc._transformed_scale, 2.2)
Loading