Skip to content

Commit 55a7540

Browse files
authored
Add an option to not use dlpack. (#9304)
1 parent 20899c7 commit 55a7540

File tree

15 files changed

+119
-94
lines changed

15 files changed

+119
-94
lines changed

benchmarks/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ def is_xla_device_available(devkind, use_xla2: bool = False):
7272

7373
def move_to_device(item, device, torch_xla2: bool = False):
7474
if torch_xla2:
75-
import torch_xla2
75+
import torchax
7676
import jax
7777
move_to_device_func = lambda t: jax.device_put(
78-
torch_xla2.tensor.t2j(t), device)
78+
torchax.default_env().t2j_copy(t), device)
7979
else:
8080

8181
def move_to_device_func(tensor: torch.Tensor) -> torch.Tensor:

torchax/examples/train_llama/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ def _shard_fsdp_style(self, state_dict, sharding=None):
200200
sharding = self.x_sharding
201201

202202
def move_one_tensor(x):
203-
jval = torchax.tensor.t2j(x)
203+
env = torchax.default_env()
204+
jval = env.t2j_copy(x)
204205
return sharded_device_put(jval, sharding)
205206

206207
if isinstance(state_dict, torch.Tensor):

torchax/test/gemma/test_gemma.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def test_gemma(self):
7171
)
7272

7373
weights, jax_func = torchax.extract_jax(model)
74-
inputs_jax = pytree.tree_map_only(torch.Tensor, torchax.tensor.t2j, inputs)
74+
env = torchax.default_env()
75+
inputs_jax = env.t2j_copy(inputs)
7576

7677
import jax
7778
print(jax.jit(jax_func)(weights, inputs_jax))

torchax/test/llama/test_llama.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def make_cache(args, batch_size):
8888
m_prefill = torch.export.export(m, sample_input_prefill)
8989

9090
weights, mj_prefill = torchax.export.exported_program_to_jax(m_prefill)
91-
sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j,
92-
sample_input_prefill)
91+
env = torchax.default_env()
92+
sample_inputs = env.t2j_copy(sample_input_prefill)
9393
print('Prefill', mj_prefill(weights, sample_inputs))
9494

9595
sample_input_decode = (
@@ -103,8 +103,7 @@ def make_cache(args, batch_size):
103103
with torch.no_grad():
104104
m_decode = torch.export.export(m, sample_input_decode)
105105
weights, mj_decode = torchax.export.exported_program_to_jax(m_decode)
106-
sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j,
107-
sample_input_decode)
106+
sample_inputs = env.t2j_copy(sample_input_decode)
108107
print('Decode', mj_decode(weights, sample_inputs))
109108

110109

torchax/test/moe/moe_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_moe_layer(self):
4848
x_xla = env.to_xla(x)
4949
with jax.default_matmul_precision('float32'):
5050
res_xla = model_xla(x_xla)
51-
res2 = torchax.tensor.j2t(res_xla._elem)
51+
res2 = res_xla.to('cpu')
5252
print('max diff', torch.max((res - res2).abs()))
5353

5454
self.assertTrue(torch.allclose(res2, res, atol=1e-2))

torchax/test/test_context.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ def test_same_manual_seed(self):
4747
y = torch.randn((3, 3))
4848
self.assertIsInstance(y, tensor.Tensor)
4949

50-
self.assertTrue(
51-
torch.equal(torchax.tensor.j2t(x._elem), torchax.tensor.j2t(y._elem)))
50+
self.assertTrue(torch.allclose(x, y))
5251

5352
def test_different_manual_seed(self):
5453
with xla_env:
@@ -60,36 +59,30 @@ def test_different_manual_seed(self):
6059
y = torch.randn((3, 3))
6160
self.assertIsInstance(y, tensor.Tensor)
6261

63-
self.assertFalse(
64-
torch.equal(torchax.tensor.j2t(x._elem), torchax.tensor.j2t(y._elem)))
62+
self.assertFalse(torch.allclose(x, y))
6563

6664
def test_jit_with_rng(self):
6765

68-
@xla_env
69-
def random_op():
70-
x = torch.randn(3, 3)
71-
y = torch.randn(3, 3)
72-
return x @ y
66+
with xla_env:
67+
68+
def random_op():
69+
x = torch.randn(3, 3)
70+
y = torch.randn(3, 3)
71+
return x @ y
7372

74-
random_jit = torchax.interop.jax_jit(random_op)
75-
self.assertIsInstance(random_jit(), tensor.Tensor)
73+
random_jit = torchax.interop.jax_jit(random_op)
74+
self.assertIsInstance(random_jit(), tensor.Tensor)
7675

77-
# Result always expected to be the same for a jitted function because seeds
78-
# are baked in
79-
torch.testing.assert_close(
80-
torchax.tensor.j2t(random_jit()._elem),
81-
torchax.tensor.j2t(random_jit()._elem),
82-
atol=0,
83-
rtol=0)
76+
# Result always expected to be the same for a jitted function because seeds
77+
# are baked in
78+
torch.testing.assert_close(random_jit(), random_jit(), atol=0, rtol=0)
8479

8580
def test_generator_seed(self):
8681
with xla_env:
8782
x = torch.randn(2, 3, generator=torch.Generator().manual_seed(0))
8883
y = torch.randn(2, 3, generator=torch.Generator().manual_seed(0))
8984

90-
# Values will be different, but still check device, layout, dtype, etc
91-
torch.testing.assert_close(
92-
torchax.tensor.j2t(x._elem), torchax.tensor.j2t(y._elem))
85+
torch.testing.assert_close(x, y)
9386

9487
def test_buffer(self):
9588

torchax/test/test_conv.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,24 +55,26 @@ def forward(self, x):
5555
class ConvTest(base_test_util.TestCase):
5656

5757
def test_conv1(self):
58+
env = torchax.default_env()
5859
m = CustomConv1()
5960
arg = torch.randn((20, 1, 50))
6061
res = m(arg)
6162

6263
jax_weights, jax_func = torchax.extract_jax(m)
63-
arg = torchax.tensor.t2j(arg)
64+
arg = env.t2j_copy(arg)
6465
res2 = jax_func(jax_weights, (arg,))
65-
res2_torch = torchax.tensor.j2t(res2)
66+
res2_torch = env.j2t_copy(res2)
6667
self.assertTrue(torch.allclose(res, res2_torch))
6768

6869
def test_conv2(self):
70+
env = torchax.default_env()
6971
m = CustomConv2()
7072
arg = torch.randn((20, 4, 50, 100))
7173
res = m(arg)
7274
jax_weights, jax_func = torchax.extract_jax(m)
73-
arg = torchax.tensor.t2j(arg)
75+
arg = env.t2j_copy(arg)
7476
res2 = jax_func(jax_weights, (arg,))
75-
res2_torch = torchax.tensor.j2t(res2)
77+
res2_torch = env.j2t_copy(res2)
7678
self.assertTrue(torch.allclose(res, res2_torch, atol=1e-4, rtol=1e-4))
7779

7880

torchax/test/test_exports.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,14 @@ def test_interpolate(self):
4343
model = Interpolate()
4444
ans = model(*arg)
4545

46+
env = torchax.default_env()
47+
4648
with torch.no_grad():
4749
exported = torch.export.export(model, arg)
4850
weights, func = torchax.export.exported_program_to_jax(exported)
49-
argj = tensor.t2j(arg[0])
51+
argj = env.t2j_copy(arg[0])
5052
ans2 = jax.jit(func)(weights, (argj,))[0]
51-
ans2 = tensor.j2t(ans2)
53+
ans2 = env.j2t_copy(ans2)
5254
self.assertTrue(torch.allclose(ans, ans2, atol=1e-3))
5355

5456
# Convert to StableHLO
@@ -67,11 +69,11 @@ def test_constant(self):
6769

6870
with torch.no_grad():
6971
exported = torch.export.export(model, arg)
70-
72+
env = torchax.default_env()
7173
weights, func = torchax.export.exported_program_to_jax(exported)
72-
argj = tensor.t2j(arg[0])
74+
argj = env.t2j_copy(arg[0])
7375
ans2 = jax.jit(func)(weights, (argj,))[0]
74-
ans2 = tensor.j2t(ans2)
76+
ans2 = env.j2t_copy(ans2)
7577
self.assertTrue(torch.allclose(ans, ans2, atol=1e-5))
7678

7779
# Convert to StableHLO

torchax/test/test_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_tensor_constructor(self, arg, kwargs=None):
4040
actual = torch.tensor(arg, device='jax', **kwargs)
4141
self.assertIsInstance(actual, torchax.tensor.Tensor)
4242

43-
torch.testing.assert_close(torchax.tensor.j2t(actual._elem), expected)
43+
torch.testing.assert_close(actual.to('cpu'), expected)
4444

4545
def test_dont_capture_conversion(self):
4646
t = torch.tensor([1, 2, 3])
@@ -86,7 +86,7 @@ def test_rms_norm(self):
8686
model.to('jax')
8787
x = x.to('jax')
8888
res2 = model(x)
89-
self.assertTrue(torch.allclose(res, torchax.tensor.j2t(res2.jax())))
89+
self.assertTrue(torch.allclose(res, res2.to('cpu')))
9090

9191
def test_randn_requires_grad(self):
9292
x = torch.randn((3, 3), requires_grad=True, device='jax')

torchax/test/test_image.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,6 @@
1010
import torchax.interop
1111

1212

13-
def to_xla_tensor(tensorstree):
14-
return torchax.interop.torch_view(torchax.tensor.t2j(tensorstree))
15-
16-
17-
def to_torch_tensor(tensorstree):
18-
return torchax.tensor.j2t(torchax.interop.jax_view(tensorstree))
19-
20-
2113
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
2214
def upsample_jit(tensor, output_size: Tuple[int, int], align_corners: bool,
2315
antialias: bool, method: str):
@@ -53,8 +45,9 @@ def test_resampling_combinations_bicubic(self, antialias, align_corners):
5345
align_corners=align_corners,
5446
antialias=antialias)
5547

56-
with torchax.default_env():
57-
input_tensor_xla = to_xla_tensor(input_tensor)
48+
env = torchax.default_env()
49+
with env:
50+
input_tensor_xla = env.to_xla(input_tensor)
5851
input_tensor_xla = torchax.interop.jax_view(input_tensor_xla)
5952
upsampled_tensor_xla = upsample_jit(
6053
input_tensor_xla,
@@ -63,7 +56,7 @@ def test_resampling_combinations_bicubic(self, antialias, align_corners):
6356
antialias=antialias,
6457
method=method)
6558

66-
upsampled_tensor_xla = to_torch_tensor(upsampled_tensor_xla)
59+
upsampled_tensor_xla = env.j2t_copy(upsampled_tensor_xla)
6760
abs_err = torch.abs(upsampled_tensor - upsampled_tensor_xla)
6861

6962
assert torch.allclose(

torchax/torchax/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@ def extract_jax(mod: torch.nn.Module, env=None):
4343
"""Returns a pytree of jax.ndarray and a jax callable."""
4444
if env is None:
4545
env = default_env()
46-
states = mod.state_dict()
46+
states = dict(mod.named_buffers())
47+
states.update(mod.named_parameters())
4748

48-
states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)
49+
states = env.t2j_copy(states)
4950

5051
#@jax.jit
5152
def jax_func(states, inputs):

torchax/torchax/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ class Configuration:
1313
# If true, we will convert Views into torchax.Tensors eagerly
1414
force_materialize_views: bool = False
1515

16+
# Use DLPack for converting jax.Arrays <-> and torch.Tensor
17+
use_dlpack_for_data_conversion: bool = False
18+
1619
# Flash attention
1720
use_tpu_flash_attention: bool = False
1821
shmap_flash_attention: bool = False

torchax/torchax/export.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from typing import Any, Dict, Tuple
55
import torch
66
from torch.utils import _pytree as pytree
7+
import torchax
78
from torchax import tensor
8-
from torchax.ops import ops_registry
9+
from torchax.ops import ops_registry, mappings
910
from torchax import decompositions
1011
import jax
1112
import jax.export
@@ -108,8 +109,8 @@ def func(states, inputs):
108109

109110
if export_raw:
110111
return names, states, func
111-
112-
states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)
112+
env = torchax.default_env()
113+
states = env.t2j_copy(states)
113114
return states, func
114115

115116

@@ -135,7 +136,7 @@ def _get_dim(d):
135136

136137
tensor_meta = arg_meta['tensor_meta']
137138
shape = [_get_dim(d) for d in tensor_meta.shape]
138-
return jax.ShapeDtypeStruct(shape, tensor.t2j_dtype(tensor_meta.dtype))
139+
return jax.ShapeDtypeStruct(shape, mappings.t2j_dtype(tensor_meta.dtype))
139140

140141
def _get_inputs(exported):
141142
"""Return placeholders with input metadata"""

torchax/torchax/ops/mappings.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.utils._mode_utils as mode_utils
88

99

10-
def t2j(t):
10+
def t2j(t, use_dlpack=True):
1111
is_bool = False
1212
if t.dtype == torch.bool:
1313
is_bool = True
@@ -18,9 +18,14 @@ def t2j(t):
1818
if not t.is_contiguous():
1919
t = t.contiguous()
2020

21-
try:
22-
res = jaxdl.from_dlpack(t)
23-
except Exception:
21+
res = None
22+
if use_dlpack:
23+
try:
24+
res = jaxdl.from_dlpack(t)
25+
except Exception:
26+
pass
27+
28+
if res is None:
2429
# https://github.com/google/jax/issues/7657
2530
# https://github.com/google/jax/issues/17784
2631
if t.dtype == torch.bfloat16:
@@ -37,15 +42,29 @@ def t2j(t):
3742
return res
3843

3944

40-
def j2t(x):
45+
def j2t(x, use_dlpack=True):
4146
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
42-
try:
43-
dl = jaxdl.to_dlpack(x)
44-
res = torchdl.from_dlpack(dl)
45-
except Exception:
47+
res = None
48+
if use_dlpack:
49+
try:
50+
dl = jaxdl.to_dlpack(x)
51+
res = torchdl.from_dlpack(dl)
52+
except Exception:
53+
res = None
54+
55+
orig_dtype = None
56+
if res is None:
57+
orig_dtype = None
58+
if x.dtype == jnp.bfloat16.dtype:
59+
orig_dtype = x.dtype
60+
x = x.astype(jnp.float32.dtype)
4661
res = torch.from_numpy(numpy.asarray(x))
62+
4763
if x.dtype == jnp.bool_:
4864
res = res.to(torch.bool)
65+
66+
if orig_dtype is not None:
67+
res = res.to(j2t_dtype(orig_dtype))
4968
return res
5069

5170

0 commit comments

Comments
 (0)