-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generate data_tests with jax backend
- Loading branch information
Showing
21 changed files
with
396 additions
and
0 deletions.
There are no files selected for viewing
77 changes: 77 additions & 0 deletions
77
_transonic_testing/src/_transonic_testing/__jax__/for_test_init.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# __protected__ from jax import jit | ||
# __protected__ @jit | ||
|
||
|
||
def func_tmp(arg): | ||
return arg**2 | ||
|
||
|
||
import jax.numpy as np | ||
|
||
# __protected__ @jit | ||
|
||
|
||
def func(a, b): | ||
i = 2 | ||
c = 0.1 + i | ||
return a + b + c | ||
|
||
|
||
# __protected__ @jit | ||
|
||
|
||
def func0(a, b): | ||
return a + b | ||
|
||
|
||
# __protected__ @jit | ||
|
||
|
||
def func2(a, b): | ||
return a - func_tmp(b) | ||
|
||
|
||
# __protected__ @jit | ||
|
||
|
||
def func3(c): | ||
return c[0] + 1 | ||
|
||
|
||
# __protected__ @jit | ||
|
||
|
||
def __for_method__Transmitter____call__(self_freq, inp): | ||
"""My docstring""" | ||
return inp * np.exp(np.arange(len(inp)) * self_freq * 1j) | ||
|
||
|
||
__code_new_method__Transmitter____call__ = ( | ||
"\n\ndef new_method(self, inp):\n return backend_func(self.freq, inp)\n\n" | ||
) | ||
# __protected__ @jit | ||
|
||
|
||
def __for_method__Transmitter__other_func(self_freq): | ||
return 2 * self_freq | ||
|
||
|
||
__code_new_method__Transmitter__other_func = ( | ||
"\n\ndef new_method(self, ):\n return backend_func(self.freq, )\n\n" | ||
) | ||
# __protected__ @jit | ||
|
||
|
||
def block0(a, b, n): | ||
# transonic block ( | ||
# float a, b; | ||
# int n | ||
# ) | ||
result = 0.0 | ||
for _ in range(n): | ||
result += a**2 + b**3 | ||
return result | ||
|
||
|
||
arguments_blocks = {"block0": ["a", "b", "n"]} | ||
__transonic__ = ("0.6.3+editable",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
const = 1 | ||
from __ext__MyClass2__exterior_import_boost_2 import func_import_2 | ||
import numpy as np | ||
|
||
|
||
def func_import(): | ||
return const + func_import_2() + np.pi - np.pi |
5 changes: 5 additions & 0 deletions
5
data_tests/__jax__/__ext__MyClass2__exterior_import_boost_2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
const = 1 | ||
|
||
|
||
def func_import_2(): | ||
return const |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
const = 1 | ||
from __ext__func__exterior_import_boost_2 import func_import_2 | ||
import numpy as np | ||
|
||
|
||
def func_import(): | ||
return const + func_import_2() + np.pi - np.pi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
const = 1 | ||
|
||
|
||
def func_import_2(): | ||
return const |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# __protected__ from jax import jit | ||
# __protected__ @jit | ||
|
||
|
||
def add(a, b): | ||
return a + b | ||
|
||
|
||
# __protected__ @jit | ||
|
||
|
||
def use_add(n=10000): | ||
tmp = 0 | ||
for _ in range(n): | ||
tmp = add(tmp, 1) | ||
return tmp | ||
|
||
|
||
__transonic__ = ("0.6.3+editable",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# __protected__ from jax import jit | ||
# __protected__ @jit | ||
|
||
|
||
def func(x): | ||
return x**2 | ||
|
||
|
||
__transonic__ = ("0.6.3+editable",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# __protected__ from jax import jit | ||
# __protected__ @jit | ||
|
||
|
||
def rk2_step0(state_spect_n12, state_spect, tendencies_n, diss2, dt): | ||
# transonic block ( | ||
# complex128[][][] state_spect_n12, state_spect, | ||
# tendencies_n; | ||
# float64[][] diss2; | ||
# float dt | ||
# ) | ||
state_spect_n12[:] = (state_spect + dt / 2 * tendencies_n) * diss2 | ||
|
||
|
||
arguments_blocks = { | ||
"rk2_step0": ["state_spect_n12", "state_spect", "tendencies_n", "diss2", "dt"] | ||
} | ||
__transonic__ = ("0.6.3+editable",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# __protected__ from jax import jit | ||
# __protected__ @jit | ||
|
||
|
||
def block0(a, b, n): | ||
# transonic block ( | ||
# A a; A1 b; | ||
# int n | ||
# ) | ||
# transonic block ( | ||
# int[:] a, b; | ||
# float n | ||
# ) | ||
result = a**2 + b.mean() ** 3 + n | ||
return result | ||
|
||
|
||
arguments_blocks = {"block0": ["a", "b", "n"]} | ||
__transonic__ = ("0.6.3+editable",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# __protected__ from jax import jit | ||
import jax.numpy as np | ||
from __ext__MyClass2__exterior_import_boost import func_import | ||
|
||
# __protected__ @jit | ||
|
||
|
||
def __for_method__MyClass2__myfunc(self_attr0, self_attr1, arg): | ||
return self_attr1 + self_attr0 + np.abs(arg) + func_import() | ||
|
||
|
||
__code_new_method__MyClass2__myfunc = "\n\ndef new_method(self, arg):\n return backend_func(self.attr0, self.attr1, arg)\n\n" | ||
__transonic__ = ("0.6.3+editable",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# __protected__ from jax import jit | ||
import jax.numpy as np | ||
from __ext__func__exterior_import_boost import func_import | ||
|
||
# __protected__ @jit | ||
|
||
|
||
def func(a, b): | ||
return (a * np.log(b)).max() + func_import() | ||
|
||
|
||
__transonic__ = ("0.6.3+editable",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# __protected__ from jax import jit | ||
import jax.numpy as np | ||
|
||
# __protected__ @jit | ||
|
||
|
||
def block0(a, b, n): | ||
# foo | ||
# transonic block ( | ||
# float[][] a, b; | ||
# int n | ||
# ) bar | ||
# foo | ||
# transonic block ( | ||
# float[][][] a, b; | ||
# int n | ||
# ) | ||
# foobar | ||
result = np.zeros_like(a) | ||
for _ in range(n): | ||
result += a**2 + b**3 | ||
return result | ||
|
||
|
||
# __protected__ @jit | ||
|
||
|
||
def block1(a, b, n): | ||
# transonic block ( | ||
# float[][] a, b; | ||
# int n | ||
# ) | ||
# transonic block ( | ||
# float[][][] a, b; | ||
# int n | ||
# ) | ||
result = np.zeros_like(a) | ||
for _ in range(n): | ||
result += a**2 + b**3 | ||
return result | ||
|
||
|
||
arguments_blocks = {"block0": ["a", "b", "n"], "block1": ["a", "b", "n"]} | ||
__transonic__ = ("0.6.3+editable",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# __protected__ from jax import jit | ||
# __protected__ @jit | ||
|
||
|
||
def __for_method__Myclass__func(self_attr, self_attr2, arg): | ||
if __for_method__Myclass__func(self_attr, self_attr2, arg - 1) < 1: | ||
return 1 | ||
else: | ||
a = __for_method__Myclass__func( | ||
self_attr, self_attr2, arg - 1 | ||
) * __for_method__Myclass__func(self_attr, self_attr2, arg - 1) | ||
return ( | ||
a | ||
+ self_attr * self_attr2 * arg | ||
+ __for_method__Myclass__func(self_attr, self_attr2, arg - 1) | ||
) | ||
|
||
|
||
__code_new_method__Myclass__func = "\n\ndef new_method(self, arg):\n return backend_func(self.attr, self.attr2, arg)\n\n" | ||
__transonic__ = ("0.6.3+editable",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# __protected__ from jax import jit | ||
import jax.numpy as np | ||
|
||
# __protected__ @jit | ||
|
||
|
||
def func(a, b): | ||
return (a * np.log(b)).max() | ||
|
||
|
||
__transonic__ = ("0.6.3+editable",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# __protected__ from jax import jit | ||
# __protected__ @jit | ||
|
||
|
||
def func(a=1, b=None, c=1.0): | ||
print(b) | ||
return a + c | ||
|
||
|
||
__transonic__ = ("0.6.3+editable",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# __protected__ from jax import jit | ||
import jax.numpy as np | ||
|
||
# __protected__ @jit | ||
|
||
|
||
def __for_method__Transmitter____call__(self_arr, self_freq, inp): | ||
"""My docstring""" | ||
return (inp * np.exp(np.arange(len(inp)) * self_freq * 1j), self_arr) | ||
|
||
|
||
__code_new_method__Transmitter____call__ = "\n\ndef new_method(self, inp):\n return backend_func(self.arr, self.freq, inp)\n\n" | ||
__transonic__ = ("0.6.3+editable",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# __protected__ from jax import jit | ||
import jax.numpy as np | ||
|
||
# __protected__ @jit | ||
|
||
|
||
def func(a, b): | ||
return (a * np.log(b)).max() | ||
|
||
|
||
# __protected__ @jit | ||
|
||
|
||
def func1(a, b): | ||
return a * np.cos(b) | ||
|
||
|
||
__transonic__ = ("0.6.3+editable",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# __protected__ from jax import jit | ||
# __protected__ @jit | ||
|
||
|
||
def func(): | ||
return 1 | ||
|
||
|
||
# __protected__ @jit | ||
|
||
|
||
def func2(): | ||
return 1 | ||
|
||
|
||
__transonic__ = ("0.6.3+editable",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# __protected__ from jax import jit | ||
import jax.numpy as np | ||
|
||
# __protected__ @jit | ||
|
||
|
||
def row_sum(arr, columns): | ||
return arr.T[columns].sum(0) | ||
|
||
|
||
# __protected__ @jit | ||
|
||
|
||
def row_sum_loops(arr, columns): | ||
# locals type annotations are used only for Cython | ||
# arr.dtype not supported for memoryview | ||
dtype = type(arr[0, 0]) | ||
res = np.empty(arr.shape[0], dtype=dtype) | ||
for i in range(arr.shape[0]): | ||
sum_ = dtype(0) | ||
for j in range(columns.shape[0]): | ||
sum_ += arr[i, columns[j]] | ||
res[i] = sum_ | ||
return res | ||
|
||
|
||
__transonic__ = ("0.6.3+editable",) |
Oops, something went wrong.