Skip to content

Commit c46b0db

Browse files
authored
Adopt singleton semantics for globally defined constraint instances (#1507)
* pickle constraints.real by reference * test constraints.real pickling * enforce singleton behavior for all singleton classes * ensure pickling by reference for parametrized constraints * remove singleton check against non-singleton constraint * singleton subclasses for global parametrized constraints * revert changes of #1519 * address old review comments * separate out bijector optimizations * minor fixes * remove unused import
1 parent f7b7080 commit c46b0db

File tree

3 files changed

+154
-23
lines changed

3 files changed

+154
-23
lines changed

numpyro/distributions/constraints.py

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,20 @@ def feasible_like(self, prototype):
9595
raise NotImplementedError
9696

9797

98-
class _Boolean(Constraint):
98+
class _SingletonConstraint(Constraint):
99+
"""
100+
A constraint type which has only one canonical instance, like constraints.real,
101+
and unlike constraints.interval.
102+
"""
103+
104+
def __new__(cls):
105+
if (not hasattr(cls, "_instance")) or (type(cls._instance) is not cls):
106+
# Do not use the singleton instance of a superclass of cls.
107+
cls._instance = super(_SingletonConstraint, cls).__new__(cls)
108+
return cls._instance
109+
110+
111+
class _Boolean(_SingletonConstraint):
99112
is_discrete = True
100113

101114
def __call__(self, x):
@@ -105,7 +118,7 @@ def feasible_like(self, prototype):
105118
return jax.numpy.zeros_like(prototype)
106119

107120

108-
class _CorrCholesky(Constraint):
121+
class _CorrCholesky(_SingletonConstraint):
109122
event_dim = 2
110123

111124
def __call__(self, x):
@@ -126,7 +139,7 @@ def feasible_like(self, prototype):
126139
)
127140

128141

129-
class _CorrMatrix(Constraint):
142+
class _CorrMatrix(_SingletonConstraint):
130143
event_dim = 2
131144

132145
def __call__(self, x):
@@ -231,6 +244,11 @@ def feasible_like(self, prototype):
231244
return jax.numpy.broadcast_to(self.lower_bound + 1, jax.numpy.shape(prototype))
232245

233246

247+
class _Positive(_GreaterThan, _SingletonConstraint):
248+
def __init__(self):
249+
super().__init__(0.0)
250+
251+
234252
class _IndependentConstraint(Constraint):
235253
"""
236254
Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many
@@ -284,6 +302,16 @@ def feasible_like(self, prototype):
284302
return self.base_constraint.feasible_like(prototype)
285303

286304

305+
class _RealVector(_IndependentConstraint, _SingletonConstraint):
306+
def __init__(self):
307+
super().__init__(_Real(), 1)
308+
309+
310+
class _RealMatrix(_IndependentConstraint, _SingletonConstraint):
311+
def __init__(self):
312+
super().__init__(_Real(), 2)
313+
314+
287315
class _LessThan(Constraint):
288316
def __init__(self, upper_bound):
289317
self.upper_bound = upper_bound
@@ -339,6 +367,16 @@ def feasible_like(self, prototype):
339367
return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype))
340368

341369

370+
class _IntegerPositive(_IntegerGreaterThan, _SingletonConstraint):
371+
def __init__(self):
372+
super().__init__(1)
373+
374+
375+
class _IntegerNonnegative(_IntegerGreaterThan, _SingletonConstraint):
376+
def __init__(self):
377+
super().__init__(0)
378+
379+
342380
class _Interval(Constraint):
343381
def __init__(self, lower_bound, upper_bound):
344382
self.lower_bound = lower_bound
@@ -367,6 +405,16 @@ def __eq__(self, other):
367405
)
368406

369407

408+
class _Circular(_Interval, _SingletonConstraint):
409+
def __init__(self):
410+
super().__init__(-math.pi, math.pi)
411+
412+
413+
class _UnitInterval(_Interval, _SingletonConstraint):
414+
def __init__(self):
415+
super().__init__(0.0, 1.0)
416+
417+
370418
class _OpenInterval(_Interval):
371419
def __call__(self, x):
372420
return (x > self.lower_bound) & (x < self.upper_bound)
@@ -379,12 +427,7 @@ def __repr__(self):
379427
return fmt_string
380428

381429

382-
class _Circular(_Interval):
383-
def __init__(self):
384-
super().__init__(-math.pi, math.pi)
385-
386-
387-
class _LowerCholesky(Constraint):
430+
class _LowerCholesky(_SingletonConstraint):
388431
event_dim = 2
389432

390433
def __call__(self, x):
@@ -420,7 +463,7 @@ def feasible_like(self, prototype):
420463
return jax.numpy.broadcast_to(value, prototype.shape)
421464

422465

423-
class _L1Ball(Constraint):
466+
class _L1Ball(_SingletonConstraint):
424467
"""
425468
Constrain to the L1 ball of any dimension.
426469
"""
@@ -437,7 +480,7 @@ def feasible_like(self, prototype):
437480
return jax.numpy.zeros_like(prototype)
438481

439482

440-
class _OrderedVector(Constraint):
483+
class _OrderedVector(_SingletonConstraint):
441484
event_dim = 1
442485

443486
def __call__(self, x):
@@ -449,7 +492,7 @@ def feasible_like(self, prototype):
449492
)
450493

451494

452-
class _PositiveDefinite(Constraint):
495+
class _PositiveDefinite(_SingletonConstraint):
453496
event_dim = 2
454497

455498
def __call__(self, x):
@@ -466,7 +509,7 @@ def feasible_like(self, prototype):
466509
)
467510

468511

469-
class _PositiveOrderedVector(Constraint):
512+
class _PositiveOrderedVector(_SingletonConstraint):
470513
"""
471514
Constrains to a positive real-valued tensor where the elements are monotonically
472515
increasing along the `event_shape` dimension.
@@ -483,7 +526,7 @@ def feasible_like(self, prototype):
483526
)
484527

485528

486-
class _Real(Constraint):
529+
class _Real(_SingletonConstraint):
487530
def __call__(self, x):
488531
# XXX: consider to relax this condition to [-inf, inf] interval
489532
return (x == x) & (x != float("inf")) & (x != float("-inf"))
@@ -492,7 +535,7 @@ def feasible_like(self, prototype):
492535
return jax.numpy.zeros_like(prototype)
493536

494537

495-
class _Simplex(Constraint):
538+
class _Simplex(_SingletonConstraint):
496539
event_dim = 1
497540

498541
def __call__(self, x):
@@ -503,7 +546,7 @@ def feasible_like(self, prototype):
503546
return jax.numpy.full_like(prototype, 1 / prototype.shape[-1])
504547

505548

506-
class _SoftplusPositive(_GreaterThan):
549+
class _SoftplusPositive(_GreaterThan, _SingletonConstraint):
507550
def __init__(self):
508551
super().__init__(lower_bound=0.0)
509552

@@ -522,7 +565,7 @@ class _ScaledUnitLowerCholesky(_LowerCholesky):
522565
pass
523566

524567

525-
class _Sphere(Constraint):
568+
class _Sphere(_SingletonConstraint):
526569
"""
527570
Constrain to the Euclidean sphere of any dimension.
528571
"""
@@ -559,18 +602,18 @@ def feasible_like(self, prototype):
559602
lower_cholesky = _LowerCholesky()
560603
scaled_unit_lower_cholesky = _ScaledUnitLowerCholesky()
561604
multinomial = _Multinomial
562-
nonnegative_integer = _IntegerGreaterThan(0)
605+
nonnegative_integer = _IntegerNonnegative()
563606
ordered_vector = _OrderedVector()
564-
positive = _GreaterThan(0.0)
607+
positive = _Positive()
565608
positive_definite = _PositiveDefinite()
566-
positive_integer = _IntegerGreaterThan(1)
609+
positive_integer = _IntegerPositive()
567610
positive_ordered_vector = _PositiveOrderedVector()
568611
real = _Real()
569-
real_vector = independent(real, 1)
570-
real_matrix = independent(real, 2)
612+
real_vector = _RealVector()
613+
real_matrix = _RealMatrix()
571614
simplex = _Simplex()
572615
softplus_lower_cholesky = _SoftplusLowerCholesky()
573616
softplus_positive = _SoftplusPositive()
574617
sphere = _Sphere()
575-
unit_interval = _Interval(0.0, 1.0)
618+
unit_interval = _UnitInterval()
576619
open_interval = _OpenInterval

numpyro/distributions/transforms.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,11 @@ def _transform_to_corr_matrix(constraint):
10651065
)
10661066

10671067

1068+
@biject_to.register(type(constraints.positive))
1069+
def _transform_to_positive(constraint):
1070+
return ExpTransform()
1071+
1072+
10681073
@biject_to.register(constraints.greater_than)
10691074
def _transform_to_greater_than(constraint):
10701075
return ComposeTransform(
@@ -1085,13 +1090,21 @@ def _transform_to_less_than(constraint):
10851090
)
10861091

10871092

1093+
@biject_to.register(type(constraints.real_matrix))
1094+
@biject_to.register(type(constraints.real_vector))
10881095
@biject_to.register(constraints.independent)
10891096
def _biject_to_independent(constraint):
10901097
return IndependentTransform(
10911098
biject_to(constraint.base_constraint), constraint.reinterpreted_batch_ndims
10921099
)
10931100

10941101

1102+
@biject_to.register(type(constraints.unit_interval))
1103+
def _transform_to_unit_interval(constraint):
1104+
return SigmoidTransform()
1105+
1106+
1107+
@biject_to.register(type(constraints.circular))
10951108
@biject_to.register(constraints.open_interval)
10961109
@biject_to.register(constraints.interval)
10971110
def _transform_to_interval(constraint):

test/test_pickle.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,31 @@
1313

1414
import numpyro
1515
import numpyro.distributions as dist
16+
from numpyro.distributions.constraints import (
17+
boolean,
18+
circular,
19+
corr_cholesky,
20+
corr_matrix,
21+
greater_than,
22+
interval,
23+
l1_ball,
24+
lower_cholesky,
25+
nonnegative_integer,
26+
ordered_vector,
27+
positive,
28+
positive_definite,
29+
positive_integer,
30+
positive_ordered_vector,
31+
real,
32+
real_matrix,
33+
real_vector,
34+
scaled_unit_lower_cholesky,
35+
simplex,
36+
softplus_lower_cholesky,
37+
softplus_positive,
38+
sphere,
39+
unit_interval,
40+
)
1641
from numpyro.infer import (
1742
HMC,
1843
HMCECS,
@@ -93,3 +118,53 @@ def test_pickle_autoguide(guide_class):
93118
)
94119
samples = predictive(random.PRNGKey(1), None, 1)
95120
assert set(samples.keys()) == {"param", "x"}
121+
122+
123+
def test_pickle_singleton_constraint():
124+
# some numpyro constraint classes such as constraints._Real, are only accessible
125+
# through their public singleton instance, (such as constraint.real). This test
126+
# ensures that pickling and unpickling singleton instances does not re-create
127+
# additional instances, which is the default behavior of pickle, and which would
128+
# break singleton semantics.
129+
singleton_constraints = (
130+
boolean,
131+
circular,
132+
corr_cholesky,
133+
corr_matrix,
134+
l1_ball,
135+
lower_cholesky,
136+
nonnegative_integer,
137+
ordered_vector,
138+
positive,
139+
positive_definite,
140+
positive_integer,
141+
positive_ordered_vector,
142+
real,
143+
real_matrix,
144+
real_vector,
145+
scaled_unit_lower_cholesky,
146+
simplex,
147+
softplus_lower_cholesky,
148+
softplus_positive,
149+
sphere,
150+
unit_interval,
151+
)
152+
for cnstr in singleton_constraints:
153+
roundtripped_cnstr = pickle.loads(pickle.dumps(cnstr))
154+
# make sure that the unpickled constraint is the original singleton constraint
155+
assert roundtripped_cnstr is cnstr
156+
157+
# Test that it remains possible to pickle newly-created, non-singleton constraints.
158+
# because these constraints are neither singleton nor exposed as top-level variables
159+
# of the numpyro.distributions.constraints module, these objects are not pickled by
160+
# reference, but by value.
161+
int_cstr = interval(1.0, 2.0)
162+
roundtripped_int_cstr = pickle.loads(pickle.dumps(int_cstr))
163+
assert type(roundtripped_int_cstr) is type(int_cstr)
164+
assert int_cstr.lower_bound == roundtripped_int_cstr.lower_bound
165+
assert int_cstr.upper_bound == roundtripped_int_cstr.upper_bound
166+
167+
gt_cstr = greater_than(1.0)
168+
roundtripped_gt_cstr = pickle.loads(pickle.dumps(gt_cstr))
169+
assert type(roundtripped_gt_cstr) is type(gt_cstr)
170+
assert gt_cstr.lower_bound == roundtripped_gt_cstr.lower_bound

0 commit comments

Comments
 (0)