From 46d54135eb2bf41162317d4de07355f4bf79023f Mon Sep 17 00:00:00 2001 From: "Ashwin V. Mohanan" Date: Thu, 14 Mar 2024 22:24:02 +0100 Subject: [PATCH] Generate data_tests with jax backend --- .../__jax__/for_test_init.py | 77 +++++++++++++++++++ .../__ext__MyClass2__exterior_import_boost.py | 7 ++ ..._ext__MyClass2__exterior_import_boost_2.py | 5 ++ .../__ext__func__exterior_import_boost.py | 7 ++ .../__ext__func__exterior_import_boost_2.py | 5 ++ data_tests/__jax__/add_inline.py | 19 +++++ data_tests/__jax__/assign_func_boost.py | 9 +++ data_tests/__jax__/block_fluidsim.py | 18 +++++ data_tests/__jax__/blocks_type_hints.py | 19 +++++ .../__jax__/boosted_class_use_import.py | 13 ++++ data_tests/__jax__/boosted_func_use_import.py | 12 +++ data_tests/__jax__/class_blocks.py | 44 +++++++++++ data_tests/__jax__/class_rec_calls.py | 20 +++++ data_tests/__jax__/classic.py | 11 +++ data_tests/__jax__/default_params.py | 10 +++ data_tests/__jax__/methods.py | 13 ++++ data_tests/__jax__/mixed_classic_type_hint.py | 18 +++++ data_tests/__jax__/no_arg.py | 16 ++++ data_tests/__jax__/row_sum_boost.py | 27 +++++++ data_tests/__jax__/subpackages.py | 33 ++++++++ data_tests/__jax__/type_hint_notemplate.py | 13 ++++ 21 files changed, 396 insertions(+) create mode 100644 _transonic_testing/src/_transonic_testing/__jax__/for_test_init.py create mode 100644 data_tests/__jax__/__ext__MyClass2__exterior_import_boost.py create mode 100644 data_tests/__jax__/__ext__MyClass2__exterior_import_boost_2.py create mode 100644 data_tests/__jax__/__ext__func__exterior_import_boost.py create mode 100644 data_tests/__jax__/__ext__func__exterior_import_boost_2.py create mode 100644 data_tests/__jax__/add_inline.py create mode 100644 data_tests/__jax__/assign_func_boost.py create mode 100644 data_tests/__jax__/block_fluidsim.py create mode 100644 data_tests/__jax__/blocks_type_hints.py create mode 100644 data_tests/__jax__/boosted_class_use_import.py create mode 100644 data_tests/__jax__/boosted_func_use_import.py create mode 100644 data_tests/__jax__/class_blocks.py create mode 100644 data_tests/__jax__/class_rec_calls.py create mode 100644 data_tests/__jax__/classic.py create mode 100644 data_tests/__jax__/default_params.py create mode 100644 data_tests/__jax__/methods.py create mode 100644 data_tests/__jax__/mixed_classic_type_hint.py create mode 100644 data_tests/__jax__/no_arg.py create mode 100644 data_tests/__jax__/row_sum_boost.py create mode 100644 data_tests/__jax__/subpackages.py create mode 100644 data_tests/__jax__/type_hint_notemplate.py diff --git a/_transonic_testing/src/_transonic_testing/__jax__/for_test_init.py b/_transonic_testing/src/_transonic_testing/__jax__/for_test_init.py new file mode 100644 index 0000000..23a4d25 --- /dev/null +++ b/_transonic_testing/src/_transonic_testing/__jax__/for_test_init.py @@ -0,0 +1,77 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def func_tmp(arg): + return arg**2 + + +import jax.numpy as np + +# __protected__ @jit + + +def func(a, b): + i = 2 + c = 0.1 + i + return a + b + c + + +# __protected__ @jit + + +def func0(a, b): + return a + b + + +# __protected__ @jit + + +def func2(a, b): + return a - func_tmp(b) + + +# __protected__ @jit + + +def func3(c): + return c[0] + 1 + + +# __protected__ @jit + + +def __for_method__Transmitter____call__(self_freq, inp): + """My docstring""" + return inp * np.exp(np.arange(len(inp)) * self_freq * 1j) + + +__code_new_method__Transmitter____call__ = ( + "\n\ndef new_method(self, inp):\n return backend_func(self.freq, inp)\n\n" +) +# __protected__ @jit + + +def __for_method__Transmitter__other_func(self_freq): + return 2 * self_freq + + +__code_new_method__Transmitter__other_func = ( + "\n\ndef new_method(self, ):\n return backend_func(self.freq, )\n\n" +) +# __protected__ @jit + + +def block0(a, b, n): + # transonic block ( + # float a, b; + # int n + # ) + result = 0.0 + for _ in range(n): + result += a**2 + b**3 + return result + + +arguments_blocks = {"block0": ["a", "b", "n"]} +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/__ext__MyClass2__exterior_import_boost.py b/data_tests/__jax__/__ext__MyClass2__exterior_import_boost.py new file mode 100644 index 0000000..365c3a0 --- /dev/null +++ b/data_tests/__jax__/__ext__MyClass2__exterior_import_boost.py @@ -0,0 +1,7 @@ +const = 1 +from __ext__MyClass2__exterior_import_boost_2 import func_import_2 +import numpy as np + + +def func_import(): + return const + func_import_2() + np.pi - np.pi diff --git a/data_tests/__jax__/__ext__MyClass2__exterior_import_boost_2.py b/data_tests/__jax__/__ext__MyClass2__exterior_import_boost_2.py new file mode 100644 index 0000000..d12395e --- /dev/null +++ b/data_tests/__jax__/__ext__MyClass2__exterior_import_boost_2.py @@ -0,0 +1,5 @@ +const = 1 + + +def func_import_2(): + return const diff --git a/data_tests/__jax__/__ext__func__exterior_import_boost.py b/data_tests/__jax__/__ext__func__exterior_import_boost.py new file mode 100644 index 0000000..7be565f --- /dev/null +++ b/data_tests/__jax__/__ext__func__exterior_import_boost.py @@ -0,0 +1,7 @@ +const = 1 +from __ext__func__exterior_import_boost_2 import func_import_2 +import numpy as np + + +def func_import(): + return const + func_import_2() + np.pi - np.pi diff --git a/data_tests/__jax__/__ext__func__exterior_import_boost_2.py b/data_tests/__jax__/__ext__func__exterior_import_boost_2.py new file mode 100644 index 0000000..d12395e --- /dev/null +++ b/data_tests/__jax__/__ext__func__exterior_import_boost_2.py @@ -0,0 +1,5 @@ +const = 1 + + +def func_import_2(): + return const diff --git a/data_tests/__jax__/add_inline.py b/data_tests/__jax__/add_inline.py new file mode 100644 index 0000000..de71101 --- /dev/null +++ b/data_tests/__jax__/add_inline.py @@ -0,0 +1,19 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def add(a, b): + return a + b + + +# __protected__ @jit + + +def use_add(n=10000): + tmp = 0 + for _ in range(n): + tmp = add(tmp, 1) + return tmp + + +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/assign_func_boost.py b/data_tests/__jax__/assign_func_boost.py new file mode 100644 index 0000000..a1cd38d --- /dev/null +++ b/data_tests/__jax__/assign_func_boost.py @@ -0,0 +1,9 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def func(x): + return x**2 + + +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/block_fluidsim.py b/data_tests/__jax__/block_fluidsim.py new file mode 100644 index 0000000..992a358 --- /dev/null +++ b/data_tests/__jax__/block_fluidsim.py @@ -0,0 +1,18 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def rk2_step0(state_spect_n12, state_spect, tendencies_n, diss2, dt): + # transonic block ( + # complex128[][][] state_spect_n12, state_spect, + # tendencies_n; + # float64[][] diss2; + # float dt + # ) + state_spect_n12[:] = (state_spect + dt / 2 * tendencies_n) * diss2 + + +arguments_blocks = { + "rk2_step0": ["state_spect_n12", "state_spect", "tendencies_n", "diss2", "dt"] +} +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/blocks_type_hints.py b/data_tests/__jax__/blocks_type_hints.py new file mode 100644 index 0000000..10c5842 --- /dev/null +++ b/data_tests/__jax__/blocks_type_hints.py @@ -0,0 +1,19 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def block0(a, b, n): + # transonic block ( + # A a; A1 b; + # int n + # ) + # transonic block ( + # int[:] a, b; + # float n + # ) + result = a**2 + b.mean() ** 3 + n + return result + + +arguments_blocks = {"block0": ["a", "b", "n"]} +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/boosted_class_use_import.py b/data_tests/__jax__/boosted_class_use_import.py new file mode 100644 index 0000000..3fe292f --- /dev/null +++ b/data_tests/__jax__/boosted_class_use_import.py @@ -0,0 +1,13 @@ +# __protected__ from jax import jit +import jax.numpy as np +from __ext__MyClass2__exterior_import_boost import func_import + +# __protected__ @jit + + +def __for_method__MyClass2__myfunc(self_attr0, self_attr1, arg): + return self_attr1 + self_attr0 + np.abs(arg) + func_import() + + +__code_new_method__MyClass2__myfunc = "\n\ndef new_method(self, arg):\n return backend_func(self.attr0, self.attr1, arg)\n\n" +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/boosted_func_use_import.py b/data_tests/__jax__/boosted_func_use_import.py new file mode 100644 index 0000000..ce3cf41 --- /dev/null +++ b/data_tests/__jax__/boosted_func_use_import.py @@ -0,0 +1,12 @@ +# __protected__ from jax import jit +import jax.numpy as np +from __ext__func__exterior_import_boost import func_import + +# __protected__ @jit + + +def func(a, b): + return (a * np.log(b)).max() + func_import() + + +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/class_blocks.py b/data_tests/__jax__/class_blocks.py new file mode 100644 index 0000000..6503120 --- /dev/null +++ b/data_tests/__jax__/class_blocks.py @@ -0,0 +1,44 @@ +# __protected__ from jax import jit +import jax.numpy as np + +# __protected__ @jit + + +def block0(a, b, n): + # foo + # transonic block ( + # float[][] a, b; + # int n + # ) bar + # foo + # transonic block ( + # float[][][] a, b; + # int n + # ) + # foobar + result = np.zeros_like(a) + for _ in range(n): + result += a**2 + b**3 + return result + + +# __protected__ @jit + + +def block1(a, b, n): + # transonic block ( + # float[][] a, b; + # int n + # ) + # transonic block ( + # float[][][] a, b; + # int n + # ) + result = np.zeros_like(a) + for _ in range(n): + result += a**2 + b**3 + return result + + +arguments_blocks = {"block0": ["a", "b", "n"], "block1": ["a", "b", "n"]} +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/class_rec_calls.py b/data_tests/__jax__/class_rec_calls.py new file mode 100644 index 0000000..5d0adf8 --- /dev/null +++ b/data_tests/__jax__/class_rec_calls.py @@ -0,0 +1,20 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def __for_method__Myclass__func(self_attr, self_attr2, arg): + if __for_method__Myclass__func(self_attr, self_attr2, arg - 1) < 1: + return 1 + else: + a = __for_method__Myclass__func( + self_attr, self_attr2, arg - 1 + ) * __for_method__Myclass__func(self_attr, self_attr2, arg - 1) + return ( + a + + self_attr * self_attr2 * arg + + __for_method__Myclass__func(self_attr, self_attr2, arg - 1) + ) + + +__code_new_method__Myclass__func = "\n\ndef new_method(self, arg):\n return backend_func(self.attr, self.attr2, arg)\n\n" +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/classic.py b/data_tests/__jax__/classic.py new file mode 100644 index 0000000..57a272f --- /dev/null +++ b/data_tests/__jax__/classic.py @@ -0,0 +1,11 @@ +# __protected__ from jax import jit +import jax.numpy as np + +# __protected__ @jit + + +def func(a, b): + return (a * np.log(b)).max() + + +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/default_params.py b/data_tests/__jax__/default_params.py new file mode 100644 index 0000000..7280ddc --- /dev/null +++ b/data_tests/__jax__/default_params.py @@ -0,0 +1,10 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def func(a=1, b=None, c=1.0): + print(b) + return a + c + + +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/methods.py b/data_tests/__jax__/methods.py new file mode 100644 index 0000000..cb01723 --- /dev/null +++ b/data_tests/__jax__/methods.py @@ -0,0 +1,13 @@ +# __protected__ from jax import jit +import jax.numpy as np + +# __protected__ @jit + + +def __for_method__Transmitter____call__(self_arr, self_freq, inp): + """My docstring""" + return (inp * np.exp(np.arange(len(inp)) * self_freq * 1j), self_arr) + + +__code_new_method__Transmitter____call__ = "\n\ndef new_method(self, inp):\n return backend_func(self.arr, self.freq, inp)\n\n" +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/mixed_classic_type_hint.py b/data_tests/__jax__/mixed_classic_type_hint.py new file mode 100644 index 0000000..3b05f02 --- /dev/null +++ b/data_tests/__jax__/mixed_classic_type_hint.py @@ -0,0 +1,18 @@ +# __protected__ from jax import jit +import jax.numpy as np + +# __protected__ @jit + + +def func(a, b): + return (a * np.log(b)).max() + + +# __protected__ @jit + + +def func1(a, b): + return a * np.cos(b) + + +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/no_arg.py b/data_tests/__jax__/no_arg.py new file mode 100644 index 0000000..447a8d3 --- /dev/null +++ b/data_tests/__jax__/no_arg.py @@ -0,0 +1,16 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def func(): + return 1 + + +# __protected__ @jit + + +def func2(): + return 1 + + +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/row_sum_boost.py b/data_tests/__jax__/row_sum_boost.py new file mode 100644 index 0000000..2793ee1 --- /dev/null +++ b/data_tests/__jax__/row_sum_boost.py @@ -0,0 +1,27 @@ +# __protected__ from jax import jit +import jax.numpy as np + +# __protected__ @jit + + +def row_sum(arr, columns): + return arr.T[columns].sum(0) + + +# __protected__ @jit + + +def row_sum_loops(arr, columns): + # locals type annotations are used only for Cython + # arr.dtype not supported for memoryview + dtype = type(arr[0, 0]) + res = np.empty(arr.shape[0], dtype=dtype) + for i in range(arr.shape[0]): + sum_ = dtype(0) + for j in range(columns.shape[0]): + sum_ += arr[i, columns[j]] + res[i] = sum_ + return res + + +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/subpackages.py b/data_tests/__jax__/subpackages.py new file mode 100644 index 0000000..7b9c43e --- /dev/null +++ b/data_tests/__jax__/subpackages.py @@ -0,0 +1,33 @@ +# __protected__ from jax import jit +from numpy.fft import rfft +from numpy.random import randn +from numpy.linalg import matrix_power +from scipy.special import jv + +# __protected__ @jit + + +def test_np_fft(u): + u_fft = rfft(u) + return u_fft + + +# __protected__ @jit + + +def test_np_linalg_random(u): + nx, ny = u.shape + u[:] = randn(nx, ny) + u2 = u.T * u + u4 = matrix_power(u2, 2) + return u4 + + +# __protected__ @jit + + +def test_sp_special(v, x): + return jv(v, x) + + +__transonic__ = ("0.6.3+editable",) diff --git a/data_tests/__jax__/type_hint_notemplate.py b/data_tests/__jax__/type_hint_notemplate.py new file mode 100644 index 0000000..b07cd6b --- /dev/null +++ b/data_tests/__jax__/type_hint_notemplate.py @@ -0,0 +1,13 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def compute(a, b, c, d, e): + print(e) + tmp = a + b + if 1 and 2: + tmp *= 2 + return tmp + + +__transonic__ = ("0.6.3+editable",)