Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,10 @@ def _get_dummy_mcmc_samples(
return mcmc_samples

def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = True
self,
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False,
) -> None:
r"""Custom logic for loading the state dict.

Expand All @@ -980,7 +983,7 @@ def load_state_dict(
)
self.load_mcmc_samples(mcmc_samples=mcmc_samples)
# Load the actual samples from the state dict
super().load_state_dict(state_dict=state_dict, strict=strict)
super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)


class SaasFullyBayesianSingleTaskGP(FullyBayesianSingleTaskGP):
Expand Down Expand Up @@ -1047,7 +1050,10 @@ def median_weight_variance(self) -> Tensor:
return weight_variance.median(0).values.squeeze(0)

def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = True
self,
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False,
) -> None:
r"""Custom logic for loading the state dict.

Expand Down Expand Up @@ -1077,4 +1083,4 @@ def load_state_dict(
mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs)
self.load_mcmc_samples(mcmc_samples=mcmc_samples)
# Load the actual samples from the state dict
super().load_state_dict(state_dict=state_dict, strict=strict)
super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
19 changes: 15 additions & 4 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def load_state_dict(
state_dict: Mapping[str, Any],
strict: bool = True,
keep_transforms: bool = True,
assign: bool = False,
) -> None:
r"""Load the model state.

Expand All @@ -337,9 +338,14 @@ def load_state_dict(
keep_transforms: A boolean indicating whether to keep the input and outcome
transforms. Doing so is useful when loading a model that was trained on
a full set of data, and is later loaded with a subset of the data.
assign: When set to ``False``, the properties of the tensors in the current
module are preserved whereas setting it to ``True`` preserves
properties of the Tensors in the state dict. The only
exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`
for which the value from the module is preserved. Default: ``False``.
"""
if not keep_transforms:
super().load_state_dict(state_dict, strict)
super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
return

should_outcome_transform = (
Expand Down Expand Up @@ -368,10 +374,12 @@ def load_state_dict(
BotorchWarning,
stacklevel=3,
)
super().load_state_dict(state_dict, strict)
super().load_state_dict(
state_dict=state_dict, strict=strict, assign=assign
)
return

super().load_state_dict(state_dict, strict)
super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)

if getattr(self, "input_transform", None) is not None:
self.input_transform.eval()
Expand Down Expand Up @@ -763,8 +771,11 @@ def load_state_dict(
self,
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False,
) -> None:
return ModelList.load_state_dict(self, state_dict, strict)
return ModelList.load_state_dict(
self, state_dict=state_dict, strict=strict, assign=assign
)

# pyre-fixme[14]: Inconsistent override in return types
def posterior(
Expand Down
3 changes: 2 additions & 1 deletion botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ def load_state_dict(
state_dict: Mapping[str, Any],
strict: bool = True,
keep_transforms: bool = True,
assign: bool = False,
) -> None:
"""Initialize the fully Bayesian models before loading the state dict."""
for i, m in enumerate(self.models):
Expand All @@ -589,7 +590,7 @@ def load_state_dict(
for k, v in state_dict.items()
if k.startswith(f"models.{i}.")
}
m.load_state_dict(filtered_dict, strict=strict)
m.load_state_dict(filtered_dict, strict=strict, assign=assign)

def fantasize(
self,
Expand Down
63 changes: 63 additions & 0 deletions test/models/test_gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,69 @@ def test_load_state_dict_with_transforms(self):
)
)

def test_load_state_dict_assign_parameter(self):
"""Test that the assign parameter correctly controls tensor property preservation.

With assign=False (default): properties of the current model's tensors are preserved.
With assign=True: properties of the state dict's tensors are preserved.
"""
# Create base model with double precision
tkwargs_double = {"device": self.device, "dtype": torch.double}
train_X_double = torch.rand(5, 2, **tkwargs_double)
train_Y_double = torch.sin(train_X_double).sum(dim=1, keepdim=True)

base_model = SingleTaskGP(
train_X=train_X_double,
train_Y=train_Y_double,
**_get_input_output_transform(d=2, indices=[0, 1], m=1),
)
state_dict_double = base_model.state_dict()

# Create a new model with float32 precision (different dtype)
tkwargs_float = {"device": self.device, "dtype": torch.float}
train_X_float = torch.rand(5, 2, **tkwargs_float)
train_Y_float = torch.sin(train_X_float).sum(dim=1, keepdim=True)

# Test assign=False (default behavior)
model_assign_false = SingleTaskGP(
train_X=train_X_float,
train_Y=train_Y_float,
**_get_input_output_transform(d=2, indices=[0, 1], m=1),
)

# Load double precision state dict with assign=False
model_assign_false.load_state_dict(
state_dict_double, keep_transforms=True, assign=False
)

# With assign=False, the model should keep its original float32 dtype
self.assertEqual(model_assign_false.train_inputs[0].dtype, torch.float)

# Test assign=True
model_assign_true = SingleTaskGP(
train_X=train_X_float,
train_Y=train_Y_float,
**_get_input_output_transform(d=2, indices=[0, 1], m=1),
)

# Load double precision state dict with assign=True
model_assign_true.load_state_dict(
state_dict_double, keep_transforms=True, assign=True
)

# With assign=True, the model should adopt the state dict's double dtype
self.assertEqual(model_assign_true.train_inputs[0].dtype, torch.double)
self.assertEqual(
model_assign_true.train_inputs[0].dtype,
state_dict_double["train_inputs.0"].dtype,
)

# Verify the two models have different dtypes
self.assertNotEqual(
model_assign_false.train_inputs[0].dtype,
model_assign_true.train_inputs[0].dtype,
)

def test_load_state_dict_no_transforms(self):
tkwargs = {"device": self.device, "dtype": torch.double}

Expand Down
Loading