Skip to content

Commit df95224

Browse files
Use dummy input variables during Scan rewrites
1 parent 6307c50 commit df95224

File tree

1 file changed

+40
-21
lines changed

1 file changed

+40
-21
lines changed

aesara/scan/rewriting.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
graph_inputs,
2424
io_toposort,
2525
is_in_ancestors,
26+
replace_nominals_with_dummies,
2627
)
2728
from aesara.graph.destroyhandler import DestroyHandler
2829
from aesara.graph.features import ReplaceValidate
@@ -82,6 +83,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
8283
"""
8384
if not isinstance(node.op, Scan):
8485
return False
86+
8587
op = node.op
8688
op_info = op.info
8789
# We only need to take care of sequences and other arguments
@@ -92,8 +94,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
9294
st += op_info.n_sit_sot
9395
st += op_info.n_shared_outs
9496

95-
op_ins = op.inner_inputs
96-
op_outs = op.inner_outputs
97+
op_ins, op_outs = replace_nominals_with_dummies(op.inner_inputs, op.inner_outputs)
9798

9899
# Corresponds to the initial states, which should stay untouched.
99100
# We put those variables aside, and put them back at the end.
@@ -189,6 +190,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
189190
allow_gc=op.allow_gc,
190191
)
191192
nw_outs = nwScan(*nw_outer, return_list=True)
193+
192194
return dict([("remove", [node])] + list(zip(node.outputs, nw_outs)))
193195
else:
194196
return False
@@ -207,7 +209,9 @@ def push_out_non_seq_scan(fgraph, node):
207209
if not isinstance(node.op, Scan):
208210
return False
209211

210-
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
212+
node_inputs, node_outputs = replace_nominals_with_dummies(
213+
node.op.inner_inputs, node.op.inner_outputs
214+
)
211215

212216
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
213217
local_fgraph_outs_set = set(node_outputs)
@@ -417,7 +421,9 @@ def push_out_seq_scan(fgraph, node):
417421
if not isinstance(node.op, Scan):
418422
return False
419423

420-
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
424+
node_inputs, node_outputs = replace_nominals_with_dummies(
425+
node.op.inner_inputs, node.op.inner_outputs
426+
)
421427

422428
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
423429
local_fgraph_outs_set = set(node_outputs)
@@ -832,9 +838,10 @@ def push_out_add_scan(fgraph, node):
832838

833839
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
834840
# use
835-
args = ScanArgs(
836-
node.inputs, node.outputs, op.inner_inputs, op.inner_outputs, op.info
841+
inner_inputs, inner_outputs = replace_nominals_with_dummies(
842+
op.inner_inputs, op.inner_outputs
837843
)
844+
args = ScanArgs(node.inputs, node.outputs, inner_inputs, inner_outputs, op.info)
838845

839846
clients = {}
840847
local_fgraph_topo = io_toposort(
@@ -1694,6 +1701,8 @@ def merge(self, nodes):
16941701
inner_outs = [[] for nd in nodes]
16951702
outer_outs = []
16961703

1704+
# inner_inputs, inner_outputs = replace_nominals_with_dummies(nd.op.inner_inputs, nd.op.inner_outputs)
1705+
16971706
def rename(ls, suffix):
16981707
for k in ls:
16991708
if k.name:
@@ -1967,11 +1976,16 @@ def scan_merge_inouts(fgraph, node):
19671976
# Do a first pass to merge identical external inputs.
19681977
# Equivalent inputs will be stored in inp_equiv, then a new
19691978
# scan node created without duplicates.
1979+
1980+
inner_inputs, inner_outputs = replace_nominals_with_dummies(
1981+
node.op.inner_inputs, node.op.inner_outputs
1982+
)
1983+
19701984
a = ScanArgs(
19711985
node.inputs,
19721986
node.outputs,
1973-
node.op.inner_inputs,
1974-
node.op.inner_outputs,
1987+
inner_inputs,
1988+
inner_outputs,
19751989
node.op.info,
19761990
)
19771991

@@ -2173,8 +2187,13 @@ def push_out_dot1_scan(fgraph, node):
21732187
# Note that this works when only you need X[-1] in the end
21742188
# and assumes dimshuffle are applied to vectors before calling dot
21752189
op = node.op
2176-
sitsot_ins = op.inner_sitsot(op.inner_inputs)
2177-
sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)
2190+
2191+
inner_inputs, inner_outputs = replace_nominals_with_dummies(
2192+
op.inner_inputs, op.inner_outputs
2193+
)
2194+
2195+
sitsot_ins = op.inner_sitsot(inner_inputs)
2196+
sitsot_outs = op.inner_sitsot_outs(inner_outputs)
21782197
outer_sitsot = op.outer_sitsot_outs(node.outputs)
21792198
seqs = op.inner_seqs(op.inner_inputs)
21802199
for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot):
@@ -2218,23 +2237,23 @@ def push_out_dot1_scan(fgraph, node):
22182237
# First let us split all arguments according to their
22192238
# corresponding categories
22202239

2221-
inner_seqs = op.inner_seqs(op.inner_inputs)
2240+
inner_seqs = op.inner_seqs(inner_inputs)
22222241
outer_seqs = op.outer_seqs(node.inputs)
2223-
inner_mitmot = op.inner_mitmot(op.inner_inputs)
2242+
inner_mitmot = op.inner_mitmot(inner_inputs)
22242243
outer_mitmot = op.outer_mitmot(node.inputs)
2225-
inner_mitmot_outs = op.inner_mitmot_outs(op.inner_outputs)
2226-
inner_mitsot = op.inner_mitsot(op.inner_inputs)
2244+
inner_mitmot_outs = op.inner_mitmot_outs(inner_outputs)
2245+
inner_mitsot = op.inner_mitsot(inner_inputs)
22272246
outer_mitsot = op.outer_mitsot(node.inputs)
2228-
inner_mitsot_outs = op.inner_mitsot_outs(op.inner_outputs)
2229-
inner_sitsot = op.inner_sitsot(op.inner_inputs)
2247+
inner_mitsot_outs = op.inner_mitsot_outs(inner_outputs)
2248+
inner_sitsot = op.inner_sitsot(inner_inputs)
22302249
outer_sitsot = op.outer_sitsot(node.inputs)
2231-
inner_sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)
2250+
inner_sitsot_outs = op.inner_sitsot_outs(inner_outputs)
22322251
outer_nitsot = op.outer_nitsot(node.inputs)
2233-
inner_nitsot_outs = op.inner_nitsot_outs(op.inner_outputs)
2234-
inner_shared = op.inner_shared(op.inner_inputs)
2252+
inner_nitsot_outs = op.inner_nitsot_outs(inner_outputs)
2253+
inner_shared = op.inner_shared(inner_inputs)
22352254
outer_shared = op.outer_shared(node.inputs)
2236-
inner_shared_outs = op.inner_shared_outs(op.inner_outputs)
2237-
inner_non_seqs = op.inner_non_seqs(op.inner_inputs)
2255+
inner_shared_outs = op.inner_shared_outs(inner_outputs)
2256+
inner_non_seqs = op.inner_non_seqs(inner_inputs)
22382257
outer_non_seqs = op.outer_non_seqs(node.inputs)
22392258

22402259
new_info = dataclasses.replace(

0 commit comments

Comments
 (0)