43
43
44
44
from pytensor import tensor as pt
45
45
from pytensor .graph import Apply , Op , node_rewriter
46
- from pytensor .graph .basic import Constant , Variable , clone_get_equiv , graph_inputs , walk
46
+ from pytensor .graph .basic import Constant , Variable , ancestors , clone_get_equiv , graph_inputs , walk
47
47
from pytensor .graph .fg import FunctionGraph
48
48
from pytensor .graph .op import HasInnerGraph
49
49
from pytensor .link .c .type import CType
50
50
from pytensor .raise_op import CheckAndRaise
51
- from pytensor .scalar .basic import Mul
51
+ from pytensor .scalar .basic import GE , LE , Exp , Mul
52
52
from pytensor .tensor .basic import get_underlying_scalar_constant_value
53
- from pytensor .tensor .elemwise import Elemwise
53
+ from pytensor .tensor .elemwise import DimShuffle , Elemwise
54
54
from pytensor .tensor .exceptions import NotScalarConstantError
55
55
from pytensor .tensor .random .op import RandomVariable
56
56
from pytensor .tensor .variable import TensorVariable
@@ -228,6 +228,55 @@ def local_remove_check_parameter(fgraph, node):
228
228
return [node .inputs [0 ]]
229
229
230
230
231
+ @node_rewriter (tracks = [pt .switch ])
232
+ def local_remove_useless_bound_switch (fgraph , node ):
233
+ """Remove bound checks ensured by the transformations.
234
+
235
+ switch(exp(x) >= 0, cond1, -inf) -> cond1 if exp(x) in cond1.
236
+
237
+ The reason we don't set it to simply True is that x could be `nan`.
238
+ If we see exp(x) exists in cond1 we assume `nan` will be propagated anyway.
239
+
240
+ This isn't guaranteed to be True, for instance if exp(x) is inside another switch statement.
241
+ """
242
+ cond , true_branch , false_branch = node .inputs
243
+ if not (cond .owner is not None and isinstance (cond .owner .op , Elemwise )):
244
+ return
245
+ scalar_op = cond .owner .op .scalar_op
246
+ if isinstance (scalar_op , LE ):
247
+ maybe_zero , var = cond .owner .inputs
248
+ elif isinstance (scalar_op , GE ):
249
+ var , maybe_zero = cond .owner .inputs
250
+ else :
251
+ return None
252
+
253
+ if not (
254
+ (isinstance (maybe_zero , Constant ) and maybe_zero .unique_value == 0 )
255
+ and (isinstance (false_branch , Constant ) and false_branch .unique_value == - np .inf )
256
+ ):
257
+ return None
258
+
259
+ # Check if var is exp(x) and x is present in the true branch
260
+ if (
261
+ var .owner is not None
262
+ and (
263
+ (isinstance (var .owner .op , Elemwise ) and isinstance (var .owner .op .scalar_op , Exp ))
264
+ or (
265
+ isinstance (var .owner .op , DimShuffle )
266
+ and (
267
+ var .owner .inputs [0 ].owner is not None
268
+ and isinstance (var .owner .inputs [0 ].owner .op , Elemwise )
269
+ and isinstance (var .owner .inputs [0 ].owner .op .scalar_op , Exp )
270
+ )
271
+ )
272
+ )
273
+ and var in ancestors ([true_branch ])
274
+ ):
275
+ return [true_branch ]
276
+
277
+ return None
278
+
279
+
231
280
@node_rewriter (tracks = [CheckParameterValue ])
232
281
def local_check_parameter_to_ninf_switch (fgraph , node ):
233
282
if not node .op .can_be_replaced_by_ninf :
@@ -248,17 +297,23 @@ def local_check_parameter_to_ninf_switch(fgraph, node):
248
297
249
298
250
299
pytensor .compile .optdb ["canonicalize" ].register (
251
- " local_remove_check_parameter" ,
300
+ local_remove_check_parameter . __name__ ,
252
301
local_remove_check_parameter ,
253
302
use_db_name_as_tag = False ,
254
303
)
255
304
256
305
pytensor .compile .optdb ["canonicalize" ].register (
257
- " local_check_parameter_to_ninf_switch" ,
306
+ local_check_parameter_to_ninf_switch . __name__ ,
258
307
local_check_parameter_to_ninf_switch ,
259
308
use_db_name_as_tag = False ,
260
309
)
261
310
311
+ pytensor .compile .optdb ["canonicalize" ].register (
312
+ local_remove_useless_bound_switch .__name__ ,
313
+ local_remove_useless_bound_switch ,
314
+ use_db_name_as_tag = False ,
315
+ )
316
+
262
317
263
318
class DiracDelta (MeasurableOp , Op ):
264
319
"""An `Op` that represents a Dirac-delta distribution."""
0 commit comments