Skip to content

Commit 2fb8d88

Browse files
committed
Remove need for explicit left expand_dims in inputs of Elemwise
1 parent 55b00d1 commit 2fb8d88

28 files changed

+1068
-848
lines changed

README.rst

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -57,36 +57,35 @@ Getting started
5757
d = a/a + (M + a).dot(v)
5858
5959
pytensor.dprint(d)
60-
# Add [id A]
61-
# ├─ ExpandDims{axis=0} [id B]
62-
#└─ True_div [id C]
63-
#─ a [id D]
64-
# └─ a [id D]
65-
# └─ dot [id E]
66-
# ├─ Add [id F]
67-
# │ ├─ M [id G]
68-
# │ └─ ExpandDims{axes=[0, 1]} [id H]
69-
# └─ a [id D]
70-
# └─ v [id I]
60+
# Add [id A]
61+
# ├─ True_div [id B]
62+
#├─ a [id C]
63+
#─ a [id C]
64+
# └─ Squeeze{axis=1} [id D]
65+
# └─ Dot [id E]
66+
# ├─ Add [id F]
67+
# │ ├─ M [id G]
68+
# │ └─ a [id C]
69+
# └─ ExpandDims{axis=1} [id H]
70+
# └─ v [id I]
7171
7272
f_d = pytensor.function([a, v, M], d)
7373
7474
# `a/a` -> `1` and the dot product is replaced with a BLAS function
7575
# (i.e. CGemv)
7676
pytensor.dprint(f_d)
77-
# Add [id A] 5
78-
# ├─ [1.] [id B]
79-
# └─ CGemv{inplace} [id C] 4
80-
# ├─ AllocEmpty{dtype='float64'} [id D] 3
81-
# │ └─ Shape_i{0} [id E] 2
77+
# Add [id A] 4
78+
# ├─ 1.0 [id B]
79+
# └─ CGemv{inplace} [id C] 3
80+
# ├─ AllocEmpty{dtype='float64'} [id D] 2
81+
# │ └─ Shape_i{0} [id E] 1
8282
# │ └─ M [id F]
83-
# ├─ 1.0 [id G]
84-
# ├─ Add [id H] 1
83+
# ├─ 1.0 [id B]
84+
# ├─ Add [id G] 0
8585
# │ ├─ M [id F]
86-
# │ └─ ExpandDims{axes=[0, 1]} [id I] 0
87-
# │ └─ a [id J]
88-
# ├─ v [id K]
89-
# └─ 0.0 [id L]
86+
# │ └─ a [id H]
87+
# ├─ v [id I]
88+
# └─ 0.0 [id J]
9089
9190
See `the PyTensor documentation <https://pytensor.readthedocs.io/en/latest/>`__ for in-depth tutorials.
9291

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,9 @@ def impl(*inputs):
383383
type(op),
384384
tuple(op.inplace_pattern.items()),
385385
input_bc_patterns,
386+
output_bc_patterns,
386387
scalar_cache_key,
388+
2, # cache version
387389
)
388390
)
389391
elemwise_key = sha256(elemwise_key.encode()).hexdigest()

pytensor/link/numba/dispatch/vectorize_codegen.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _vectorized(
125125
raise TypeError("allow_core_scalar must be literal.")
126126
allow_core_scalar = allow_core_scalar.literal_value
127127

128-
batch_ndim = len(input_bc_patterns[0])
128+
batch_ndim = len(output_bc_patterns[0])
129129
nin = len(constant_inputs_types) + len(input_types)
130130
nout = len(output_bc_patterns)
131131

@@ -138,13 +138,6 @@ def _vectorized(
138138
if not all(isinstance(input, types.Array) for input in input_types):
139139
raise TypingError("Vectorized inputs must be arrays.")
140140

141-
if not all(
142-
len(pattern) == batch_ndim for pattern in input_bc_patterns + output_bc_patterns
143-
):
144-
raise TypingError(
145-
"Vectorized broadcastable patterns must have the same length."
146-
)
147-
148141
core_input_types = []
149142
for input_type, bc_pattern in zip(input_types, input_bc_patterns, strict=True):
150143
core_ndim = input_type.ndim - len(bc_pattern)
@@ -291,16 +284,21 @@ def compute_itershape(
291284
size: list[ir.Instruction] | None,
292285
):
293286
one = ir.IntType(64)(1)
294-
batch_ndim = len(broadcast_pattern[0])
287+
batch_ndim = max((len(p) for p in broadcast_pattern), default=0)
295288
shape = [None] * batch_ndim
296289
if size is not None:
297290
shape = size
298291
for i in range(batch_ndim):
299292
for j, (bc, in_shape) in enumerate(
300293
zip(broadcast_pattern, in_shapes, strict=True)
301294
):
302-
length = in_shape[i]
303-
if bc[i]:
295+
# Offset for inputs with fewer dims than batch_ndim
296+
offset = batch_ndim - len(bc)
297+
if i < offset:
298+
# Implicit broadcast dim — no array dim to check
299+
continue
300+
length = in_shape[i - offset]
301+
if bc[i - offset]:
304302
with builder.if_then(
305303
builder.icmp_unsigned("!=", length, one), likely=False
306304
):
@@ -336,8 +334,11 @@ def compute_itershape(
336334
for j, (bc, in_shape) in enumerate(
337335
zip(broadcast_pattern, in_shapes, strict=True)
338336
):
339-
length = in_shape[i]
340-
if bc[i]:
337+
offset = batch_ndim - len(bc)
338+
if i < offset:
339+
continue
340+
length = in_shape[i - offset]
341+
if bc[i - offset]:
341342
with builder.if_then(
342343
builder.icmp_unsigned("!=", length, one), likely=False
343344
):
@@ -452,6 +453,7 @@ def make_loop_call(
452453
# output_scope_set = mod.add_metadata([input_scope, output_scope])
453454

454455
zero = ir.Constant(ir.IntType(64), 0)
456+
batch_ndim = len(iter_shape)
455457

456458
# Setup loops and initialize accumulators for outputs
457459
# This part corresponds to opening the loops
@@ -480,9 +482,12 @@ def make_loop_call(
480482
for input, input_type, bc in zip(inputs, input_types, input_bc, strict=True):
481483
core_ndim = input_type.ndim - len(bc)
482484

483-
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [
484-
zero
485-
] * core_ndim
485+
# For inputs with fewer batch dims than the loop, skip leading loop indices
486+
offset = batch_ndim - len(bc)
487+
idxs_bc = [
488+
zero if bc_dim else idx
489+
for idx, bc_dim in zip(idxs[offset:], bc, strict=True)
490+
] + [zero] * core_ndim
486491
ptr = cgutils.get_item_pointer2(
487492
context,
488493
builder,

pytensor/scan/rewriting.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
Alloc,
5757
AllocEmpty,
5858
atleast_Nd,
59+
expand_dims,
5960
get_scalar_constant_value,
6061
)
6162
from pytensor.tensor.elemwise import DimShuffle, Elemwise
@@ -504,6 +505,21 @@ def add_to_replace(y):
504505

505506
to_remove_set.add(nd)
506507

508+
# When inner Elemwise inputs have different ndims, lower-ndim
509+
# inputs are implicitly left-padded. Outer equivalents of inputs
510+
# with a time dimension need broadcast dims inserted right after
511+
# the time dim (position 0) to match that implicit padding.
512+
# E.g., inner (v,) broadcasting with (a, v) → outer (t, v)
513+
# must become (t, 1, v) so it broadcasts with (a, v) to (t, a, v).
514+
inner_max_ndim = max(x.type.ndim for x in nd.inputs)
515+
for i, x in enumerate(nd.inputs):
516+
has_time = x in inner_seqs_set or x in to_replace_set
517+
n_pad = inner_max_ndim - x.type.ndim
518+
if has_time and n_pad > 0:
519+
outside_ins[i] = expand_dims(
520+
outside_ins[i], axis=tuple(range(1, 1 + n_pad))
521+
)
522+
507523
# Do not call make_node for test_value
508524
nw_outer_node = nd.op.make_node(*outside_ins)
509525

pytensor/sparse/rewriting.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pytensor.sparse.basic import csm_properties
1717
from pytensor.sparse.math import usmm
1818
from pytensor.tensor import blas
19-
from pytensor.tensor.basic import as_tensor_variable, cast
19+
from pytensor.tensor.basic import as_tensor_variable, atleast_Nd, cast
2020
from pytensor.tensor.math import mul, neg, sub
2121
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
2222
from pytensor.tensor.shape import shape, specify_shape
@@ -957,6 +957,9 @@ def local_usmm_csx(fgraph, node):
957957
if y.type.dtype != dtype_out:
958958
return False
959959

960+
# UsmmCscDense requires alpha to be 2-d with shape (1, 1)
961+
if alpha.ndim < 2:
962+
alpha = atleast_Nd(alpha, n=2)
960963
return [usmm_csc_dense(alpha, x_val, x_ind, x_ptr, x_nsparse, y, z)]
961964
return False
962965

pytensor/tensor/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def _get_underlying_scalar_constant_value(
379379
ret = [[None]]
380380
v.owner.op.perform(v.owner, const, ret)
381381
return np.asarray(ret[0][0].copy())
382-
# In fast_compile, we don't enable local_fill_to_alloc, so
382+
# In fast_compile, we don't enable local_second_to_alloc, so
383383
# we need to investigate Second as Alloc. So elemwise
384384
# don't disable the check for Second.
385385
elif isinstance(op, Elemwise):

0 commit comments

Comments
 (0)