Skip to content

Commit d2ca1ad

Browse files
committed
fix: pre-commit
1 parent d56d48c commit d2ca1ad

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

jaxadi/_convert.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def convert(casadi_fn: Function, translate=None, compile=False) -> Callable[...,
2020
"""
2121
if translate is None:
2222
translate = graph_translate
23-
23+
2424
jax_str = translate(casadi_fn)
2525
jax_fn = declare(jax_str)
2626

jaxadi/_expand.py

+1
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def squeeze(stages: List[Stage], num_threads=1) -> List[Stage]:
201201
cmd = combine_outputs(new_stages)
202202
return cmd
203203

204+
204205
def translate(func: Function, add_jit=False, add_import=False) -> str:
205206
stages = stage_generator(func)
206207
stages = squeeze(stages)

jaxadi/_graph.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,22 @@
33
creation, traversion, code-generation and
44
compression/fusion if necessary/possible
55
"""
6+
67
from casadi import Function
78
from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function
89
from collections import deque
910

1011
from ._ops import OP_JAX_VALUE_DICT
1112

13+
1214
def sort_by_height(graph, antigraph, heights):
1315
nodes = [[] for i in range(max(heights) + 1)]
1416
for i, h in enumerate(heights):
1517
nodes[h].append(i)
1618

1719
return nodes
1820

21+
1922
def codegen(graph, antigraph, heights, output_map, values):
2023
sorted_nodes = sort_by_height(graph, antigraph, heights)
2124
code = ""
@@ -27,14 +30,10 @@ def codegen(graph, antigraph, heights, output_map, values):
2730
if node in output_map:
2831
oo = output_map[node]
2932
if outputs.get(oo[0], None) is None:
30-
outputs[oo[0]] = {
31-
'rows': [],
32-
'cols': [],
33-
'values': []
34-
}
35-
outputs[oo[0]]['rows'].append(oo[1])
36-
outputs[oo[0]]['cols'].append(oo[2])
37-
outputs[oo[0]]['values'].append(values[node])
33+
outputs[oo[0]] = {"rows": [], "cols": [], "values": []}
34+
outputs[oo[0]]["rows"].append(oo[1])
35+
outputs[oo[0]]["cols"].append(oo[2])
36+
outputs[oo[0]]["values"].append(values[node])
3837
else:
3938
if len(assignment) > 1:
4039
assignment += ", "
@@ -47,6 +46,7 @@ def codegen(graph, antigraph, heights, output_map, values):
4746

4847
return code
4948

49+
5050
def compute_heights(func, graph, antigraph):
5151
heights = [0 for _ in range(len(graph))]
5252
current_layer = set()
@@ -70,6 +70,7 @@ def compute_heights(func, graph, antigraph):
7070

7171
return heights
7272

73+
7374
def create_graph(func: Function):
7475
N = func.n_instructions()
7576
graph = [[] for _ in range(N)]
@@ -121,7 +122,6 @@ def create_graph(func: Function):
121122
else:
122123
raise Exception("Unknown CasADi operation: " + str(op))
123124

124-
125125
return graph, antigraph, output_map, values
126126

127127

@@ -142,5 +142,3 @@ def translate(func: Function, add_jit=False, add_import=False):
142142
code += " return outputs"
143143

144144
return code
145-
146-

jaxadi/_ops.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
OP_TANH,
4848
OP_TWICE,
4949
)
50+
5051
OP_JAX_VALUE_DICT = {
5152
OP_ASSIGN: "work[{0}]",
5253
OP_ADD: "work[{0}] + work[{1}]",

0 commit comments

Comments
 (0)