Skip to content

Commit f33f6e8

Browse files
committed
Add test and fix comment
1 parent c7f3893 commit f33f6e8

File tree

2 files changed

+60
-3
lines changed

2 files changed

+60
-3
lines changed
+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import unittest
2+
import torch
3+
import torch_xla2 as tx
4+
import torch_xla2.export
5+
import torch_xla2.train
6+
from torch.testing._internal.common_utils import TestCase
7+
8+
9+
class TrainTest(unittest.TestCase):
10+
11+
def setUp(self):
12+
torch.manual_seed(0)
13+
torch_xla2.enable_accuracy_mode()
14+
15+
def test_scan_module(self):
16+
x = torch.arange(300).reshape(3, 100).to(torch.float32)
17+
layers = [
18+
torch.nn.Linear(100, 100),
19+
torch.nn.Linear(100, 100),
20+
torch.nn.Linear(100, 100),
21+
torch.nn.Linear(100, 100),
22+
]
23+
# repetitively applies the linear
24+
result = x
25+
for layer in layers:
26+
result = layer(result)
27+
28+
model = tx.train.ScannedModule(
29+
layers
30+
)
31+
32+
with torch_xla2.default_env():
33+
x = x.to('jax')
34+
model.to('jax')
35+
result2 = model(x)
36+
torch.testing.assert_allclose(result, result2.to('cpu'))
37+
38+
def test_train_step_can_run(self):
39+
import optax
40+
with torch_xla2.default_env():
41+
model = torch.nn.Linear(100, 100)
42+
model.to('jax')
43+
weights = model.state_dict()
44+
x = torch.randn(2, 100).to('jax')
45+
y = torch.tensor([1, 2]).to('jax')
46+
47+
def model_fn(weight, buffers, args):
48+
return torch.func.functional_call(model, weight, args)
49+
50+
loss_fn = torch.nn.CrossEntropyLoss()
51+
52+
optimizer = optax.adam(0.01)
53+
opt_state = tx.interop.call_jax(optimizer.init, weights)
54+
55+
step = tx.train.make_train_step(model_fn, loss_fn, optimizer)
56+
print(step(weights, {}, opt_state, x, y))
57+
58+
59+
if __name__ == '__main__':
60+
unittest.main()

experimental/torch_xla2/torch_xla2/train.py

-3
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ def make_train_step(model_fn,
3030
optax_optimizer: the optimizer from optax library. for example, optax.adam
3131
remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how
3232
to do gradient checkpointing. If None, then it means checkpoint everything.
33-
mark_fsdp_sharding_axis: str. A string name for marking sharding for
34-
fsdp. It must be an axis that exists in the current mesh.
35-
if None, then no sharding is specified (i.e. for single device)
3633
"""
3734
env = torch_xla2.default_env()
3835
def loss(weights, buffers, args, label): # inputs are XLATensor

0 commit comments

Comments
 (0)