Skip to content
This repository has been archived by the owner on Nov 7, 2024. It is now read-only.

backend.item in MPS calculation is incompatible with autograd in jax #959

Open
SUSYUSTC opened this issue Jan 21, 2022 · 2 comments
Open

Comments

@SUSYUSTC
Copy link

In file https://github.com/google/TensorNetwork/blob/master/tensornetwork/matrixproductstates/base_mps.py,
line 319: res.append(self.backend.item(result.tensor))
and line 479 return [self.backend.item(o) for o in c],
the using of self.backend.item is incompatible with autograd in jax (and maybe also other backends).
I haven't checked with other files so those files might have similar issues.
Here's a simple example:

import tensornetwork as tn
import numpy as np
import jax
tn.set_default_backend('jax')
Z = jax.numpy.asarray(np.array([[1.0, 0.0], [0.0, -1.0]], dtype=np.complex64))


def func(x):
    mps = tn.FiniteMPS.random([2, 2, 2, 2], [4, 4, 4], dtype=np.complex64)
    gate = jax.scipy.linalg.expm(Z * x)
    e = mps.measure_local_operator([gate], [0])
    return e[0]


print(func(1.0))                 # output: (1.2248424291610718-2.9802322387695312e-08j)
vg = jax.value_and_grad(func)
print(vg(1.0))                   # error: AttributeError: 'ConcreteArray' object has no attribute 'item'
@mganahl
Copy link
Contributor

mganahl commented Jan 21, 2022

hi, and thanks for the message!
Can you post the full error message as well? thanks!

@SUSYUSTC
Copy link
Author

hi, and thanks for the message! Can you post the full error message as well? thanks!

The full output is

(0.7063742876052856-1.4842953532934189e-08j)
Traceback (most recent call last):
  File "a.py", line 17, in <module>
    print(vg(1.0))                   # error: AttributeError: 'ConcreteArray' object has no attribute 'item'
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/_src/api.py", line 993, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/_src/api.py", line 2313, in _vjp
    flat_fun, primals_flat, reduce_axes=reduce_axes)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/interpreters/ad.py", line 116, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/interpreters/ad.py", line 103, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 513, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "a.py", line 11, in func
    e = mps.measure_local_operator([gate], [0])
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/matrixproductstates/base_mps.py", line 319, in measure_local_operator
    res.append(self.backend.item(result.tensor))
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/backends/jax/jax_backend.py", line 878, in item
    return tensor.item()
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/core.py", line 568, in __getattr__
    attr = getattr(self.aval, name)
jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ConcreteArray' object has no attribute 'item'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "a.py", line 17, in <module>
    print(vg(1.0))                   # error: AttributeError: 'ConcreteArray' object has no attribute 'item'
  File "a.py", line 11, in func
    e = mps.measure_local_operator([gate], [0])
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/matrixproductstates/base_mps.py", line 319, in measure_local_operator
    res.append(self.backend.item(result.tensor))
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/backends/jax/jax_backend.py", line 878, in item
    return tensor.item()
AttributeError: 'ConcreteArray' object has no attribute 'item'

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants