6262import numpy as np
6363
6464import jax .numpy
65+ import jax .numpy as jnp
66+ from jax .tree_util import register_pytree_node
6567
6668
6769class Constraint (object ):
@@ -75,6 +77,10 @@ class Constraint(object):
7577 is_discrete = False
7678 event_dim = 0
7779
80+ def __init_subclass__ (cls , ** kwargs ):
81+ super ().__init_subclass__ (** kwargs )
82+ register_pytree_node (cls , cls .tree_flatten , cls .tree_unflatten )
83+
7884 def __call__ (self , x ):
7985 raise NotImplementedError
8086
@@ -94,8 +100,24 @@ def feasible_like(self, prototype):
94100 """
95101 raise NotImplementedError
96102
103+ @classmethod
104+ def tree_unflatten (cls , aux_data , params ):
105+ params_keys , aux_data = aux_data
106+ self = cls .__new__ (cls )
107+ for k , v in zip (params_keys , params ):
108+ setattr (self , k , v )
109+
110+ for k , v in aux_data .items ():
111+ setattr (self , k , v )
112+ return self
113+
114+
115+ class ParameterFreeConstraint (Constraint ):
116+ def tree_flatten (self ):
117+ return (), ((), dict ())
118+
97119
98- class _SingletonConstraint (Constraint ):
120+ class _SingletonConstraint (ParameterFreeConstraint ):
99121 """
100122 A constraint type which has only one canonical instance, like constraints.real,
101123 and unlike constraints.interval.
@@ -202,8 +224,23 @@ def __call__(self, x=None, *, is_discrete=NotImplemented, event_dim=NotImplement
202224 event_dim = self ._event_dim
203225 return _Dependent (is_discrete = is_discrete , event_dim = event_dim )
204226
227+ def __eq__ (self , other ):
228+ return (
229+ type (self ) is type (other )
230+ and self ._is_discrete == other ._is_discrete
231+ and self ._event_dim == other ._event_dim
232+ )
233+
234+ def tree_flatten (self ):
235+ return (), (
236+ (),
237+ dict (_is_discrete = self ._is_discrete , _event_dim = self ._event_dim ),
238+ )
239+
205240
206241class dependent_property (property , _Dependent ):
242+ # XXX: this should not need to be pytree-able since it simply wraps a method
243+ # and thus is automatically present once the method's object is created
207244 def __init__ (
208245 self , fn = None , * , is_discrete = NotImplemented , event_dim = NotImplemented
209246 ):
@@ -243,8 +280,16 @@ def __repr__(self):
243280 def feasible_like (self , prototype ):
244281 return jax .numpy .broadcast_to (self .lower_bound + 1 , jax .numpy .shape (prototype ))
245282
283+ def tree_flatten (self ):
284+ return (self .lower_bound ,), (("lower_bound" ,), dict ())
285+
286+ def __eq__ (self , other ):
287+ if not isinstance (other , _GreaterThan ):
288+ return False
289+ return jnp .array_equal (self .lower_bound , other .lower_bound )
246290
247- class _Positive (_GreaterThan , _SingletonConstraint ):
291+
292+ class _Positive (_SingletonConstraint , _GreaterThan ):
248293 def __init__ (self ):
249294 super ().__init__ (0.0 )
250295
@@ -301,6 +346,20 @@ def __repr__(self):
301346 def feasible_like (self , prototype ):
302347 return self .base_constraint .feasible_like (prototype )
303348
349+ def tree_flatten (self ):
350+ return (self .base_constraint ,), (
351+ ("base_constraint" ,),
352+ {"reinterpreted_batch_ndims" : self .reinterpreted_batch_ndims },
353+ )
354+
355+ def __eq__ (self , other ):
356+ if not isinstance (other , _IndependentConstraint ):
357+ return False
358+
359+ return (self .base_constraint == other .base_constraint ) & (
360+ self .reinterpreted_batch_ndims == other .reinterpreted_batch_ndims
361+ )
362+
304363
305364class _RealVector (_IndependentConstraint , _SingletonConstraint ):
306365 def __init__ (self ):
@@ -327,6 +386,14 @@ def __repr__(self):
327386 def feasible_like (self , prototype ):
328387 return jax .numpy .broadcast_to (self .upper_bound - 1 , jax .numpy .shape (prototype ))
329388
389+ def tree_flatten (self ):
390+ return (self .upper_bound ,), (("upper_bound" ,), dict ())
391+
392+ def __eq__ (self , other ):
393+ if not isinstance (other , _LessThan ):
394+ return False
395+ return jnp .array_equal (self .upper_bound , other .upper_bound )
396+
330397
331398class _IntegerInterval (Constraint ):
332399 is_discrete = True
@@ -348,6 +415,20 @@ def __repr__(self):
348415 def feasible_like (self , prototype ):
349416 return jax .numpy .broadcast_to (self .lower_bound , jax .numpy .shape (prototype ))
350417
418+ def tree_flatten (self ):
419+ return (self .lower_bound , self .upper_bound ), (
420+ ("lower_bound" , "upper_bound" ),
421+ dict (),
422+ )
423+
424+ def __eq__ (self , other ):
425+ if not isinstance (other , _IntegerInterval ):
426+ return False
427+
428+ return jnp .array_equal (self .lower_bound , other .lower_bound ) & jnp .array_equal (
429+ self .upper_bound , other .upper_bound
430+ )
431+
351432
352433class _IntegerGreaterThan (Constraint ):
353434 is_discrete = True
@@ -366,13 +447,21 @@ def __repr__(self):
366447 def feasible_like (self , prototype ):
367448 return jax .numpy .broadcast_to (self .lower_bound , jax .numpy .shape (prototype ))
368449
450+ def tree_flatten (self ):
451+ return (self .lower_bound ,), (("lower_bound" ,), dict ())
369452
370- class _IntegerPositive (_IntegerGreaterThan , _SingletonConstraint ):
453+ def __eq__ (self , other ):
454+ if not isinstance (other , _IntegerGreaterThan ):
455+ return False
456+ return jnp .array_equal (self .lower_bound , other .lower_bound )
457+
458+
459+ class _IntegerPositive (_SingletonConstraint , _IntegerGreaterThan ):
371460 def __init__ (self ):
372461 super ().__init__ (1 )
373462
374463
375- class _IntegerNonnegative (_IntegerGreaterThan , _SingletonConstraint ):
464+ class _IntegerNonnegative (_SingletonConstraint , _IntegerGreaterThan ):
376465 def __init__ (self ):
377466 super ().__init__ (0 )
378467
@@ -398,19 +487,25 @@ def feasible_like(self, prototype):
398487 )
399488
400489 def __eq__ (self , other ):
401- return (
402- isinstance ( other , _Interval )
403- and self .lower_bound == other .lower_bound
404- and self .upper_bound == other .upper_bound
490+ if not isinstance ( other , _Interval ):
491+ return False
492+ return jnp . array_equal ( self .lower_bound , other .lower_bound ) & jnp . array_equal (
493+ self .upper_bound , other .upper_bound
405494 )
406495
496+ def tree_flatten (self ):
497+ return (self .lower_bound , self .upper_bound ), (
498+ ("lower_bound" , "upper_bound" ),
499+ dict (),
500+ )
407501
408- class _Circular (_Interval , _SingletonConstraint ):
502+
503+ class _Circular (_SingletonConstraint , _Interval ):
409504 def __init__ (self ):
410505 super ().__init__ (- math .pi , math .pi )
411506
412507
413- class _UnitInterval (_Interval , _SingletonConstraint ):
508+ class _UnitInterval (_SingletonConstraint , _Interval ):
414509 def __init__ (self ):
415510 super ().__init__ (0.0 , 1.0 )
416511
@@ -462,6 +557,14 @@ def feasible_like(self, prototype):
462557 value = jax .numpy .pad (jax .numpy .expand_dims (self .upper_bound , - 1 ), pad_width )
463558 return jax .numpy .broadcast_to (value , prototype .shape )
464559
560+ def tree_flatten (self ):
561+ return (self .upper_bound ,), (("upper_bound" ,), dict ())
562+
563+ def __eq__ (self , other ):
564+ if not isinstance (other , _Multinomial ):
565+ return False
566+ return jnp .array_equal (self .upper_bound , other .upper_bound )
567+
465568
466569class _L1Ball (_SingletonConstraint ):
467570 """
@@ -546,7 +649,7 @@ def feasible_like(self, prototype):
546649 return jax .numpy .full_like (prototype , 1 / prototype .shape [- 1 ])
547650
548651
549- class _SoftplusPositive (_GreaterThan , _SingletonConstraint ):
652+ class _SoftplusPositive (_SingletonConstraint , _GreaterThan ):
550653 def __init__ (self ):
551654 super ().__init__ (lower_bound = 0.0 )
552655
0 commit comments