Skip to content

Commit 2027761

Browse files
crasandersfacebook-github-bot
authored andcommitted
add finished property to manual generators (#409)
Summary: Pull Request resolved: #409 This adds a finished property to the manual generator classes so that they can keep track of whether they have generated all their points. Once this is hooked into the strategy's finishing logic, it should make writing configs simpler. Reviewed By: JasonKChow Differential Revision: D64600239 fbshipit-source-id: f8d2c18f516592a09cd545e76b52a8878b22b66b
1 parent a3d3a33 commit 2027761

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

aepsych/generators/manual_generator.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import warnings
9-
from typing import Optional, Union, Dict
9+
from typing import Dict, Optional, Union
1010

1111
import numpy as np
1212
import torch
13+
from torch.quasirandom import SobolEngine
14+
1315
from aepsych.config import Config
1416
from aepsych.generators.base import AEPsychGenerator
1517
from aepsych.models.base import AEPsychMixin
1618
from aepsych.utils import _process_bounds
17-
from torch.quasirandom import SobolEngine
1819

1920

2021
class ManualGenerator(AEPsychGenerator):
@@ -95,6 +96,10 @@ def get_config_options(cls, config: Config, name: Optional[str] = None) -> Dict:
9596

9697
return options
9798

99+
@property
100+
def finished(self):
101+
return self._idx >= len(self.points)
102+
98103

99104
class SampleAroundPointsGenerator(ManualGenerator):
100105
"""Generator that samples in a window around reference points in a predefined list."""
@@ -131,9 +136,9 @@ def __init__(
131136
grid = self.engine.draw(samples_per_point)
132137
grid = p_lb + (p_ub - p_lb) * grid
133138
generated.append(grid)
134-
generated = torch.Tensor(np.vstack(generated)) #type: ignore
139+
generated = torch.Tensor(np.vstack(generated)) # type: ignore
135140

136-
super().__init__(lb, ub, generated, dim, shuffle, seed) #type: ignore
141+
super().__init__(lb, ub, generated, dim, shuffle, seed) # type: ignore
137142

138143
@classmethod
139144
def get_config_options(cls, config: Config, name: Optional[str] = None) -> Dict:

tests/generators/test_manual_generator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import numpy.testing as npt
12+
1213
from aepsych.config import Config
1314
from aepsych.generators import ManualGenerator, SampleAroundPointsGenerator
1415

@@ -50,6 +51,7 @@ def test_manual_generator(self):
5051
gen = ManualGenerator.from_config(config)
5152
npt.assert_equal(gen.lb.numpy(), np.array([0, 0]))
5253
npt.assert_equal(gen.ub.numpy(), np.array([1, 1]))
54+
self.assertFalse(gen.finished)
5355

5456
p1 = list(gen.gen()[0])
5557
p2 = list(gen.gen()[0])
@@ -60,6 +62,7 @@ def test_manual_generator(self):
6062
self.assertEqual(sorted([p1, p2, p3, p4]), points)
6163
self.assertEqual(gen.max_asks, len(points))
6264
self.assertEqual(gen.seed, 123)
65+
self.assertTrue(gen.finished)
6366

6467

6568
class TestSampleAroundPointsGenerator(unittest.TestCase):
@@ -86,13 +89,16 @@ def test_sample_around_points_generator(self):
8689
npt.assert_equal(gen.ub.numpy(), np.array([1, 1]))
8790
self.assertEqual(gen.max_asks, len(points * samples_per_point))
8891
self.assertEqual(gen.seed, 123)
92+
self.assertFalse(gen.finished)
8993

9094
points = gen.gen(gen.max_asks)
9195
for i in range(len(window)):
9296
npt.assert_array_less(points[:, i], points[:, i] + window[i])
9397
npt.assert_array_less(np.array([0] * len(points)), points[:, i])
9498
npt.assert_array_less(points[:, i], np.array([1] * len(points)))
9599

100+
self.assertTrue(gen.finished)
101+
96102

97103
if __name__ == "__main__":
98104
unittest.main()

0 commit comments

Comments
 (0)