3
3
creation, traversion, code-generation and
4
4
compression/fusion if necessary/possible
5
5
"""
6
+
6
7
from casadi import Function
7
8
from casadi import OP_CONST , OP_INPUT , OP_OUTPUT , OP_SQ , Function
8
9
from collections import deque
9
10
10
11
from ._ops import OP_JAX_VALUE_DICT
11
12
13
+
12
14
def sort_by_height (graph , antigraph , heights ):
13
15
nodes = [[] for i in range (max (heights ) + 1 )]
14
16
for i , h in enumerate (heights ):
15
17
nodes [h ].append (i )
16
18
17
19
return nodes
18
20
21
+
19
22
def codegen (graph , antigraph , heights , output_map , values ):
20
23
sorted_nodes = sort_by_height (graph , antigraph , heights )
21
24
code = ""
@@ -27,14 +30,10 @@ def codegen(graph, antigraph, heights, output_map, values):
27
30
if node in output_map :
28
31
oo = output_map [node ]
29
32
if outputs .get (oo [0 ], None ) is None :
30
- outputs [oo [0 ]] = {
31
- 'rows' : [],
32
- 'cols' : [],
33
- 'values' : []
34
- }
35
- outputs [oo [0 ]]['rows' ].append (oo [1 ])
36
- outputs [oo [0 ]]['cols' ].append (oo [2 ])
37
- outputs [oo [0 ]]['values' ].append (values [node ])
33
+ outputs [oo [0 ]] = {"rows" : [], "cols" : [], "values" : []}
34
+ outputs [oo [0 ]]["rows" ].append (oo [1 ])
35
+ outputs [oo [0 ]]["cols" ].append (oo [2 ])
36
+ outputs [oo [0 ]]["values" ].append (values [node ])
38
37
else :
39
38
if len (assignment ) > 1 :
40
39
assignment += ", "
@@ -47,6 +46,7 @@ def codegen(graph, antigraph, heights, output_map, values):
47
46
48
47
return code
49
48
49
+
50
50
def compute_heights (func , graph , antigraph ):
51
51
heights = [0 for _ in range (len (graph ))]
52
52
current_layer = set ()
@@ -70,6 +70,7 @@ def compute_heights(func, graph, antigraph):
70
70
71
71
return heights
72
72
73
+
73
74
def create_graph (func : Function ):
74
75
N = func .n_instructions ()
75
76
graph = [[] for _ in range (N )]
@@ -121,7 +122,6 @@ def create_graph(func: Function):
121
122
else :
122
123
raise Exception ("Unknown CasADi operation: " + str (op ))
123
124
124
-
125
125
return graph , antigraph , output_map , values
126
126
127
127
@@ -142,5 +142,3 @@ def translate(func: Function, add_jit=False, add_import=False):
142
142
code += " return outputs"
143
143
144
144
return code
145
-
146
-
0 commit comments