Skip to content

Enable more efficient batching over multiple (co)tangent vectors #244

@jpbrodrick89

Description

@jpbrodrick89

Summary

Provide the ability to provide multiple (co)tangent vectors and efficiently calculate the corresponding jvp's/vjp's with a single HTTP request to the Tesseract.

Why is this needed?

It is very common to require a jvp or vjp with respect to multiple (co)tangent vectors. For example, when calculating a Jacobian to perform a Newton solve. While the jacobian endpoint already exists, this is never called by our open source pipeline builder tesseract-jax; instead either vector_jacobian_product (jax.jacobian = jax.jacrev) or jacobian_vector_product (jax.jacfwd) is called multiple times sequentially. This is inefficient for multiple reasons:

  • Parallelisation/SIMD is not possible as the calls are all made synchronously.
  • We suffer any latency due to HTTP requests, etc. multiple times
  • We lose the ability of taking advantage of any tricks such as would be provided jax.vmap to avoid unnecessary re-computation of primals on the forward pass.

While allowing async calls with pmap and num_workers > 1 should be possible and fix the first two issues and re-writing the batching rule to instead call jax.jacobian and calculate a matrix-matrix product could work, enabling jacobian_vector_product and vector_jacobian_product would be more efficient in most cases (e.g. if no. of tangent vectors is > 1 and <<n).

Usage example

My proposal would be to allow jacobian_vector_product and vector_jacobian_product to allow (at least one) additional dimension on the (co)tangent vector inputs. We may decide to insist on the additional dimensions being added to all elements, but if not they could be expanded to. The Tesseract would look at the shape of the (co)tangent vector and work out whether an additional dimensions has been added and if so then call a batched version of itself.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions