A "lazy" / "meta" implementation of the array api? #777
Replies: 11 comments 2 replies
-
This would be super useful indeed. It's not a small amount of work I suspect. For indexing there is https://github.com/Quansight-Labs/ndindex/, which basically implements this "meta" idea. That's probably one of the most hairy parts to do, and a good start. But correctly doing all shape calculations for all functions in the API is also a large job. Perhaps others know of reusable functionality elsewhere for this? For PyTorch I believe it's too much baked into the library to be able to reuse it standalone. |
Beta Was this translation helpful? Give feedback.
-
Thanks a lot @rgommers for the response!
I only partly agree, because there is nothing particular difficult about it. The expected behavior is well defined, the API is already defined, so there is no tricky code to be figured out. It is just a matter of implementing the already defined behavior "dilligently".
This is indeed a great start, I was not aware of this project.
I think the effort could actually be limited, because looking at https://github.com/numpy/numpy/tree/main/numpy/array_api, the files already pre-group the api into operations with the same behavior in terms of shape computation, i.e. element wise, indexing, searching, statistical, etc. For each group the behavior only needs to defined once, the rest is filling in boiler plate code. In addition there is broadcasting and indexing, which always applies. I'm less sure about the dtype promotion, but this must have been coded somewhere already as well.
I agree PyTorch is already to large of a dependency. From a quick search I only found https://github.com/NeuralEnsemble/lazyarray, which seems to be un-maintained. It also has a different approach of building a graph and then delay the evaluation. |
Beta Was this translation helpful? Give feedback.
-
I'd like to get a better idea of the actual implementation effort and just share some more thoughts on this idea.
|
Beta Was this translation helpful? Give feedback.
-
Is this not a valid shape? The spec for
|
Beta Was this translation helpful? Give feedback.
-
Indeed this is already part of the spec. I missed that before! |
Beta Was this translation helpful? Give feedback.
-
Actually I'd be interested in starting a repo and playing around with this a bit. Any preference for a name @rgommers or @lucascolley? What about |
Beta Was this translation helpful? Give feedback.
-
@adonath Xarray has an internal version of exactly this, which actually gets used by default every time you use xarray to open data from disk without dask installed. We have had an open issue about the idea of lifting this functionality out into a separate library for years: pydata/xarray#5081 Our lazy indexing classes aren't very well publicised but there are some docs here, and the implementation is all in this file . The key object is our The one (major) limitation is that our implementation only supports indexing, not its dual concatenation, so would eagerly compute if you tried to call
This would be amazing! If there is interest from other parties in creating this package I think xarray would be very interested to collaborate & could have a lot to offer.
|
Beta Was this translation helpful? Give feedback.
-
(I got an email saying @dcherian posted this comment which now I can't find)
https://autoray.readthedocs.io/en/latest/lazy_computation.html seems relevant.
---
(my reply)
Investigating the computational graph, including cost and memory usage,
of a calculation ahead of time.
This is also interesting because it's the same basic insight that Cubed is
based around. i.e. that if you know the shape and dtype before evaluation, you also
know the size, and hence the memory usage.
https://github.com/cubed-dev/cubed
|
Beta Was this translation helpful? Give feedback.
-
We recently open-sourced our ONNX-backed lazy implementation of the array API, ndonnx. Being ONNX-backed means that it is 100% lazy yet provides full type and shape inference for every operation. |
Beta Was this translation helpful? Give feedback.
-
Hey - I just saw this. For what it's worth, JAX is built around the concept of abstract evaluation, which I think is what is being described here: a way to compute output shapes and dtypes without doing any computation. JAX v0.4.32 will also include Array API support in its default namespace. So if you'd like a way to do this kind of lazy/abstract evaluation over array API implementations using an existing package, you can use In [1]: import jax
In [2]: def f(x):
...: xp = x.__array_namespace__() # built-in in JAX v0.4.32 or newer
...: return xp.outer(x, x[:-1])
...:
In [3]: x = jax.numpy.arange(4)
In [4]: jax.eval_shape(f, x)
Out[4]: ShapeDtypeStruct(shape=(4, 3), dtype=int32) |
Beta Was this translation helpful? Give feedback.
-
Thanks everyone for the numerous and diverse comments! The initial motivation for this issue was to potentially initiate work on a dedicated package, however it became obvious that there have already been multiple efforts from the community to implement similar functionality. There are many mentioned in this discussion thread, here is a short summary:
It seems all options are more or less compatible with the array-api. If not yet, they will be soon. I think at this point there is no strong motivation for another implementation, as potential users can just make their choice based on the existing options. Of course there is the question of unification / consolidating efforts, but the functionality is typically bound to the actual array implementation. Especially with regards to device handling and additional functionality, not support by the array API. The abstract evaluation step, is also typically the step where a computational graph is built for compilation, which is again tightly coupled to the actual implementation of the array API. So It does not make sense for existing packages to change to an independent implementation. I think the only remaining motivation for an independent package might be the minimal dependency / stand-alone approach. I would keep the discussion open for now and let others comment, as they might add new options and this discussion and for now it becomes more of an entry point for users. |
Beta Was this translation helpful? Give feedback.
-
In addition to the already available implementations of the array api I think it could be interesting to have a lazy / meta implementation of the standard. What I mean is a small, minimal dependency, standalone library, compatible with the array api, that provides inference of the shape and dtype of resulting arrays, without ever initializing the data and executing any flops.
PyTorch already has something like this with the
"meta"
device. For example:However this misses for example the device handling, as the device is constrained to
"meta"
. I presume that dask must have something very similar. Jax also must have something very similar for the jitted computations. However I think it is only exposed to users with a different API viajax.eval_shape()
and not via an "meta" array object.Similarly to the torch example one would use a hypothetical library
lazy_array_api
:The use case I have in mind is mostly debugging, validation and testing of computational intense algorithms ("dry runs"). For now I just wanted to share the idea and bring it up for discussion.
Beta Was this translation helpful? Give feedback.
All reactions