1212from awkward ._nplikes .numpy import Numpy
1313from awkward ._nplikes .numpy_like import NumpyMetadata
1414from awkward ._nplikes .typetracer import try_touch_data
15+ from awkward ._nplikes .virtual import materialize_if_virtual
1516from awkward ._typing import Protocol , TypeAlias
1617
1718KernelKeyType : TypeAlias = tuple # Tuple[str, Unpack[Tuple[metadata.dtype, ...]]]
@@ -88,6 +89,8 @@ def _cast(cls, x, t):
8889 def __call__ (self , * args ) -> None :
8990 assert len (args ) == len (self ._impl .argtypes )
9091
92+ args = materialize_if_virtual (* args )
93+
9194 return self ._impl (
9295 * (self ._cast (x , t ) for x , t in zip (args , self ._impl .argtypes ))
9396 )
@@ -97,6 +100,8 @@ class JaxKernel(NumpyKernel):
97100 def __call__ (self , * args ) -> None :
98101 assert len (args ) == len (self ._impl .argtypes )
99102
103+ args = materialize_if_virtual (* args )
104+
100105 if not any (Jax .is_tracer_type (type (arg )) for arg in args ):
101106 return super ().__call__ (* args )
102107
@@ -138,6 +143,8 @@ def _cast(self, x, type_):
138143 def __call__ (self , * args ) -> None :
139144 import awkward ._connect .cuda as ak_cuda
140145
146+ args = materialize_if_virtual (* args )
147+
141148 cupy = ak_cuda .import_cupy ("Awkward Arrays with CUDA" )
142149 maxlength = self .max_length (args )
143150 grid , blocks = self .calc_grid (maxlength ), self .calc_blocks (maxlength )
0 commit comments