@@ -95,7 +95,20 @@ def feasible_like(self, prototype):
9595 raise NotImplementedError
9696
9797
98- class _Boolean (Constraint ):
98+ class _SingletonConstraint (Constraint ):
99+ """
100+ A constraint type which has only one canonical instance, like constraints.real,
101+ and unlike constraints.interval.
102+ """
103+
104+ def __new__ (cls ):
105+ if (not hasattr (cls , "_instance" )) or (type (cls ._instance ) is not cls ):
106+ # Do not use the singleton instance of a superclass of cls.
107+ cls ._instance = super (_SingletonConstraint , cls ).__new__ (cls )
108+ return cls ._instance
109+
110+
111+ class _Boolean (_SingletonConstraint ):
99112 is_discrete = True
100113
101114 def __call__ (self , x ):
@@ -105,7 +118,7 @@ def feasible_like(self, prototype):
105118 return jax .numpy .zeros_like (prototype )
106119
107120
108- class _CorrCholesky (Constraint ):
121+ class _CorrCholesky (_SingletonConstraint ):
109122 event_dim = 2
110123
111124 def __call__ (self , x ):
@@ -126,7 +139,7 @@ def feasible_like(self, prototype):
126139 )
127140
128141
129- class _CorrMatrix (Constraint ):
142+ class _CorrMatrix (_SingletonConstraint ):
130143 event_dim = 2
131144
132145 def __call__ (self , x ):
@@ -231,6 +244,11 @@ def feasible_like(self, prototype):
231244 return jax .numpy .broadcast_to (self .lower_bound + 1 , jax .numpy .shape (prototype ))
232245
233246
247+ class _Positive (_GreaterThan , _SingletonConstraint ):
248+ def __init__ (self ):
249+ super ().__init__ (0.0 )
250+
251+
234252class _IndependentConstraint (Constraint ):
235253 """
236254 Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many
@@ -284,6 +302,16 @@ def feasible_like(self, prototype):
284302 return self .base_constraint .feasible_like (prototype )
285303
286304
305+ class _RealVector (_IndependentConstraint , _SingletonConstraint ):
306+ def __init__ (self ):
307+ super ().__init__ (_Real (), 1 )
308+
309+
310+ class _RealMatrix (_IndependentConstraint , _SingletonConstraint ):
311+ def __init__ (self ):
312+ super ().__init__ (_Real (), 2 )
313+
314+
287315class _LessThan (Constraint ):
288316 def __init__ (self , upper_bound ):
289317 self .upper_bound = upper_bound
@@ -339,6 +367,16 @@ def feasible_like(self, prototype):
339367 return jax .numpy .broadcast_to (self .lower_bound , jax .numpy .shape (prototype ))
340368
341369
370+ class _IntegerPositive (_IntegerGreaterThan , _SingletonConstraint ):
371+ def __init__ (self ):
372+ super ().__init__ (1 )
373+
374+
375+ class _IntegerNonnegative (_IntegerGreaterThan , _SingletonConstraint ):
376+ def __init__ (self ):
377+ super ().__init__ (0 )
378+
379+
342380class _Interval (Constraint ):
343381 def __init__ (self , lower_bound , upper_bound ):
344382 self .lower_bound = lower_bound
@@ -367,6 +405,16 @@ def __eq__(self, other):
367405 )
368406
369407
408+ class _Circular (_Interval , _SingletonConstraint ):
409+ def __init__ (self ):
410+ super ().__init__ (- math .pi , math .pi )
411+
412+
413+ class _UnitInterval (_Interval , _SingletonConstraint ):
414+ def __init__ (self ):
415+ super ().__init__ (0.0 , 1.0 )
416+
417+
370418class _OpenInterval (_Interval ):
371419 def __call__ (self , x ):
372420 return (x > self .lower_bound ) & (x < self .upper_bound )
@@ -379,12 +427,7 @@ def __repr__(self):
379427 return fmt_string
380428
381429
382- class _Circular (_Interval ):
383- def __init__ (self ):
384- super ().__init__ (- math .pi , math .pi )
385-
386-
387- class _LowerCholesky (Constraint ):
430+ class _LowerCholesky (_SingletonConstraint ):
388431 event_dim = 2
389432
390433 def __call__ (self , x ):
@@ -420,7 +463,7 @@ def feasible_like(self, prototype):
420463 return jax .numpy .broadcast_to (value , prototype .shape )
421464
422465
423- class _L1Ball (Constraint ):
466+ class _L1Ball (_SingletonConstraint ):
424467 """
425468 Constrain to the L1 ball of any dimension.
426469 """
@@ -437,7 +480,7 @@ def feasible_like(self, prototype):
437480 return jax .numpy .zeros_like (prototype )
438481
439482
440- class _OrderedVector (Constraint ):
483+ class _OrderedVector (_SingletonConstraint ):
441484 event_dim = 1
442485
443486 def __call__ (self , x ):
@@ -449,7 +492,7 @@ def feasible_like(self, prototype):
449492 )
450493
451494
452- class _PositiveDefinite (Constraint ):
495+ class _PositiveDefinite (_SingletonConstraint ):
453496 event_dim = 2
454497
455498 def __call__ (self , x ):
@@ -466,7 +509,7 @@ def feasible_like(self, prototype):
466509 )
467510
468511
469- class _PositiveOrderedVector (Constraint ):
512+ class _PositiveOrderedVector (_SingletonConstraint ):
470513 """
471514 Constrains to a positive real-valued tensor where the elements are monotonically
472515 increasing along the `event_shape` dimension.
@@ -483,7 +526,7 @@ def feasible_like(self, prototype):
483526 )
484527
485528
486- class _Real (Constraint ):
529+ class _Real (_SingletonConstraint ):
487530 def __call__ (self , x ):
488531 # XXX: consider to relax this condition to [-inf, inf] interval
489532 return (x == x ) & (x != float ("inf" )) & (x != float ("-inf" ))
@@ -492,7 +535,7 @@ def feasible_like(self, prototype):
492535 return jax .numpy .zeros_like (prototype )
493536
494537
495- class _Simplex (Constraint ):
538+ class _Simplex (_SingletonConstraint ):
496539 event_dim = 1
497540
498541 def __call__ (self , x ):
@@ -503,7 +546,7 @@ def feasible_like(self, prototype):
503546 return jax .numpy .full_like (prototype , 1 / prototype .shape [- 1 ])
504547
505548
506- class _SoftplusPositive (_GreaterThan ):
549+ class _SoftplusPositive (_GreaterThan , _SingletonConstraint ):
507550 def __init__ (self ):
508551 super ().__init__ (lower_bound = 0.0 )
509552
@@ -522,7 +565,7 @@ class _ScaledUnitLowerCholesky(_LowerCholesky):
522565 pass
523566
524567
525- class _Sphere (Constraint ):
568+ class _Sphere (_SingletonConstraint ):
526569 """
527570 Constrain to the Euclidean sphere of any dimension.
528571 """
@@ -559,18 +602,18 @@ def feasible_like(self, prototype):
559602lower_cholesky = _LowerCholesky ()
560603scaled_unit_lower_cholesky = _ScaledUnitLowerCholesky ()
561604multinomial = _Multinomial
562- nonnegative_integer = _IntegerGreaterThan ( 0 )
605+ nonnegative_integer = _IntegerNonnegative ( )
563606ordered_vector = _OrderedVector ()
564- positive = _GreaterThan ( 0.0 )
607+ positive = _Positive ( )
565608positive_definite = _PositiveDefinite ()
566- positive_integer = _IntegerGreaterThan ( 1 )
609+ positive_integer = _IntegerPositive ( )
567610positive_ordered_vector = _PositiveOrderedVector ()
568611real = _Real ()
569- real_vector = independent ( real , 1 )
570- real_matrix = independent ( real , 2 )
612+ real_vector = _RealVector ( )
613+ real_matrix = _RealMatrix ( )
571614simplex = _Simplex ()
572615softplus_lower_cholesky = _SoftplusLowerCholesky ()
573616softplus_positive = _SoftplusPositive ()
574617sphere = _Sphere ()
575- unit_interval = _Interval ( 0.0 , 1.0 )
618+ unit_interval = _UnitInterval ( )
576619open_interval = _OpenInterval
0 commit comments