Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 88876e0

Browse files
committedSep 17, 2024·
fix: merge new translate
1 parent 1c7378c commit 88876e0

File tree

4 files changed

+98
-82
lines changed

4 files changed

+98
-82
lines changed
 

‎jaxadi/_convert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ._compile import compile as compile_fn
88

99

10-
def convert(casadi_fn: Function, compile=False) -> Callable[..., Any]:
10+
def convert(casadi_fn: Function, compile=False, num_threads=1) -> Callable[..., Any]:
1111
"""
1212
Convert given casadi function into python
1313
callable based on JAX backend, optionally
@@ -17,7 +17,7 @@ def convert(casadi_fn: Function, compile=False) -> Callable[..., Any]:
1717
:param compile (bool): Whether to AOT compile function
1818
:return (Callable[..., Any]): Resulting python function
1919
"""
20-
jax_str = translate(casadi_fn)
20+
jax_str = translate(casadi_fn, num_threads=num_threads)
2121
jax_fn = declare(jax_str)
2222

2323
if compile:

‎jaxadi/_ops.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -49,50 +49,50 @@
4949
)
5050

5151
OP_JAX_VALUE_DICT = {
52-
OP_ASSIGN: "work[{0}]",
53-
OP_ADD: "work[{0}] + work[{1}]",
54-
OP_SUB: "work[{0}] - work[{1}]",
55-
OP_MUL: "work[{0}] * work[{1}]",
56-
OP_DIV: "work[{0}] / work[{1}]",
57-
OP_NEG: "-work[{0}]",
58-
OP_EXP: "jnp.exp(work[{0}])",
59-
OP_LOG: "jnp.log(work[{0}])",
60-
OP_POW: "jnp.power(work[{0}], work[{1}])",
61-
OP_CONSTPOW: "jnp.power(work[{0}], work[{1}])",
62-
OP_SQRT: "jnp.sqrt(work[{0}])",
63-
OP_SQ: "work[{0}] * work[{0}]",
64-
OP_TWICE: "2 * work[{0}]",
65-
OP_SIN: "jnp.sin(work[{0}])",
66-
OP_COS: "jnp.cos(work[{0}])",
67-
OP_TAN: "jnp.tan(work[{0}])",
68-
OP_ASIN: "jnp.arcsin(work[{0}])",
69-
OP_ACOS: "jnp.arccos(work[{0}])",
70-
OP_ATAN: "jnp.arctan(work[{0}])",
71-
OP_LT: "work[{0}] < work[{1}]",
72-
OP_LE: "work[{0}] <= work[{1}]",
73-
OP_EQ: "work[{0}] == work[{1}]",
74-
OP_NE: "work[{0}] != work[{1}]",
75-
OP_NOT: "jnp.logical_not(work[{0}])",
76-
OP_AND: "jnp.logical_and(work[{0}], work[{1}])",
77-
OP_OR: "jnp.logical_or(work[{0}], work[{1}])",
78-
OP_FLOOR: "jnp.floor(work[{0}])",
79-
OP_CEIL: "jnp.ceil(work[{0}])",
80-
OP_FMOD: "jnp.fmod(work[{0}], work[{1}])",
81-
OP_FABS: "jnp.abs(work[{0}])",
82-
OP_SIGN: "jnp.sign(work[{0}])",
83-
OP_COPYSIGN: "jnp.copysign(work[{0}], work[{1}])",
84-
OP_IF_ELSE_ZERO: "jnp.where(work[{0}] == 0, 0, work[{1}])",
85-
OP_ERF: "jax.scipy.special.erf(work[{0}])",
86-
OP_FMIN: "jnp.minimum(work[{0}], work[{1}])",
87-
OP_FMAX: "jnp.maximum(work[{0}], work[{1}])",
88-
OP_INV: "1.0 / work[{0}]",
89-
OP_SINH: "jnp.sinh(work[{0}])",
90-
OP_COSH: "jnp.cosh(work[{0}])",
91-
OP_TANH: "jnp.tanh(work[{0}])",
92-
OP_ASINH: "jnp.arcsinh(work[{0}])",
93-
OP_ACOSH: "jnp.arccosh(work[{0}])",
94-
OP_ATANH: "jnp.arctanh(work[{0}])",
95-
OP_ATAN2: "jnp.arctan2(work[{0}], work[{1}])",
52+
OP_ASSIGN: "{0}",
53+
OP_ADD: "{0}+{1}",
54+
OP_SUB: "{0}-{1}",
55+
OP_MUL: "{0}*{1}",
56+
OP_DIV: "{0}/{1}",
57+
OP_NEG: "-{0}",
58+
OP_EXP: "jnp.exp({0})",
59+
OP_LOG: "jnp.log({0})",
60+
OP_POW: "jnp.power({0}, {1})",
61+
OP_CONSTPOW: "jnp.power({0}, {1})",
62+
OP_SQRT: "jnp.sqrt({0})",
63+
OP_SQ: "{0} * {0}",
64+
OP_TWICE: "2 * {0}",
65+
OP_SIN: "jnp.sin({0})",
66+
OP_COS: "jnp.cos({0})",
67+
OP_TAN: "jnp.tan({0})",
68+
OP_ASIN: "jnp.arcsin({0})",
69+
OP_ACOS: "jnp.arccos({0})",
70+
OP_ATAN: "jnp.arctan({0})",
71+
OP_LT: "{0} < {1}",
72+
OP_LE: "{0} <= {1}",
73+
OP_EQ: "{0} == {1}",
74+
OP_NE: "{0} != {1}",
75+
OP_NOT: "jnp.logical_not({0})",
76+
OP_AND: "jnp.logical_and({0}, {1})",
77+
OP_OR: "jnp.logical_or({0}, {1})",
78+
OP_FLOOR: "jnp.floor({0})",
79+
OP_CEIL: "jnp.ceil({0})",
80+
OP_FMOD: "jnp.fmod({0}, {1})",
81+
OP_FABS: "jnp.abs({0})",
82+
OP_SIGN: "jnp.sign({0})",
83+
OP_COPYSIGN: "jnp.copysign({0}, {1})",
84+
OP_IF_ELSE_ZERO: "jnp.where({0} == 0, 0, {1})",
85+
OP_ERF: "jax.scipy.special.erf({0})",
86+
OP_FMIN: "jnp.minimum({0}, {1})",
87+
OP_FMAX: "jnp.maximum({0}, {1})",
88+
OP_INV: "1.0/{0}",
89+
OP_SINH: "jnp.sinh({0})",
90+
OP_COSH: "jnp.cosh({0})",
91+
OP_TANH: "jnp.tanh({0})",
92+
OP_ASINH: "jnp.arcsinh({0})",
93+
OP_ACOSH: "jnp.arccosh({0})",
94+
OP_ATANH: "jnp.arctanh({0})",
95+
OP_ATAN2: "jnp.arctan2({0}, {1})",
9696
OP_CONST: "{0:.16f}",
9797
OP_INPUT: "inputs[{0}][{1}, {2}]",
9898
OP_OUTPUT: "work[{0}][0]",

‎jaxadi/_stages.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from ._ops import OP_JAX_VALUE_DICT
33
from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function
44
import re
5+
from tqdm import tqdm
6+
from multiprocessing import Pool, cpu_count
57

68

79
class Stage:
@@ -56,6 +58,9 @@ def stage_generator(func: Function) -> str:
5658
n_instr = func.n_instructions()
5759
n_out = func.n_out() # number of outputs in the function
5860
n_in = func.n_in() # number of outputs in the function
61+
n_w = func.sz_w()
62+
63+
workers = [""] * n_w
5964

6065
# Get the shapes of input and output
6166
out_shapes = [func.size_out(i) for i in range(n_out)]
@@ -67,23 +72,23 @@ def stage_generator(func: Function) -> str:
6772
const_instr = [func.instruction_constant(i) for i in range(n_instr)]
6873

6974
stages = []
70-
for k in range(n_instr):
75+
for k in tqdm(range(n_instr)):
7176
op = operations[k]
7277
o_idx = output_idx[k]
7378
i_idx = input_idx[k]
7479
operation = Operation()
7580
operation.op = op
7681
if op == OP_CONST:
77-
operation.output_idx = o_idx[0]
78-
operation.value = "jnp.array([" + OP_JAX_VALUE_DICT[op].format(const_instr[k]) + "])"
79-
# codegen += OP_JAX_DICT[op].format(o_idx[0], const_instr[k])
82+
workers[o_idx[0]
83+
] = "jnp.array([" + OP_JAX_VALUE_DICT[op].format(const_instr[k]) + "])"
84+
8085
elif op == OP_INPUT:
8186
this_shape = in_shapes[i_idx[0]]
8287
rows, cols = this_shape # Get the shape of the output
8388
row_number = i_idx[1] % rows # Compute row index for JAX
8489
column_number = i_idx[1] // rows # Compute column index for JAX
85-
operation.output_idx = o_idx[0]
86-
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0], row_number, column_number)
90+
workers[o_idx[0]] = OP_JAX_VALUE_DICT[op].format(
91+
i_idx[0], row_number, column_number)
8792
elif op == OP_OUTPUT:
8893
operation = OutputOperation()
8994
operation.op = op
@@ -94,28 +99,25 @@ def stage_generator(func: Function) -> str:
9499
operation.exact_idx2 = column_number
95100
operation.output_idx = o_idx[0]
96101
operation.work_idx.append(i_idx[0])
97-
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0])
102+
operation.value = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]])
103+
stage = Stage()
104+
stage.output_idx.append(operation.output_idx)
105+
stage.work_idx.extend(operation.work_idx)
106+
stage.ops.append(operation)
107+
stages.append(stage)
98108
elif op == OP_SQ:
99-
operation.output_idx = o_idx[0]
100-
operation.work_idx.append(i_idx[0])
101-
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0])
109+
workers[o_idx[0]] = "(" + \
110+
OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]]) + ")"
102111
elif OP_JAX_VALUE_DICT[op].count("}") == 2:
103-
operation.output_idx = o_idx[0]
104-
operation.work_idx.extend([i_idx[0], i_idx[1]])
105-
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0], i_idx[1])
112+
workers[o_idx[0]] = "(" + OP_JAX_VALUE_DICT[op].format(
113+
workers[i_idx[0]], workers[i_idx[1]]) + ")"
106114
elif OP_JAX_VALUE_DICT[op].count("}") == 1:
107-
operation.output_idx = o_idx[0]
108-
operation.work_idx.append(i_idx[0])
109-
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0])
115+
workers[o_idx[0]] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]])
110116
else:
111117
raise Exception("Unknown CasADi operation: " + str(op))
118+
print(sum(len(s) for s in workers))
112119

113-
stage = Stage()
114-
stage.output_idx.append(operation.output_idx)
115-
stage.work_idx.extend(operation.work_idx)
116-
stage.ops.append(operation)
117-
stages.append(stage)
118-
120+
print("finished stages")
119121
return stages
120122

121123

@@ -146,7 +148,7 @@ def combine_outputs(stages: List[Stage]) -> str:
146148
rows = "[" + ", ".join(row_indices) + "]"
147149
columns = "[" + ", ".join(column_indices) + "]"
148150
values_str = ", ".join(values)
149-
command = f" outputs[{output_idx}] = outputs[{output_idx}].at[({rows}, {columns})].set([{values_str}])"
151+
command = f" o[{output_idx}] = o[{output_idx}].at[({rows}, {columns})].set([{values_str}])"
150152
commands.append(command)
151153

152154
# Combine all the commands into a single string
@@ -177,19 +179,32 @@ def recursive_subs(stages: List[Stage], idx: int) -> str:
177179
for i in range(idx - 1, -1, -1):
178180
if stages[i].ops[0].output_idx == number and stages[i].ops[0].op != OP_OUTPUT:
179181
# Recursively replace the found work[<number>] with expanded value
180-
expanded_value = recursive_subs(stages, i)
181-
result = result.replace(f"work[{number}]", expanded_value)
182+
stages[i].ops[0].value = recursive_subs(stages, i)
183+
result = result.replace(
184+
f"work[{number}]", stages[i].ops[0].value)
182185
break
183186

184187
return f"({result})"
185188

186189

187-
def squeeze(stages: List[Stage]) -> List[Stage]:
190+
def process_stage(args):
191+
stages, i = args
192+
if len(stages[i].ops) != 0:
193+
stages[i].ops[0].value = recursive_subs(stages, i)
194+
return stages[i]
195+
return None
196+
197+
198+
def squeeze(stages: List[Stage], num_threads=1) -> List[Stage]:
188199
new_stages = []
189-
for i in range(len(stages)):
190-
if len(stages[i].ops) != 0 and stages[i].ops[0].op == OP_OUTPUT:
191-
stages[i].ops[0].value = recursive_subs(stages, i)
192-
new_stages.append(stages[i])
200+
working_stages = []
201+
for i, stage in enumerate(stages):
202+
if len(stage.ops) != 0 and stage.ops[0].op == OP_OUTPUT:
203+
working_stages.append((i, stage))
204+
for i in tqdm(range(len(working_stages))):
205+
i, stage = working_stages[i]
206+
stage.value = recursive_subs(stages, i)
207+
new_stages.append(stage)
193208

194209
cmd = combine_outputs(new_stages)
195210
return cmd

‎jaxadi/_translate.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
from ._stages import stage_generator, squeeze
33

44

5-
def translate(func: Function, add_jit=False, add_import=False) -> str:
6-
stages = stage_generator(func)
7-
stages = squeeze(stages)
8-
# get information about casadi function
5+
def translate(func: Function, add_jit=False, add_import=False, num_threads=1) -> str:
96
n_out = func.n_out() # number of outputs in the function
107

118
# get the shapes of input and output
129
out_shapes = [func.size_out(i) for i in range(n_out)]
10+
print(out_shapes)
11+
stages = stage_generator(func)
12+
stages = squeeze(stages, num_threads=num_threads)
13+
# get information about casadi function
1314

1415
# generate string with complete code
1516
codegen = ""
@@ -20,13 +21,13 @@ def translate(func: Function, add_jit=False, add_import=False) -> str:
2021
# combine all inputs into a single list
2122
codegen += " inputs = [jnp.expand_dims(jnp.array(arg), axis=-1) for arg in args]\n"
2223
# output variables
23-
codegen += f" outputs = [jnp.zeros(out) for out in {out_shapes}]\n"
24+
codegen += f" o = [jnp.zeros(out) for out in {out_shapes}]\n"
2425

2526
# for stage in stages:
2627
# codegen += stage.codegen()
2728
codegen += stages
2829

2930
# footer
30-
codegen += "\n return outputs\n"
31+
codegen += "\n return o\n"
3132

3233
return codegen

0 commit comments

Comments
 (0)
Please sign in to comment.