Skip to content

Commit fc5a4f3

Browse files
committed
Remove useless switch on log transformed parameters
1 parent 49382ba commit fc5a4f3

File tree

3 files changed

+75
-9
lines changed

3 files changed

+75
-9
lines changed

Diff for: pymc/logprob/utils.py

+60-5
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@
4343

4444
from pytensor import tensor as pt
4545
from pytensor.graph import Apply, Op, node_rewriter
46-
from pytensor.graph.basic import Constant, Variable, clone_get_equiv, graph_inputs, walk
46+
from pytensor.graph.basic import Constant, Variable, ancestors, clone_get_equiv, graph_inputs, walk
4747
from pytensor.graph.fg import FunctionGraph
4848
from pytensor.graph.op import HasInnerGraph
4949
from pytensor.link.c.type import CType
5050
from pytensor.raise_op import CheckAndRaise
51-
from pytensor.scalar.basic import Mul
51+
from pytensor.scalar.basic import GE, LE, Exp, Mul
5252
from pytensor.tensor.basic import get_underlying_scalar_constant_value
53-
from pytensor.tensor.elemwise import Elemwise
53+
from pytensor.tensor.elemwise import DimShuffle, Elemwise
5454
from pytensor.tensor.exceptions import NotScalarConstantError
5555
from pytensor.tensor.random.op import RandomVariable
5656
from pytensor.tensor.variable import TensorVariable
@@ -228,6 +228,55 @@ def local_remove_check_parameter(fgraph, node):
228228
return [node.inputs[0]]
229229

230230

231+
@node_rewriter(tracks=[pt.switch])
232+
def local_remove_useless_bound_switch(fgraph, node):
233+
"""Remove bound checks ensured by the transformations.
234+
235+
switch(exp(x) >= 0, cond1, -inf) -> cond1 if exp(x) in cond1.
236+
237+
The reason we don't set it to simply True is that x could be `nan`.
238+
If we see exp(x) exists in cond1 we assume `nan` will be propagated anyway.
239+
240+
This isn't guaranteed to be True, for instance if exp(x) is inside another switch statement.
241+
"""
242+
cond, true_branch, false_branch = node.inputs
243+
if not (cond.owner is not None and isinstance(cond.owner.op, Elemwise)):
244+
return
245+
scalar_op = cond.owner.op.scalar_op
246+
if isinstance(scalar_op, LE):
247+
maybe_zero, var = cond.owner.inputs
248+
elif isinstance(scalar_op, GE):
249+
var, maybe_zero = cond.owner.inputs
250+
else:
251+
return None
252+
253+
if not (
254+
(isinstance(maybe_zero, Constant) and maybe_zero.unique_value == 0)
255+
and (isinstance(false_branch, Constant) and false_branch.unique_value == -np.inf)
256+
):
257+
return None
258+
259+
# Check if var is exp(x) and x is present in the true branch
260+
if (
261+
var.owner is not None
262+
and (
263+
(isinstance(var.owner.op, Elemwise) and isinstance(var.owner.op.scalar_op, Exp))
264+
or (
265+
isinstance(var.owner.op, DimShuffle)
266+
and (
267+
var.owner.inputs[0].owner is not None
268+
and isinstance(var.owner.inputs[0].owner.op, Elemwise)
269+
and isinstance(var.owner.inputs[0].owner.op.scalar_op, Exp)
270+
)
271+
)
272+
)
273+
and var in ancestors([true_branch])
274+
):
275+
return [true_branch]
276+
277+
return None
278+
279+
231280
@node_rewriter(tracks=[CheckParameterValue])
232281
def local_check_parameter_to_ninf_switch(fgraph, node):
233282
if not node.op.can_be_replaced_by_ninf:
@@ -248,17 +297,23 @@ def local_check_parameter_to_ninf_switch(fgraph, node):
248297

249298

250299
pytensor.compile.optdb["canonicalize"].register(
251-
"local_remove_check_parameter",
300+
local_remove_check_parameter.__name__,
252301
local_remove_check_parameter,
253302
use_db_name_as_tag=False,
254303
)
255304

256305
pytensor.compile.optdb["canonicalize"].register(
257-
"local_check_parameter_to_ninf_switch",
306+
local_check_parameter_to_ninf_switch.__name__,
258307
local_check_parameter_to_ninf_switch,
259308
use_db_name_as_tag=False,
260309
)
261310

311+
pytensor.compile.optdb["canonicalize"].register(
312+
local_remove_useless_bound_switch.__name__,
313+
local_remove_useless_bound_switch,
314+
use_db_name_as_tag=False,
315+
)
316+
262317

263318
class DiracDelta(MeasurableOp, Op):
264319
"""An `Op` that represents a Dirac-delta distribution."""

Diff for: pymc/pytensorf.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -937,12 +937,16 @@ def compile(
937937
check_bounds = model.check_bounds
938938
except TypeError:
939939
check_bounds = True
940-
check_parameter_opt = (
941-
"local_check_parameter_to_ninf_switch" if check_bounds else "local_remove_check_parameter"
942-
)
940+
if check_bounds:
941+
check_parameter_opt = ("local_check_parameter_to_ninf_switch",)
942+
else:
943+
check_parameter_opt = (
944+
"local_remove_check_parameter",
945+
"local_remove_useless_bound_switch",
946+
)
943947

944948
mode = get_mode(mode)
945-
opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
949+
opt_qry = mode.provided_optimizer.including("random_make_inplace", *check_parameter_opt)
946950
mode = Mode(linker=mode.linker, optimizer=opt_qry)
947951
pytensor_function = pytensor.function(
948952
inputs,

Diff for: tests/test_pytensorf.py

+7
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,13 @@ def test_check_parameters_can_be_replaced_by_ninf(self):
354354
with pytest.raises(ParameterValueError, match="test"):
355355
fn([-1, 2, 3])
356356

357+
def test_useless_bound_switch(self):
358+
# Without default transform the switch is never removed
359+
# Even if check_bounds = False
360+
with pm.Model(check_bounds=False) as m:
361+
x = pm.HalfNormal("x", default_transform=None)
362+
assert m.compile_logp()({"x": -1}) == -np.inf
363+
357364
def test_compile_pymc_sets_rng_updates(self):
358365
rng = pytensor.shared(np.random.default_rng(0))
359366
x = pm.Normal.dist(rng=rng)

0 commit comments

Comments
 (0)