Skip to content

Commit 1eb5c3f

Browse files
committed
try kernels
1 parent 2ab00f7 commit 1eb5c3f

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

src/awkward/_kernels.py

+7
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from awkward._nplikes.numpy import Numpy
1313
from awkward._nplikes.numpy_like import NumpyMetadata
1414
from awkward._nplikes.typetracer import try_touch_data
15+
from awkward._nplikes.virtual import materialize_if_virtual
1516
from awkward._typing import Protocol, TypeAlias
1617

1718
KernelKeyType: TypeAlias = tuple # Tuple[str, Unpack[Tuple[metadata.dtype, ...]]]
@@ -88,6 +89,8 @@ def _cast(cls, x, t):
8889
def __call__(self, *args) -> None:
8990
assert len(args) == len(self._impl.argtypes)
9091

92+
args = materialize_if_virtual(*args)
93+
9194
return self._impl(
9295
*(self._cast(x, t) for x, t in zip(args, self._impl.argtypes))
9396
)
@@ -97,6 +100,8 @@ class JaxKernel(NumpyKernel):
97100
def __call__(self, *args) -> None:
98101
assert len(args) == len(self._impl.argtypes)
99102

103+
args = materialize_if_virtual(*args)
104+
100105
if not any(Jax.is_tracer_type(type(arg)) for arg in args):
101106
return super().__call__(*args)
102107

@@ -138,6 +143,8 @@ def _cast(self, x, type_):
138143
def __call__(self, *args) -> None:
139144
import awkward._connect.cuda as ak_cuda
140145

146+
args = materialize_if_virtual(*args)
147+
141148
cupy = ak_cuda.import_cupy("Awkward Arrays with CUDA")
142149
maxlength = self.max_length(args)
143150
grid, blocks = self.calc_grid(maxlength), self.calc_blocks(maxlength)

src/awkward/_nplikes/numpy.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from awkward._nplikes.dispatch import register_nplike
99
from awkward._nplikes.numpy_like import NumpyMetadata
1010
from awkward._nplikes.placeholder import PlaceholderArray
11+
from awkward._nplikes.virtual import VirtualArray
1112
from awkward._typing import TYPE_CHECKING, Final, Literal
1213

1314
if TYPE_CHECKING:
@@ -48,7 +49,8 @@ def is_own_array_type(cls, type_: type) -> bool:
4849
return issubclass(type_, numpy.ndarray)
4950

5051
def is_c_contiguous(self, x: NDArray | PlaceholderArray) -> bool:
51-
if isinstance(x, PlaceholderArray):
52+
# TODO: What should this do for virtual arrays?
53+
if isinstance(x, (PlaceholderArray, VirtualArray)):
5254
return True
5355
else:
5456
return x.flags["C_CONTIGUOUS"] # type: ignore[attr-defined]

0 commit comments

Comments
 (0)