12
12
from awkward ._nplikes .numpy import Numpy
13
13
from awkward ._nplikes .numpy_like import NumpyMetadata
14
14
from awkward ._nplikes .typetracer import try_touch_data
15
+ from awkward ._nplikes .virtual import materialize_if_virtual
15
16
from awkward ._typing import Protocol , TypeAlias
16
17
17
18
KernelKeyType : TypeAlias = tuple # Tuple[str, Unpack[Tuple[metadata.dtype, ...]]]
@@ -88,6 +89,8 @@ def _cast(cls, x, t):
88
89
def __call__ (self , * args ) -> None :
89
90
assert len (args ) == len (self ._impl .argtypes )
90
91
92
+ args = materialize_if_virtual (* args )
93
+
91
94
return self ._impl (
92
95
* (self ._cast (x , t ) for x , t in zip (args , self ._impl .argtypes ))
93
96
)
@@ -97,6 +100,8 @@ class JaxKernel(NumpyKernel):
97
100
def __call__ (self , * args ) -> None :
98
101
assert len (args ) == len (self ._impl .argtypes )
99
102
103
+ args = materialize_if_virtual (* args )
104
+
100
105
if not any (Jax .is_tracer_type (type (arg )) for arg in args ):
101
106
return super ().__call__ (* args )
102
107
@@ -138,6 +143,8 @@ def _cast(self, x, type_):
138
143
def __call__ (self , * args ) -> None :
139
144
import awkward ._connect .cuda as ak_cuda
140
145
146
+ args = materialize_if_virtual (* args )
147
+
141
148
cupy = ak_cuda .import_cupy ("Awkward Arrays with CUDA" )
142
149
maxlength = self .max_length (args )
143
150
grid , blocks = self .calc_grid (maxlength ), self .calc_blocks (maxlength )
0 commit comments