Skip to content

Jax jacobian returns unexpected nan #29152

Open
@PcdPanda

Description

@PcdPanda

Description

def my_func(arg):
    return arg[:3] * jnp.log(jnp.cos(arg[:3]))

res = jax.jacobian(my_func)(jnp.array([6, 7, 8], dtype=np.float64))

Jax function returns

Array([[ 1.70539252, -0.        ,         nan],
       [ 0.        , -6.38262843,         nan],
       [ 0.        , -0.        ,         nan]], dtype=float64)

But the nan in the first two rows should be 0.
Would you please take a look?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.6.0
jaxlib: 0.6.0
numpy:  2.0.2
python: 3.11.10 (main, Oct  3 2024, 07:29:13) [GCC 11.2.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='panda', release='5.15.167.4-microsoft-standard-WSL2', version='#1 SMP Tue Nov 5 00:21:55 UTC 2024', machine='x86_64')


$ nvidia-smi
Sun Jun  1 10:51:19 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.55.01              Driver Version: 576.40         CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3070 ...    On  |   00000000:01:00.0  On |                  N/A |
| N/A   67C    P0             49W /  120W |    2113MiB /   8192MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |

Metadata

Metadata

Assignees

Labels

questionQuestions for the JAX team

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions