Skip to content

Commit 5e19503

Browse files
authored
WIP: Sharding for updated test framework (#13)
* Sharding for updated test framework * Grid configs
1 parent bb3ef72 commit 5e19503

File tree

3 files changed

+199
-104
lines changed

3 files changed

+199
-104
lines changed

axlearn/common/mixture_of_experts_neuron_test.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from axlearn.common.utils_neuron import TestConfig
2121
from axlearn.common.utils_neuron import get_training_configs
2222

23-
jax.config.update('jax_platform_name', 'cpu')
23+
#jax.config.update('jax_platform_name', 'cpu') # Do we need this ?
2424

2525
MODULE_UNIT_TEST_ATOL=1e-6
2626
MODULE_UNIT_TEST_RTOL=1e-3
@@ -40,20 +40,20 @@ def _fwd_call(self, layer, state, inputs):
4040
@parameterized.named_parameters(get_training_configs())
4141
def test_fwd_correctness(self, cfg: TestConfig):
4242

43-
@partial(jax.jit, device=jax.devices(cfg.test.device)[0])
44-
def test_fwd_call(inputs):
45-
test_output, _ = self._fwd_call(cfg.test_layer, cfg.test_state, inputs)
43+
@partial(jax.jit, out_shardings=cfg.out_shard_test) # cannot specify both backend and sharding together
44+
def test_fwd_call():
45+
test_output, _ = self._fwd_call(cfg.test_layer, cfg.test_state, cfg.test_inputs)
4646
return test_output
4747

48-
@partial(jax.jit, device=jax.devices(cfg.golden.device)[0])
49-
def golden_fwd_call(inputs):
50-
golden_output, _ = self._fwd_call(cfg.golden_layer, cfg.golden_state, inputs)
48+
@partial(jax.jit, out_shardings=cfg.out_shard_golden)
49+
def golden_fwd_call():
50+
golden_output, _ = self._fwd_call(cfg.golden_layer, cfg.golden_state, cfg.golden_inputs)
5151
return golden_output
5252

53-
inputs_test = jax.device_put(cfg.inputs, jax.devices(cfg.test.device)[0])
54-
test_output = test_fwd_call(inputs_test)
55-
inputs_golden = jax.device_put(cfg.inputs, jax.devices(cfg.golden.device)[0])
56-
golden_output = golden_fwd_call(inputs_golden)
53+
with cfg.mesh_test:
54+
test_output = test_fwd_call()
55+
with cfg.mesh_golden:
56+
golden_output = golden_fwd_call()
5757

5858
if cfg.conv_output != None:
5959
test_output = cfg.conv_output(test_output)
@@ -64,28 +64,28 @@ def golden_fwd_call(inputs):
6464
@parameterized.named_parameters(get_training_configs())
6565
def test_bwd_correctness(self, cfg: TestConfig):
6666

67-
@partial(jax.jit, backend=cfg.test.device)
68-
def test_bwd_call(inputs):
67+
@partial(jax.jit, out_shardings=cfg.out_shard_test)
68+
def test_bwd_call():
6969
def loss_fn(state):
70-
test_output, _ = self._fwd_call(cfg.test_layer, state, inputs)
70+
test_output, _ = self._fwd_call(cfg.test_layer, state, cfg.test_inputs)
7171
return cfg.loss_fn(test_output)
7272

7373
loss, grads = jax.value_and_grad(loss_fn, has_aux=False)(cfg.test_state)
74-
return loss, grads
74+
return loss, grads
7575

76-
@partial(jax.jit, backend=cfg.golden.device)
77-
def golden_bwd_call(inputs):
76+
@partial(jax.jit, out_shardings=cfg.out_shard_golden)
77+
def golden_bwd_call():
7878
def loss_fn(state):
79-
golden_output, _ = self._fwd_call(cfg.golden_layer, state, inputs)
79+
golden_output, _ = self._fwd_call(cfg.golden_layer, state, cfg.golden_inputs)
8080
return cfg.loss_fn(golden_output)
8181

8282
loss, grads = jax.value_and_grad(loss_fn, has_aux=False)(cfg.golden_state)
8383
return loss, grads
8484

85-
inputs_test = jax.device_put(cfg.inputs, jax.devices(cfg.test.device)[0])
86-
test_loss, test_grads = test_bwd_call(inputs_test)
87-
inputs_golden = jax.device_put(cfg.inputs, jax.devices(cfg.golden.device)[0])
88-
golden_loss, golden_grads = golden_bwd_call(inputs_golden)
85+
with cfg.mesh_test:
86+
test_loss, test_grads = test_bwd_call()
87+
with cfg.mesh_golden:
88+
golden_loss, golden_grads = golden_bwd_call()
8989

9090
# Transfer results to CPU before comparison
9191
test_loss = jax.tree_map(jax.device_get, test_loss)
@@ -111,48 +111,52 @@ def _fwd_call(self, layer, state, inputs):
111111
@parameterized.named_parameters(get_training_configs(is_unit=True))
112112
def test_fwd_correctness(self, cfg: TestConfig):
113113

114-
@partial(jax.jit, backend="cpu")
114+
@partial(jax.jit, out_shardings=cfg.out_shard_test) # cannot specify both backend and sharding together
115115
def test_fwd_call():
116-
test_output, _ = self._fwd_call(cfg.test_layer, cfg.test_state, cfg.inputs)
116+
test_output, _ = self._fwd_call(cfg.test_layer, cfg.test_state, cfg.test_inputs)
117117
return test_output
118118

119-
@partial(jax.jit, backend="cpu")
119+
@partial(jax.jit, out_shardings=cfg.out_shard_golden)
120120
def golden_fwd_call():
121-
golden_output, _ = self._fwd_call(cfg.golden_layer, cfg.golden_state, cfg.inputs)
121+
golden_output, _ = self._fwd_call(cfg.golden_layer, cfg.golden_state, cfg.golden_inputs)
122122
return golden_output
123-
124-
test_output = test_fwd_call()
125-
golden_output = golden_fwd_call()
123+
124+
with cfg.mesh_test:
125+
test_output = test_fwd_call()
126+
with cfg.mesh_golden:
127+
golden_output = golden_fwd_call()
126128

127129
if cfg.conv_output != None:
128130
test_output = cfg.conv_output(test_output)
129131

130132
# Transfer results to CPU before comparison
131133
self.assertNestedAllClose(jax.device_get(test_output), jax.device_get(golden_output))
132134

133-
@parameterized.named_parameters(get_training_configs())
135+
@parameterized.named_parameters(get_training_configs(is_unit=True))
134136
def test_bwd_correctness(self, cfg: TestConfig):
135137

136-
@partial(jax.jit, backend="cpu")
137-
def test_bwd_call(state):
138+
@partial(jax.jit, out_shardings=cfg.out_shard_test)
139+
def test_bwd_call():
138140
def loss_fn(state):
139-
test_output, _ = self._fwd_call(cfg.test_layer, state, cfg.inputs)
141+
test_output, _ = self._fwd_call(cfg.test_layer, state, cfg.test_inputs)
140142
return cfg.loss_fn(test_output)
141143

142-
loss, grads = jax.value_and_grad(loss_fn, has_aux=False)(state)
143-
return loss, grads
144+
loss, grads = jax.value_and_grad(loss_fn, has_aux=False)(cfg.test_state)
145+
return loss, grads
144146

145-
@partial(jax.jit, backend="cpu")
146-
def golden_bwd_call(state):
147+
@partial(jax.jit, out_shardings=cfg.out_shard_golden)
148+
def golden_bwd_call():
147149
def loss_fn(state):
148-
golden_output, _ = self._fwd_call(cfg.golden_layer, state, cfg.inputs)
150+
golden_output, _ = self._fwd_call(cfg.golden_layer, state, cfg.golden_inputs)
149151
return cfg.loss_fn(golden_output)
150152

151-
loss, grads = jax.value_and_grad(loss_fn, has_aux=False)(state)
153+
loss, grads = jax.value_and_grad(loss_fn, has_aux=False)(cfg.golden_state)
152154
return loss, grads
153155

154-
test_loss, test_grads = test_bwd_call(cfg.test_state)
155-
golden_loss, golden_grads = golden_bwd_call(cfg.golden_state)
156+
with cfg.mesh_test:
157+
test_loss, test_grads = test_bwd_call()
158+
with cfg.mesh_golden:
159+
golden_loss, golden_grads = golden_bwd_call()
156160

157161
# Transfer results to CPU before comparison
158162
test_loss = jax.tree_map(jax.device_get, test_loss)

0 commit comments

Comments
 (0)