File tree Expand file tree Collapse file tree 2 files changed +45
-0
lines changed
Expand file tree Collapse file tree 2 files changed +45
-0
lines changed Original file line number Diff line number Diff line change 2121
2222
2323class TestQStatisticalDistanceActiveLearning (BotorchTestCase ):
24+ def test_invalid_distance_metric (self ):
25+ """Test that invalid distance_metric raises ValueError."""
26+ torch .manual_seed (1 )
27+ tkwargs = {"device" : self .device , "dtype" : torch .double }
28+ input_dim = 2
29+
30+ model = get_fully_bayesian_model (
31+ train_X = torch .rand (4 , input_dim , ** tkwargs ),
32+ train_Y = torch .rand (4 , 1 , ** tkwargs ),
33+ num_models = 3 ,
34+ ** tkwargs ,
35+ )
36+
37+ with self .assertRaises (ValueError ):
38+ qStatisticalDistanceActiveLearning (
39+ model = model ,
40+ distance_metric = "invalid_metric" ,
41+ )
42+
2443 def test_q_statistical_distance_active_learning (self ):
2544 torch .manual_seed (1 )
2645 tkwargs = {"device" : self .device }
Original file line number Diff line number Diff line change @@ -37,6 +37,32 @@ def initialize(self):
3737 self .n_context ,
3838 )
3939
40+ def test_default_hidden_dims (self ):
41+ """Test that default hidden dimensions are used when not provided."""
42+ x_dim = 2
43+ y_dim = 1
44+ r_dim = 8
45+ z_dim = 8
46+ n_context = 20
47+
48+ # Create model without specifying hidden dimensions (use defaults)
49+ model = NeuralProcessModel (
50+ train_X = torch .rand (100 , x_dim ),
51+ train_Y = torch .rand (100 , y_dim ),
52+ r_hidden_dims = None ,
53+ z_hidden_dims = None ,
54+ decoder_hidden_dims = None ,
55+ x_dim = x_dim ,
56+ y_dim = y_dim ,
57+ r_dim = r_dim ,
58+ z_dim = z_dim ,
59+ n_context = n_context ,
60+ )
61+
62+ # Test that the model works with default dimensions
63+ output = model (model .train_X , model .train_Y )
64+ self .assertEqual (output .loc .shape , (80 , y_dim ))
65+
4066 def test_r_encoder (self ):
4167 self .initialize ()
4268 input = torch .rand (100 , self .x_dim + self .y_dim )
You can’t perform that action at this time.
0 commit comments