Skip to content

Commit

Permalink
Generate data_tests with jax backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinvis committed Mar 14, 2024
1 parent e6dae54 commit 463565d
Show file tree
Hide file tree
Showing 21 changed files with 396 additions and 0 deletions.
77 changes: 77 additions & 0 deletions _transonic_testing/src/_transonic_testing/__jax__/for_test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# __protected__ from jax import jit
# __protected__ @jit


def func_tmp(arg):
return arg**2


import jax.numpy as np

# __protected__ @jit


def func(a, b):
i = 2
c = 0.1 + i
return a + b + c


# __protected__ @jit


def func0(a, b):
return a + b


# __protected__ @jit


def func2(a, b):
return a - func_tmp(b)


# __protected__ @jit


def func3(c):
return c[0] + 1


# __protected__ @jit


def __for_method__Transmitter____call__(self_freq, inp):
"""My docstring"""
return inp * np.exp(np.arange(len(inp)) * self_freq * 1j)


__code_new_method__Transmitter____call__ = (
"\n\ndef new_method(self, inp):\n return backend_func(self.freq, inp)\n\n"
)
# __protected__ @jit


def __for_method__Transmitter__other_func(self_freq):
return 2 * self_freq


__code_new_method__Transmitter__other_func = (
"\n\ndef new_method(self, ):\n return backend_func(self.freq, )\n\n"
)
# __protected__ @jit


def block0(a, b, n):
# transonic block (
# float a, b;
# int n
# )
result = 0.0
for _ in range(n):
result += a**2 + b**3
return result


arguments_blocks = {"block0": ["a", "b", "n"]}
__transonic__ = ("0.6.3+editable",)
7 changes: 7 additions & 0 deletions data_tests/__jax__/__ext__MyClass2__exterior_import_boost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
const = 1
from __ext__MyClass2__exterior_import_boost_2 import func_import_2
import numpy as np


def func_import():
return const + func_import_2() + np.pi - np.pi
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
const = 1


def func_import_2():
return const
7 changes: 7 additions & 0 deletions data_tests/__jax__/__ext__func__exterior_import_boost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
const = 1
from __ext__func__exterior_import_boost_2 import func_import_2
import numpy as np


def func_import():
return const + func_import_2() + np.pi - np.pi
5 changes: 5 additions & 0 deletions data_tests/__jax__/__ext__func__exterior_import_boost_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
const = 1


def func_import_2():
return const
19 changes: 19 additions & 0 deletions data_tests/__jax__/add_inline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# __protected__ from jax import jit
# __protected__ @jit


def add(a, b):
return a + b


# __protected__ @jit


def use_add(n=10000):
tmp = 0
for _ in range(n):
tmp = add(tmp, 1)
return tmp


__transonic__ = ("0.6.3+editable",)
9 changes: 9 additions & 0 deletions data_tests/__jax__/assign_func_boost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# __protected__ from jax import jit
# __protected__ @jit


def func(x):
return x**2


__transonic__ = ("0.6.3+editable",)
18 changes: 18 additions & 0 deletions data_tests/__jax__/block_fluidsim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# __protected__ from jax import jit
# __protected__ @jit


def rk2_step0(state_spect_n12, state_spect, tendencies_n, diss2, dt):
# transonic block (
# complex128[][][] state_spect_n12, state_spect,
# tendencies_n;
# float64[][] diss2;
# float dt
# )
state_spect_n12[:] = (state_spect + dt / 2 * tendencies_n) * diss2


arguments_blocks = {
"rk2_step0": ["state_spect_n12", "state_spect", "tendencies_n", "diss2", "dt"]
}
__transonic__ = ("0.6.3+editable",)
19 changes: 19 additions & 0 deletions data_tests/__jax__/blocks_type_hints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# __protected__ from jax import jit
# __protected__ @jit


def block0(a, b, n):
# transonic block (
# A a; A1 b;
# int n
# )
# transonic block (
# int[:] a, b;
# float n
# )
result = a**2 + b.mean() ** 3 + n
return result


arguments_blocks = {"block0": ["a", "b", "n"]}
__transonic__ = ("0.6.3+editable",)
13 changes: 13 additions & 0 deletions data_tests/__jax__/boosted_class_use_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# __protected__ from jax import jit
import jax.numpy as np
from __ext__MyClass2__exterior_import_boost import func_import

# __protected__ @jit


def __for_method__MyClass2__myfunc(self_attr0, self_attr1, arg):
return self_attr1 + self_attr0 + np.abs(arg) + func_import()


__code_new_method__MyClass2__myfunc = "\n\ndef new_method(self, arg):\n return backend_func(self.attr0, self.attr1, arg)\n\n"
__transonic__ = ("0.6.3+editable",)
12 changes: 12 additions & 0 deletions data_tests/__jax__/boosted_func_use_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# __protected__ from jax import jit
import jax.numpy as np
from __ext__func__exterior_import_boost import func_import

# __protected__ @jit


def func(a, b):
return (a * np.log(b)).max() + func_import()


__transonic__ = ("0.6.3+editable",)
44 changes: 44 additions & 0 deletions data_tests/__jax__/class_blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# __protected__ from jax import jit
import jax.numpy as np

# __protected__ @jit


def block0(a, b, n):
# foo
# transonic block (
# float[][] a, b;
# int n
# ) bar
# foo
# transonic block (
# float[][][] a, b;
# int n
# )
# foobar
result = np.zeros_like(a)
for _ in range(n):
result += a**2 + b**3
return result


# __protected__ @jit


def block1(a, b, n):
# transonic block (
# float[][] a, b;
# int n
# )
# transonic block (
# float[][][] a, b;
# int n
# )
result = np.zeros_like(a)
for _ in range(n):
result += a**2 + b**3
return result


arguments_blocks = {"block0": ["a", "b", "n"], "block1": ["a", "b", "n"]}
__transonic__ = ("0.6.3+editable",)
20 changes: 20 additions & 0 deletions data_tests/__jax__/class_rec_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# __protected__ from jax import jit
# __protected__ @jit


def __for_method__Myclass__func(self_attr, self_attr2, arg):
if __for_method__Myclass__func(self_attr, self_attr2, arg - 1) < 1:
return 1
else:
a = __for_method__Myclass__func(
self_attr, self_attr2, arg - 1
) * __for_method__Myclass__func(self_attr, self_attr2, arg - 1)
return (
a
+ self_attr * self_attr2 * arg
+ __for_method__Myclass__func(self_attr, self_attr2, arg - 1)
)


__code_new_method__Myclass__func = "\n\ndef new_method(self, arg):\n return backend_func(self.attr, self.attr2, arg)\n\n"
__transonic__ = ("0.6.3+editable",)
11 changes: 11 additions & 0 deletions data_tests/__jax__/classic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# __protected__ from jax import jit
import jax.numpy as np

# __protected__ @jit


def func(a, b):
return (a * np.log(b)).max()


__transonic__ = ("0.6.3+editable",)
10 changes: 10 additions & 0 deletions data_tests/__jax__/default_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# __protected__ from jax import jit
# __protected__ @jit


def func(a=1, b=None, c=1.0):
print(b)
return a + c


__transonic__ = ("0.6.3+editable",)
13 changes: 13 additions & 0 deletions data_tests/__jax__/methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# __protected__ from jax import jit
import jax.numpy as np

# __protected__ @jit


def __for_method__Transmitter____call__(self_arr, self_freq, inp):
"""My docstring"""
return (inp * np.exp(np.arange(len(inp)) * self_freq * 1j), self_arr)


__code_new_method__Transmitter____call__ = "\n\ndef new_method(self, inp):\n return backend_func(self.arr, self.freq, inp)\n\n"
__transonic__ = ("0.6.3+editable",)
18 changes: 18 additions & 0 deletions data_tests/__jax__/mixed_classic_type_hint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# __protected__ from jax import jit
import jax.numpy as np

# __protected__ @jit


def func(a, b):
return (a * np.log(b)).max()


# __protected__ @jit


def func1(a, b):
return a * np.cos(b)


__transonic__ = ("0.6.3+editable",)
16 changes: 16 additions & 0 deletions data_tests/__jax__/no_arg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# __protected__ from jax import jit
# __protected__ @jit


def func():
return 1


# __protected__ @jit


def func2():
return 1


__transonic__ = ("0.6.3+editable",)
27 changes: 27 additions & 0 deletions data_tests/__jax__/row_sum_boost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# __protected__ from jax import jit
import jax.numpy as np

# __protected__ @jit


def row_sum(arr, columns):
return arr.T[columns].sum(0)


# __protected__ @jit


def row_sum_loops(arr, columns):
# locals type annotations are used only for Cython
# arr.dtype not supported for memoryview
dtype = type(arr[0, 0])
res = np.empty(arr.shape[0], dtype=dtype)
for i in range(arr.shape[0]):
sum_ = dtype(0)
for j in range(columns.shape[0]):
sum_ += arr[i, columns[j]]
res[i] = sum_
return res


__transonic__ = ("0.6.3+editable",)
Loading

0 comments on commit 463565d

Please sign in to comment.