-
Notifications
You must be signed in to change notification settings - Fork 706
[WIP] Bump jax 0.7.2 #8604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[WIP] Bump jax 0.7.2 #8604
Conversation
- Remove custom_partial_eval_rule from _capture_qnode.py which was causing finite-diff to fail with JAX 0.7.2 - Document all 4 patches in jax_patches.py as REQUIRED (verified via testing) - Remove unnecessary xfails in finite-diff and JVP tests - All capture tests pass (1583/1583)
- Comment out external-libraries-tests from all-tests-passed and upload-reports needs - The test job itself already has conditional skip logic - This allows CI to pass without waiting for external library tests
| - core-tests | ||
| - all-interfaces-tests | ||
| - external-libraries-tests | ||
| # - external-libraries-tests # Skipped for JAX 0.7.2 migration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
temp skip until 0.7.2 is fully compatible with catalyst
Co-authored-by: Hong-Sheng Zheng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is also needed due to the new Literal in jax 0.7.2
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #8604 +/- ##
==========================================
- Coverage 99.43% 96.40% -3.03%
==========================================
Files 587 588 +1
Lines 61879 62162 +283
==========================================
- Hits 61529 59927 -1602
- Misses 350 2235 +1885 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Somethings should also be updated:
def _tuple_to_slice(t):
"""Convert a tuple representation of a slice back to a slice object.
JAX converts slice objects to tuples for hashability in jaxpr parameters.
This function converts them back to slice objects for use with indexing.
Args:
t: Either a slice object (returned as-is) or a tuple (start, stop, step)
Returns:
slice: A slice object
"""
if isinstance(t, tuple) and len(t) == 3:
return slice(*t)
return t
def _is_dict_like_tuple(t):
"""Checks if a tuple t is structured like a list of (key, value) pairs."""
return isinstance(t, tuple) and all(isinstance(item, tuple) and len(item) == 2 for item in t)
def _tuple_to_dict(t):
"""
Recursively converts JAX-hashable tuple representations back to dicts,
and list-like tuples back to lists.
Args:
t: The item to convert. Can be a dict, a tuple, or a scalar.
Returns:
The converted dict, list, or the original scalar value.
"""
if not isinstance(t, (dict, tuple, list)):
return t
if isinstance(t, dict):
return {k: _tuple_to_dict(v) for k, v in t.items()}
if isinstance(t, list):
return [_tuple_to_dict(item) for item in t]
if isinstance(t, tuple):
# A. Dict-like tuple: Convert to dict, then recurse on values
if _is_dict_like_tuple(t):
# This handles the main (key, value) pair structure
return {key: _tuple_to_dict(value) for key, value in t}
# B. List-like tuple: Convert to list, then recurse on elements
else:
return [_tuple_to_dict(item) for item in t]
return t
@ExpandTransformsInterpreter.register_primitive(_create_transform_primitive())
def _(
self, *invals, inner_jaxpr, args_slice, consts_slice, targs_slice, tkwargs, transform
): # pylint: disable=too-many-arguments
args = invals[_tuple_to_slice(args_slice)]
consts = invals[_tuple_to_slice(consts_slice)]
targs = invals[_tuple_to_slice(targs_slice)]
tkwargs = _tuple_to_dict(tkwargs)
def wrapper(*inner_args):
return copy(self).eval(inner_jaxpr, consts, *inner_args)
jaxpr = jax.make_jaxpr(wrapper)(*args)
jaxpr = transform.plxpr_transform(jaxpr.jaxpr, jaxpr.consts, targs, tkwargs, *args)
return copy(self).eval(jaxpr.jaxpr, jaxpr.consts, *args)
length = qml.math.size(state_vector)
global_phase = qml.math.sum(-1 * qml.math.angle(state_vector) / length)
self.queue[op_hash] = [coeff, summand]And cannot compare tracer with bool: if coeff == 1: |
Context:
It seems that jax 0.7.2 has some "good" feature that we don't need const var anymore.
Description of the Change:
Benefits:
Possible Drawbacks:
Related GitHub Issues: