1
+ import unittest
2
+ import torch
3
+ import torch_xla2 as tx
4
+ import torch_xla2 .export
5
+ import torch_xla2 .train
6
+ from torch .testing ._internal .common_utils import TestCase
7
+
8
+
9
+ class TrainTest (unittest .TestCase ):
10
+
11
+ def setUp (self ):
12
+ torch .manual_seed (0 )
13
+ torch_xla2 .enable_accuracy_mode ()
14
+
15
+ def test_scan_module (self ):
16
+ x = torch .arange (300 ).reshape (3 , 100 ).to (torch .float32 )
17
+ layers = [
18
+ torch .nn .Linear (100 , 100 ),
19
+ torch .nn .Linear (100 , 100 ),
20
+ torch .nn .Linear (100 , 100 ),
21
+ torch .nn .Linear (100 , 100 ),
22
+ ]
23
+ # repetitively applies the linear
24
+ result = x
25
+ for layer in layers :
26
+ result = layer (result )
27
+
28
+ model = tx .train .ScannedModule (
29
+ layers
30
+ )
31
+
32
+ with torch_xla2 .default_env ():
33
+ x = x .to ('jax' )
34
+ model .to ('jax' )
35
+ result2 = model (x )
36
+ torch .testing .assert_allclose (result , result2 .to ('cpu' ))
37
+
38
+ def test_train_step_can_run (self ):
39
+ import optax
40
+ with torch_xla2 .default_env ():
41
+ model = torch .nn .Linear (100 , 100 )
42
+ model .to ('jax' )
43
+ weights = model .state_dict ()
44
+ x = torch .randn (2 , 100 ).to ('jax' )
45
+ y = torch .tensor ([1 , 2 ]).to ('jax' )
46
+
47
+ def model_fn (weight , buffers , args ):
48
+ return torch .func .functional_call (model , weight , args )
49
+
50
+ loss_fn = torch .nn .CrossEntropyLoss ()
51
+
52
+ optimizer = optax .adam (0.01 )
53
+ opt_state = tx .interop .call_jax (optimizer .init , weights )
54
+
55
+ step = tx .train .make_train_step (model_fn , loss_fn , optimizer )
56
+ print (step (weights , {}, opt_state , x , y ))
57
+
58
+
59
+ if __name__ == '__main__' :
60
+ unittest .main ()
0 commit comments