23
23
graph_inputs ,
24
24
io_toposort ,
25
25
is_in_ancestors ,
26
+ replace_nominals_with_dummies ,
26
27
)
27
28
from aesara .graph .destroyhandler import DestroyHandler
28
29
from aesara .graph .features import ReplaceValidate
@@ -82,6 +83,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
82
83
"""
83
84
if not isinstance (node .op , Scan ):
84
85
return False
86
+
85
87
op = node .op
86
88
op_info = op .info
87
89
# We only need to take care of sequences and other arguments
@@ -92,8 +94,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
92
94
st += op_info .n_sit_sot
93
95
st += op_info .n_shared_outs
94
96
95
- op_ins = op .inner_inputs
96
- op_outs = op .inner_outputs
97
+ op_ins , op_outs = replace_nominals_with_dummies (op .inner_inputs , op .inner_outputs )
97
98
98
99
# Corresponds to the initial states, which should stay untouched.
99
100
# We put those variables aside, and put them back at the end.
@@ -189,6 +190,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
189
190
allow_gc = op .allow_gc ,
190
191
)
191
192
nw_outs = nwScan (* nw_outer , return_list = True )
193
+
192
194
return dict ([("remove" , [node ])] + list (zip (node .outputs , nw_outs )))
193
195
else :
194
196
return False
@@ -207,7 +209,9 @@ def push_out_non_seq_scan(fgraph, node):
207
209
if not isinstance (node .op , Scan ):
208
210
return False
209
211
210
- node_inputs , node_outputs = node .op .inner_inputs , node .op .inner_outputs
212
+ node_inputs , node_outputs = replace_nominals_with_dummies (
213
+ node .op .inner_inputs , node .op .inner_outputs
214
+ )
211
215
212
216
local_fgraph_topo = io_toposort (node_inputs , node_outputs )
213
217
local_fgraph_outs_set = set (node_outputs )
@@ -417,7 +421,9 @@ def push_out_seq_scan(fgraph, node):
417
421
if not isinstance (node .op , Scan ):
418
422
return False
419
423
420
- node_inputs , node_outputs = node .op .inner_inputs , node .op .inner_outputs
424
+ node_inputs , node_outputs = replace_nominals_with_dummies (
425
+ node .op .inner_inputs , node .op .inner_outputs
426
+ )
421
427
422
428
local_fgraph_topo = io_toposort (node_inputs , node_outputs )
423
429
local_fgraph_outs_set = set (node_outputs )
@@ -832,9 +838,10 @@ def push_out_add_scan(fgraph, node):
832
838
833
839
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
834
840
# use
835
- args = ScanArgs (
836
- node . inputs , node . outputs , op .inner_inputs , op .inner_outputs , op . info
841
+ inner_inputs , inner_outputs = replace_nominals_with_dummies (
842
+ op .inner_inputs , op .inner_outputs
837
843
)
844
+ args = ScanArgs (node .inputs , node .outputs , inner_inputs , inner_outputs , op .info )
838
845
839
846
clients = {}
840
847
local_fgraph_topo = io_toposort (
@@ -1694,6 +1701,8 @@ def merge(self, nodes):
1694
1701
inner_outs = [[] for nd in nodes ]
1695
1702
outer_outs = []
1696
1703
1704
+ # inner_inputs, inner_outputs = replace_nominals_with_dummies(nd.op.inner_inputs, nd.op.inner_outputs)
1705
+
1697
1706
def rename (ls , suffix ):
1698
1707
for k in ls :
1699
1708
if k .name :
@@ -1967,11 +1976,16 @@ def scan_merge_inouts(fgraph, node):
1967
1976
# Do a first pass to merge identical external inputs.
1968
1977
# Equivalent inputs will be stored in inp_equiv, then a new
1969
1978
# scan node created without duplicates.
1979
+
1980
+ inner_inputs , inner_outputs = replace_nominals_with_dummies (
1981
+ node .op .inner_inputs , node .op .inner_outputs
1982
+ )
1983
+
1970
1984
a = ScanArgs (
1971
1985
node .inputs ,
1972
1986
node .outputs ,
1973
- node . op . inner_inputs ,
1974
- node . op . inner_outputs ,
1987
+ inner_inputs ,
1988
+ inner_outputs ,
1975
1989
node .op .info ,
1976
1990
)
1977
1991
@@ -2173,8 +2187,13 @@ def push_out_dot1_scan(fgraph, node):
2173
2187
# Note that this works when only you need X[-1] in the end
2174
2188
# and assumes dimshuffle are applied to vectors before calling dot
2175
2189
op = node .op
2176
- sitsot_ins = op .inner_sitsot (op .inner_inputs )
2177
- sitsot_outs = op .inner_sitsot_outs (op .inner_outputs )
2190
+
2191
+ inner_inputs , inner_outputs = replace_nominals_with_dummies (
2192
+ op .inner_inputs , op .inner_outputs
2193
+ )
2194
+
2195
+ sitsot_ins = op .inner_sitsot (inner_inputs )
2196
+ sitsot_outs = op .inner_sitsot_outs (inner_outputs )
2178
2197
outer_sitsot = op .outer_sitsot_outs (node .outputs )
2179
2198
seqs = op .inner_seqs (op .inner_inputs )
2180
2199
for inp , out , outer_out in zip (sitsot_ins , sitsot_outs , outer_sitsot ):
@@ -2218,23 +2237,23 @@ def push_out_dot1_scan(fgraph, node):
2218
2237
# First let us split all arguments according to their
2219
2238
# corresponding categories
2220
2239
2221
- inner_seqs = op .inner_seqs (op . inner_inputs )
2240
+ inner_seqs = op .inner_seqs (inner_inputs )
2222
2241
outer_seqs = op .outer_seqs (node .inputs )
2223
- inner_mitmot = op .inner_mitmot (op . inner_inputs )
2242
+ inner_mitmot = op .inner_mitmot (inner_inputs )
2224
2243
outer_mitmot = op .outer_mitmot (node .inputs )
2225
- inner_mitmot_outs = op .inner_mitmot_outs (op . inner_outputs )
2226
- inner_mitsot = op .inner_mitsot (op . inner_inputs )
2244
+ inner_mitmot_outs = op .inner_mitmot_outs (inner_outputs )
2245
+ inner_mitsot = op .inner_mitsot (inner_inputs )
2227
2246
outer_mitsot = op .outer_mitsot (node .inputs )
2228
- inner_mitsot_outs = op .inner_mitsot_outs (op . inner_outputs )
2229
- inner_sitsot = op .inner_sitsot (op . inner_inputs )
2247
+ inner_mitsot_outs = op .inner_mitsot_outs (inner_outputs )
2248
+ inner_sitsot = op .inner_sitsot (inner_inputs )
2230
2249
outer_sitsot = op .outer_sitsot (node .inputs )
2231
- inner_sitsot_outs = op .inner_sitsot_outs (op . inner_outputs )
2250
+ inner_sitsot_outs = op .inner_sitsot_outs (inner_outputs )
2232
2251
outer_nitsot = op .outer_nitsot (node .inputs )
2233
- inner_nitsot_outs = op .inner_nitsot_outs (op . inner_outputs )
2234
- inner_shared = op .inner_shared (op . inner_inputs )
2252
+ inner_nitsot_outs = op .inner_nitsot_outs (inner_outputs )
2253
+ inner_shared = op .inner_shared (inner_inputs )
2235
2254
outer_shared = op .outer_shared (node .inputs )
2236
- inner_shared_outs = op .inner_shared_outs (op . inner_outputs )
2237
- inner_non_seqs = op .inner_non_seqs (op . inner_inputs )
2255
+ inner_shared_outs = op .inner_shared_outs (inner_outputs )
2256
+ inner_non_seqs = op .inner_non_seqs (inner_inputs )
2238
2257
outer_non_seqs = op .outer_non_seqs (node .inputs )
2239
2258
2240
2259
new_info = dataclasses .replace (
0 commit comments