Skip to content

Commit 0be0fbd

Browse files
authored
Improve assume_pure SPMD functionality (#9360)
1 parent 55a7540 commit 0be0fbd

File tree

4 files changed

+123
-12
lines changed

4 files changed

+123
-12
lines changed

test/test_assume_pure_spmd.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,26 @@
1+
from copy import copy
12
import os
23
import sys
34
import unittest
45

56
import numpy as np
67
import torch
8+
import torch.nn as nn
79
import torch_xla
810
import torch_xla.runtime as xr
9-
from torch_xla.experimental.assume_pure import assume_pure
11+
from torch_xla.experimental.assume_pure import PureModule, assume_pure
1012
from torch_xla.distributed.spmd import mark_sharding, mark_sharding_with_gradients, set_global_mesh, get_1d_mesh, Mesh
13+
from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear
14+
15+
16+
def get_2d_mesh(name1: str, name2: str):
17+
num_devices = xr.global_runtime_device_count()
18+
dim1_size = 2
19+
assert num_devices % 2 == 0
20+
dim2_size = num_devices // dim1_size
21+
devices = np.arange(xr.global_runtime_device_count())
22+
mesh_shape = (dim1_size, dim2_size)
23+
return Mesh(devices, mesh_shape=mesh_shape, axis_names=(name1, name2))
1124

1225

1326
class AssumePureSpmdTest(unittest.TestCase):
@@ -56,6 +69,44 @@ def test_assume_pure_works_with_mark_sharding_with_gradients(self):
5669
self.assertIn(f'devices=[{N}',
5770
torch_xla._XLAC._get_xla_sharding_spec(x.grad))
5871

72+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
73+
"Multiple devices required")
74+
@unittest.skipIf(
75+
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
76+
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
77+
)
78+
def test_assume_pure_works_with_mark_sharding_nested(self):
79+
mesh = get_2d_mesh("model", "batch")
80+
set_global_mesh(mesh)
81+
x = torch.randn((8, 4, 5, 128), device='xla')
82+
result = assume_pure(mark_sharding)(x, mesh,
83+
(("model", "batch"), None, None, None))
84+
torch_xla.sync(wait=True)
85+
N = xr.global_runtime_device_count()
86+
self.assertIn(f'devices=[{N}',
87+
torch_xla._XLAC._get_xla_sharding_spec(result))
88+
89+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
90+
"Multiple devices required")
91+
@unittest.skipIf(
92+
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
93+
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
94+
)
95+
def test_assume_pure_works_with_mark_sharding_with_gradients_nested(self):
96+
mesh = get_2d_mesh("model", "batch")
97+
set_global_mesh(mesh)
98+
x = torch.randn((8, 4, 5, 128)).to('xla').requires_grad_(True)
99+
result = assume_pure(mark_sharding_with_gradients)(
100+
x, mesh, (("model", "batch"), None, None, None))
101+
result.sum().backward()
102+
torch_xla.sync(wait=True)
103+
N = xr.global_runtime_device_count()
104+
self.assertIn(f'devices=[{N}',
105+
torch_xla._XLAC._get_xla_sharding_spec(result))
106+
assert x.grad is not None
107+
self.assertIn(f'devices=[{N}',
108+
torch_xla._XLAC._get_xla_sharding_spec(x.grad))
109+
59110
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
60111
"Multiple devices required")
61112
@unittest.skipIf(
@@ -94,6 +145,33 @@ def test_convert_to_jax_mesh_shuffled(self):
94145
np.array([dev['coords'] for dev in torch_xla_devices.flatten()]),
95146
)
96147

148+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
149+
"Multiple devices required")
150+
@unittest.skipUnless(os.environ.get('PJRT_DEVICE') == 'TPU', "TPU only test")
151+
def test_pure_module(self):
152+
"""Test tracing `nn.Linear` and `EinsumLinear` with `assume_pure`."""
153+
for transform in [apply_xla_patch_to_nn_linear, lambda x: x]:
154+
with torch_xla.device():
155+
# Arrange
156+
original = nn.Linear(4, 8)
157+
replaced = PureModule(transform(copy(original)))
158+
inputs = torch.ones((4,))
159+
torch_xla.sync()
160+
161+
# Act
162+
original_output = original(inputs)
163+
original_output.sum().backward()
164+
replaced_output = replaced(inputs)
165+
replaced_output.sum().backward()
166+
torch_xla.sync()
167+
168+
# Assert
169+
torch.testing.assert_close(original_output, replaced_output)
170+
torch.testing.assert_close(original.weight.grad,
171+
replaced._module.weight.grad)
172+
torch.testing.assert_close(original.bias.grad,
173+
replaced._module.bias.grad)
174+
97175

98176
if __name__ == '__main__':
99177
test = unittest.main()

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -644,9 +644,8 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
644644
tx = maybe_get_torchax()
645645
if tx is not None and isinstance(t, tx.tensor.Tensor):
646646
from jax.sharding import PartitionSpec as P, NamedSharding
647-
op_sharding = tuple(str(i) if i is not None else i for i in partition_spec)
648647
jmesh = mesh.get_jax_mesh()
649-
t.shard_(NamedSharding(jmesh, P(*op_sharding)))
648+
t.shard_(NamedSharding(jmesh, P(*partition_spec)))
650649
return t
651650

652651
op_sharding = mesh.get_op_sharding(partition_spec)
@@ -986,8 +985,9 @@ def apply_xla_patch_to_nn_linear(module: torch.nn.Module):
986985
for name, child in module.named_children():
987986
if isinstance(child,
988987
torch.nn.Linear) and not isinstance(child, EinsumLinear):
989-
einsum_linear = EinsumLinear(
990-
child.in_features, child.out_features, bias=child.bias is not None)
988+
with torch.device('meta'):
989+
einsum_linear = EinsumLinear(
990+
child.in_features, child.out_features, bias=child.bias is not None)
991991
einsum_linear.load_state_dict(
992992
child.state_dict(), strict=True, assign=True)
993993
setattr(module, name, einsum_linear)

torch_xla/experimental/assume_pure.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from copy import copy
2-
from functools import wraps
1+
from functools import wraps, partial
32
from typing import Dict
43

54
import torch
6-
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
5+
import torch.nn as nn
6+
from torch.utils._pytree import tree_flatten, tree_unflatten
77
import torch_xla
88
from torch_xla._internal.jax_workarounds import requires_jax
99
import torch_xla.core.xla_builder as xb
@@ -48,6 +48,33 @@ def j2t_autograd(fn):
4848
fn, call_jax=lambda fn, *args: xb.call_jax(fn, args))
4949

5050

51+
class PureModule(nn.Module):
52+
"""Wraps a module whose forward pass is known to be free of side-effects and whose
53+
behavior only depends on the inputs.
54+
55+
It behaves as if decorating the wrapped module's functionalized forward pass with `@assume_pure`.
56+
57+
This wrapper has a few advantages over the underlying module:
58+
- `PureModule`s will only be traced once.
59+
- Framework profile scopes added via `xp.Trace` will show up in both the forward
60+
and the backward pass.
61+
"""
62+
63+
def __init__(self, module: nn.Module) -> None:
64+
super().__init__()
65+
self._module = module
66+
self._pure_forward = assume_pure(partial(_pure_forward, self._module))
67+
68+
def forward(self, *args, **kwargs):
69+
params = dict(self._module.named_parameters())
70+
buffers = dict(self._module.named_buffers())
71+
return self._pure_forward(params, buffers, args, kwargs)
72+
73+
74+
def _pure_forward(module, params, buffers, args, kwargs):
75+
return torch.func.functional_call(module, (params, buffers), args, kwargs)
76+
77+
5178
def make_fake_inputs(input):
5279
"""Creates a fake input for the given input torch tensor. If the input
5380
is not a tensor, it returns the input as is.

torchax/torchax/ops/jaten.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,7 +1346,7 @@ def reduce_fn(a, b):
13461346
try:
13471347

13481348
@op(torch.ops.xla.max_pool2d_forward)
1349-
def _xla_max_pool2d_foward(*args, **kwargs):
1349+
def _xla_max_pool2d_forward(*args, **kwargs):
13501350
return _aten_max_pool2d_with_indices(*args, **kwargs)[0]
13511351

13521352
@op(torch.ops.xla.aot_mark_sharding)
@@ -1357,11 +1357,17 @@ def _xla_aot_mark_sharding(t, mesh: str, partition_spec: str):
13571357
pmesh = xs.Mesh.from_str(mesh)
13581358
assert pmesh is not None
13591359
partition_spec_eval = ast.literal_eval(partition_spec)
1360-
op_sharding = tuple(
1361-
str(i) if i is not None else i for i in partition_spec_eval)
13621360
jmesh = pmesh.get_jax_mesh()
13631361
return jax.lax.with_sharding_constraint(
1364-
t, NamedSharding(jmesh, P(*op_sharding)))
1362+
t, NamedSharding(jmesh, P(*partition_spec_eval)))
1363+
1364+
@op(torch.ops.xla.einsum_linear_forward)
1365+
def _xla_einsum_linear_forward(input, weight, bias):
1366+
with jax.named_scope('einsum_linear_forward'):
1367+
product = jax.numpy.einsum('...n,mn->...m', input, weight)
1368+
if bias is not None:
1369+
return product + bias
1370+
return product
13651371

13661372
except AttributeError:
13671373
pass

0 commit comments

Comments
 (0)