Skip to content

Commit dff75b0

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Enable unobserved task support in MultiTaskGP (meta-pytorch#3145)
Summary: Pull Request resolved: meta-pytorch#3145 Permits an MTGP to predict on an unobserved task, addressing these issues: meta-pytorch#2360 meta-pytorch#3085 To do this, we assume that the unobserved task is maximally correlated with the target tasks (equally with each, by averaging the elements). Exact heuristic on correlation is definitely up for discussion, but this seems like a decent default assumption. Will come in handy for TL initialization. Differential Revision: D90769576
1 parent 811ef3f commit dff75b0

File tree

4 files changed

+167
-12
lines changed

4 files changed

+167
-12
lines changed

botorch/models/fully_bayesian_multitask.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
r"""Multi-task Gaussian Process Regression models with fully Bayesian inference."""
88

99
from collections.abc import Mapping
10-
from typing import Any, NoReturn, TypeVar
10+
from typing import Any, NoReturn, Self, TypeVar
1111

1212
import pyro
1313
import torch
@@ -19,7 +19,10 @@
1919
reshape_and_detach,
2020
SaasPyroModel,
2121
)
22-
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
22+
from botorch.models.gpytorch import (
23+
BatchedMultiOutputGPyTorchModel,
24+
MultiTaskGPyTorchModel,
25+
)
2326
from botorch.models.multitask import MultiTaskGP
2427
from botorch.models.transforms.input import InputTransform
2528
from botorch.models.transforms.outcome import OutcomeTransform
@@ -55,6 +58,7 @@ def set_inputs(
5558
train_Yvar: Tensor | None,
5659
task_feature: int,
5760
task_rank: int | None = None,
61+
all_tasks: list[int] | None = None,
5862
) -> None:
5963
"""Set the training data.
6064
@@ -73,7 +77,11 @@ def set_inputs(
7377
task_feature = task_feature % train_X.shape[-1]
7478
super().set_inputs(train_X, train_Y, train_Yvar)
7579
# obtain a list of task indicies
76-
all_tasks = train_X[:, task_feature].unique().to(dtype=torch.long).tolist()
80+
all_tasks = (
81+
train_X[:, task_feature].unique().to(dtype=torch.long).tolist()
82+
if all_tasks is None
83+
else all_tasks
84+
)
7785
self.task_feature = task_feature
7886
self.num_tasks = len(all_tasks)
7987
self.task_rank = task_rank or self.num_tasks
@@ -242,7 +250,10 @@ def __init__(
242250
outputs for. If omitted, return outputs for all task indices.
243251
rank: The num of learned task embeddings to be used in the task kernel.
244252
If omitted, use a full rank (i.e. number of tasks) kernel.
245-
all_tasks: NOT SUPPORTED!
253+
all_tasks: A list of all task indices. If omitted, all tasks will be
254+
inferred from the task feature column of the training data. Used to
255+
inform the model about the total number of tasks, including any
256+
unobserved tasks.
246257
outcome_transform: An outcome transform that is applied to the
247258
training data during instantiation and to the posterior during
248259
inference (that is, the ``Posterior`` obtained by calling
@@ -310,6 +321,7 @@ def __init__(
310321
train_Yvar=train_Yvar,
311322
task_feature=task_feature,
312323
task_rank=self._rank,
324+
all_tasks=all_tasks,
313325
)
314326
self.pyro_model: MultitaskSaasPyroModel = pyro_model
315327
if outcome_transform is not None:
@@ -383,6 +395,20 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
383395
_,
384396
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
385397

398+
def eval(self) -> Self:
399+
r"""Puts the model in eval mode.
400+
401+
Circumvents the need to call MultiTaskGP.eval(), which computes the
402+
task_covar_matrix for non-observed tasks. This is not needed for fully
403+
Bayesian models, since the non-observed tasks' covar factors are instead
404+
sampled.
405+
406+
Returns:
407+
The model itself.
408+
"""
409+
self._check_if_fitted()
410+
return MultiTaskGPyTorchModel.eval(self)
411+
386412
def posterior(
387413
self,
388414
X: Tensor,

botorch/models/multitask.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from __future__ import annotations
3131

3232
import math
33-
from typing import Any
33+
from typing import Any, Self
3434

3535
import torch
3636
from botorch.acquisition.objective import PosteriorTransform
@@ -238,7 +238,24 @@ def __init__(
238238
"This is not allowed as it will lead to errors during model training."
239239
)
240240
all_tasks = all_tasks or all_tasks_inferred
241-
self.num_tasks = len(all_tasks_inferred)
241+
# Compute observed and unobserved task indices when all_tasks includes
242+
# unobserved tasks
243+
sorted_all_tasks = sorted(all_tasks)
244+
if set(all_tasks) != set(all_tasks_inferred):
245+
observed_set = set(all_tasks_inferred)
246+
self._observed_task_indices = torch.tensor(
247+
[i for i, t in enumerate(sorted_all_tasks) if t in observed_set],
248+
dtype=torch.long,
249+
)
250+
self._unobserved_task_indices = torch.tensor(
251+
[i for i, t in enumerate(sorted_all_tasks) if t not in observed_set],
252+
dtype=torch.long,
253+
)
254+
else:
255+
# All tasks are observed - set observed indices to all tasks
256+
self._observed_task_indices = torch.arange(len(all_tasks), dtype=torch.long)
257+
self._unobserved_task_indices = torch.tensor([], dtype=torch.long)
258+
self.num_tasks = len(all_tasks)
242259
if outcome_transform == DEFAULT:
243260
outcome_transform = Standardize(m=1, batch_shape=train_X.shape[:-2])
244261
if outcome_transform is not None:
@@ -321,7 +338,7 @@ def __init__(
321338
default_task_value=None if output_tasks is None else output_tasks[0],
322339
)
323340
self.register_buffer("_task_mapper", task_mapper)
324-
self._expected_task_values = set(all_tasks_inferred)
341+
self._expected_task_values = set(all_tasks)
325342
if input_transform is not None:
326343
self.input_transform = input_transform
327344
if outcome_transform is not None:
@@ -407,6 +424,28 @@ def forward(self, x: Tensor) -> MultivariateNormal:
407424
covar_x = self.covar_module(x_covar)
408425
return MultivariateNormal(mean_x, covar_x)
409426

427+
def eval(self) -> Self:
428+
r"""Puts the model in ``eval`` mode.
429+
430+
When unobserved tasks are present (i.e., ``all_tasks`` includes tasks not in
431+
the training data), this method sets the covariance factor for unobserved tasks
432+
to the mean of the observed tasks' covariance factors. This provides a
433+
reasonable initialization for prediction on unobserved tasks.
434+
"""
435+
if len(self._unobserved_task_indices) > 0:
436+
task_covar_module = self.covar_module.kernels[1]
437+
# Get the current covar_factor (transformed from raw_covar_factor)
438+
covar_factor = task_covar_module.covar_factor
439+
# Compute mean of observed tasks' covar_factor rows
440+
observed_covar_factor = covar_factor[self._observed_task_indices]
441+
mean_covar_factor = observed_covar_factor.mean(dim=0)
442+
# Create new covar_factor with unobserved tasks set to mean
443+
new_covar_factor = covar_factor.clone()
444+
new_covar_factor[self._unobserved_task_indices] = mean_covar_factor
445+
# Set the new covar_factor (this applies inverse_transform internally)
446+
task_covar_module._set_covar_factor(new_covar_factor)
447+
return super().eval()
448+
410449
@classmethod
411450
def get_all_tasks(
412451
cls,

test/models/test_fully_bayesian_multitask.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,8 @@ def test_fit_model_with_task_mapper(self) -> None:
529529
self.assertTrue(
530530
torch.equal(model._task_mapper, torch.tensor([0, 1, 1], **tkwargs))
531531
)
532+
# Verify the pyro_model has the correct number of tasks (3, not 2)
533+
self.assertEqual(model.pyro_model.num_tasks, 3)
532534
self.test_fit_model(
533535
use_outcome_transform=True,
534536
all_tasks=all_tasks,

test/models/test_multitask.py

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,12 @@ def test_all_tasks_input(self) -> None:
445445
model = MultiTaskGP(
446446
train_X=train_X, train_Y=train_Y, task_feature=0, all_tasks=[0, 1, 2, 3]
447447
)
448-
self.assertEqual(model.num_tasks, 2)
448+
self.assertEqual(model.num_tasks, 4)
449449
# Check that PositiveIndexKernel knows of all tasks.
450-
self.assertEqual(model.covar_module.kernels[1].raw_covar_factor.shape[0], 2)
450+
self.assertEqual(model.covar_module.kernels[1].raw_covar_factor.shape[0], 4)
451+
# Check that observed and unobserved task indices are computed correctly.
452+
self.assertEqual(model._observed_task_indices.tolist(), [0, 1])
453+
self.assertEqual(model._unobserved_task_indices.tolist(), [2, 3])
451454

452455
def test_MultiTaskGP_construct_inputs(self) -> None:
453456
for dtype, fixed_noise, skip_task_features_in_datasets in zip(
@@ -540,13 +543,98 @@ def test_validatation_of_task_values(self) -> None:
540543
validate_task_values=True,
541544
)
542545

546+
# Task 2 is in all_tasks, so it should be valid even with validation enabled
547+
self.assertTrue(
548+
torch.equal(
549+
torch.tensor([1], **tkwargs),
550+
model._map_tasks(task_values=torch.tensor([2], **tkwargs)),
551+
)
552+
)
553+
554+
# Task 3 is NOT in all_tasks, so it should raise an error
543555
with self.assertRaisesRegex(
544556
ValueError,
545557
"Received invalid raw task values. Expected raw value to be in"
546-
r" \{0, 1\}, but got unexpected task"
547-
r" values: \{2\}.",
558+
r" \{0, 1, 2\}, but got unexpected task"
559+
r" values: \{3\}.",
548560
):
549-
model._map_tasks(task_values=torch.tensor([2], **tkwargs))
561+
model._map_tasks(task_values=torch.tensor([3], **tkwargs))
562+
563+
def test_multitask_gp_unobserved_tasks(self) -> None:
564+
"""Test MultiTaskGP with unobserved tasks.
565+
566+
This test verifies that:
567+
1. Creating a model with all_tasks including unobserved tasks works
568+
2. In train mode, unobserved task covar_factor is at random initialization
569+
3. In eval mode, unobserved task covar_factor is set to mean of observed
570+
4. Predictions work for the unobserved task
571+
"""
572+
tkwargs = {"device": self.device, "dtype": torch.double}
573+
574+
# Create data for tasks 0 and 2 only (task 1 is unobserved)
575+
_, (train_X, train_Y, _) = gen_multi_task_dataset(task_values=[0, 2], **tkwargs)
576+
577+
# Create model with all_tasks=[0, 1, 2] including unobserved task 1
578+
model = MultiTaskGP(
579+
train_X=train_X,
580+
train_Y=train_Y,
581+
task_feature=0,
582+
all_tasks=[0, 1, 2],
583+
)
584+
model.to(**tkwargs)
585+
586+
# Verify model.num_tasks == 3
587+
self.assertEqual(model.num_tasks, 3)
588+
589+
# Verify observed and unobserved task indices are correctly set
590+
self.assertEqual(model._observed_task_indices.tolist(), [0, 2])
591+
self.assertEqual(model._unobserved_task_indices.tolist(), [1])
592+
593+
# Get the task covariance module
594+
task_covar_module = model.covar_module.kernels[1]
595+
596+
# In train mode, get the covar_factor for unobserved task (index 1)
597+
model.train()
598+
train_covar_factor = task_covar_module.covar_factor.clone()
599+
unobserved_train_covar = train_covar_factor[1]
600+
observed_train_covar = train_covar_factor[[0, 2]]
601+
mean_observed_train = observed_train_covar.mean(dim=0)
602+
603+
# Unobserved task covar_factor should be at random init in train mode
604+
# (very unlikely to be exactly equal to mean of observed)
605+
self.assertFalse(
606+
torch.allclose(unobserved_train_covar, mean_observed_train, atol=1e-6)
607+
)
608+
609+
# Switch to eval mode
610+
model.eval()
611+
612+
# In eval mode, get the covar_factor for unobserved task
613+
eval_covar_factor = task_covar_module.covar_factor.clone()
614+
unobserved_eval_covar = eval_covar_factor[1]
615+
observed_eval_covar = eval_covar_factor[[0, 2]]
616+
mean_observed_eval = observed_eval_covar.mean(dim=0)
617+
618+
# Unobserved task covar_factor should equal mean of observed in eval mode
619+
self.assertTrue(
620+
torch.allclose(unobserved_eval_covar, mean_observed_eval, atol=1e-6)
621+
)
622+
623+
# Verify predictions work for the unobserved task
624+
# Create test input for unobserved task (task 1)
625+
test_X = torch.rand(3, 2, **tkwargs)
626+
test_X[:, 0] = 1.0 # Set task feature to 1 (unobserved task)
627+
628+
with torch.no_grad():
629+
posterior = model.posterior(X=test_X)
630+
631+
# Verify posterior has expected shape
632+
self.assertEqual(posterior.mean.shape, torch.Size([3, 1]))
633+
self.assertEqual(posterior.variance.shape, torch.Size([3, 1]))
634+
635+
# Verify we can sample from the posterior
636+
samples = posterior.rsample(sample_shape=torch.Size([2]))
637+
self.assertEqual(samples.shape, torch.Size([2, 3, 1]))
550638

551639

552640
class TestKroneckerMultiTaskGP(BotorchTestCase):

0 commit comments

Comments
 (0)