Open
Description
Description
def my_func(arg):
return arg[:3] * jnp.log(jnp.cos(arg[:3]))
res = jax.jacobian(my_func)(jnp.array([6, 7, 8], dtype=np.float64))
Jax function returns
Array([[ 1.70539252, -0. , nan],
[ 0. , -6.38262843, nan],
[ 0. , -0. , nan]], dtype=float64)
But the nan in the first two rows should be 0.
Would you please take a look?
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.6.0
jaxlib: 0.6.0
numpy: 2.0.2
python: 3.11.10 (main, Oct 3 2024, 07:29:13) [GCC 11.2.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='panda', release='5.15.167.4-microsoft-standard-WSL2', version='#1 SMP Tue Nov 5 00:21:55 UTC 2024', machine='x86_64')
$ nvidia-smi
Sun Jun 1 10:51:19 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.55.01 Driver Version: 576.40 CUDA Version: 12.9 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3070 ... On | 00000000:01:00.0 On | N/A |
| N/A 67C P0 49W / 120W | 2113MiB / 8192MiB | 1% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |