2020from axlearn .common .utils_neuron import TestConfig
2121from axlearn .common .utils_neuron import get_training_configs
2222
23- jax .config .update ('jax_platform_name' , 'cpu' )
23+ # jax.config.update('jax_platform_name', 'cpu') # Do we need this ?
2424
2525MODULE_UNIT_TEST_ATOL = 1e-6
2626MODULE_UNIT_TEST_RTOL = 1e-3
@@ -40,20 +40,20 @@ def _fwd_call(self, layer, state, inputs):
4040 @parameterized .named_parameters (get_training_configs ())
4141 def test_fwd_correctness (self , cfg : TestConfig ):
4242
43- @partial (jax .jit , device = jax . devices ( cfg .test . device )[ 0 ])
44- def test_fwd_call (inputs ):
45- test_output , _ = self ._fwd_call (cfg .test_layer , cfg .test_state , inputs )
43+ @partial (jax .jit , out_shardings = cfg .out_shard_test ) # cannot specify both backend and sharding together
44+ def test_fwd_call ():
45+ test_output , _ = self ._fwd_call (cfg .test_layer , cfg .test_state , cfg . test_inputs )
4646 return test_output
4747
48- @partial (jax .jit , device = jax . devices ( cfg .golden . device )[ 0 ] )
49- def golden_fwd_call (inputs ):
50- golden_output , _ = self ._fwd_call (cfg .golden_layer , cfg .golden_state , inputs )
48+ @partial (jax .jit , out_shardings = cfg .out_shard_golden )
49+ def golden_fwd_call ():
50+ golden_output , _ = self ._fwd_call (cfg .golden_layer , cfg .golden_state , cfg . golden_inputs )
5151 return golden_output
5252
53- inputs_test = jax . device_put ( cfg .inputs , jax . devices ( cfg . test . device )[ 0 ])
54- test_output = test_fwd_call (inputs_test )
55- inputs_golden = jax . device_put ( cfg .inputs , jax . devices ( cfg . golden . device )[ 0 ])
56- golden_output = golden_fwd_call (inputs_golden )
53+ with cfg .mesh_test :
54+ test_output = test_fwd_call ()
55+ with cfg .mesh_golden :
56+ golden_output = golden_fwd_call ()
5757
5858 if cfg .conv_output != None :
5959 test_output = cfg .conv_output (test_output )
@@ -64,28 +64,28 @@ def golden_fwd_call(inputs):
6464 @parameterized .named_parameters (get_training_configs ())
6565 def test_bwd_correctness (self , cfg : TestConfig ):
6666
67- @partial (jax .jit , backend = cfg .test . device )
68- def test_bwd_call (inputs ):
67+ @partial (jax .jit , out_shardings = cfg .out_shard_test )
68+ def test_bwd_call ():
6969 def loss_fn (state ):
70- test_output , _ = self ._fwd_call (cfg .test_layer , state , inputs )
70+ test_output , _ = self ._fwd_call (cfg .test_layer , state , cfg . test_inputs )
7171 return cfg .loss_fn (test_output )
7272
7373 loss , grads = jax .value_and_grad (loss_fn , has_aux = False )(cfg .test_state )
74- return loss , grads
74+ return loss , grads
7575
76- @partial (jax .jit , backend = cfg .golden . device )
77- def golden_bwd_call (inputs ):
76+ @partial (jax .jit , out_shardings = cfg .out_shard_golden )
77+ def golden_bwd_call ():
7878 def loss_fn (state ):
79- golden_output , _ = self ._fwd_call (cfg .golden_layer , state , inputs )
79+ golden_output , _ = self ._fwd_call (cfg .golden_layer , state , cfg . golden_inputs )
8080 return cfg .loss_fn (golden_output )
8181
8282 loss , grads = jax .value_and_grad (loss_fn , has_aux = False )(cfg .golden_state )
8383 return loss , grads
8484
85- inputs_test = jax . device_put ( cfg .inputs , jax . devices ( cfg . test . device )[ 0 ])
86- test_loss , test_grads = test_bwd_call (inputs_test )
87- inputs_golden = jax . device_put ( cfg .inputs , jax . devices ( cfg . golden . device )[ 0 ])
88- golden_loss , golden_grads = golden_bwd_call (inputs_golden )
85+ with cfg .mesh_test :
86+ test_loss , test_grads = test_bwd_call ()
87+ with cfg .mesh_golden :
88+ golden_loss , golden_grads = golden_bwd_call ()
8989
9090 # Transfer results to CPU before comparison
9191 test_loss = jax .tree_map (jax .device_get , test_loss )
@@ -111,48 +111,52 @@ def _fwd_call(self, layer, state, inputs):
111111 @parameterized .named_parameters (get_training_configs (is_unit = True ))
112112 def test_fwd_correctness (self , cfg : TestConfig ):
113113
114- @partial (jax .jit , backend = "cpu" )
114+ @partial (jax .jit , out_shardings = cfg . out_shard_test ) # cannot specify both backend and sharding together
115115 def test_fwd_call ():
116- test_output , _ = self ._fwd_call (cfg .test_layer , cfg .test_state , cfg .inputs )
116+ test_output , _ = self ._fwd_call (cfg .test_layer , cfg .test_state , cfg .test_inputs )
117117 return test_output
118118
119- @partial (jax .jit , backend = "cpu" )
119+ @partial (jax .jit , out_shardings = cfg . out_shard_golden )
120120 def golden_fwd_call ():
121- golden_output , _ = self ._fwd_call (cfg .golden_layer , cfg .golden_state , cfg .inputs )
121+ golden_output , _ = self ._fwd_call (cfg .golden_layer , cfg .golden_state , cfg .golden_inputs )
122122 return golden_output
123-
124- test_output = test_fwd_call ()
125- golden_output = golden_fwd_call ()
123+
124+ with cfg .mesh_test :
125+ test_output = test_fwd_call ()
126+ with cfg .mesh_golden :
127+ golden_output = golden_fwd_call ()
126128
127129 if cfg .conv_output != None :
128130 test_output = cfg .conv_output (test_output )
129131
130132 # Transfer results to CPU before comparison
131133 self .assertNestedAllClose (jax .device_get (test_output ), jax .device_get (golden_output ))
132134
133- @parameterized .named_parameters (get_training_configs ())
135+ @parameterized .named_parameters (get_training_configs (is_unit = True ))
134136 def test_bwd_correctness (self , cfg : TestConfig ):
135137
136- @partial (jax .jit , backend = "cpu" )
137- def test_bwd_call (state ):
138+ @partial (jax .jit , out_shardings = cfg . out_shard_test )
139+ def test_bwd_call ():
138140 def loss_fn (state ):
139- test_output , _ = self ._fwd_call (cfg .test_layer , state , cfg .inputs )
141+ test_output , _ = self ._fwd_call (cfg .test_layer , state , cfg .test_inputs )
140142 return cfg .loss_fn (test_output )
141143
142- loss , grads = jax .value_and_grad (loss_fn , has_aux = False )(state )
143- return loss , grads
144+ loss , grads = jax .value_and_grad (loss_fn , has_aux = False )(cfg . test_state )
145+ return loss , grads
144146
145- @partial (jax .jit , backend = "cpu" )
146- def golden_bwd_call (state ):
147+ @partial (jax .jit , out_shardings = cfg . out_shard_golden )
148+ def golden_bwd_call ():
147149 def loss_fn (state ):
148- golden_output , _ = self ._fwd_call (cfg .golden_layer , state , cfg .inputs )
150+ golden_output , _ = self ._fwd_call (cfg .golden_layer , state , cfg .golden_inputs )
149151 return cfg .loss_fn (golden_output )
150152
151- loss , grads = jax .value_and_grad (loss_fn , has_aux = False )(state )
153+ loss , grads = jax .value_and_grad (loss_fn , has_aux = False )(cfg . golden_state )
152154 return loss , grads
153155
154- test_loss , test_grads = test_bwd_call (cfg .test_state )
155- golden_loss , golden_grads = golden_bwd_call (cfg .golden_state )
156+ with cfg .mesh_test :
157+ test_loss , test_grads = test_bwd_call ()
158+ with cfg .mesh_golden :
159+ golden_loss , golden_grads = golden_bwd_call ()
156160
157161 # Transfer results to CPU before comparison
158162 test_loss = jax .tree_map (jax .device_get , test_loss )
0 commit comments