@@ -274,13 +274,19 @@ def forward(self, x, target):
274274
275275 device = "cuda"
276276
277- batch_size = 8
278- seq_len = 4096
277+ batch_size = 2
278+ seq_len = 1024
279279 hidden_size = 4096
280280 vocab_size = 32000
281+ # batch_size = 2
282+ # seq_len = 256
283+ # hidden_size = 512
284+ # vocab_size = 1024
281285 dtype = torch .float32
282286 reduction = "mean"
283287 ignore_index = - 100
288+ rtol = 1e-2
289+ atol = 1e-2
284290
285291 input = torch .randn (batch_size * seq_len , hidden_size , device = device , requires_grad = True )
286292 weight = torch .randn (vocab_size , hidden_size , device = device , requires_grad = True )
@@ -300,7 +306,7 @@ def forward(self, x, target):
300306 ref_loss : torch .Tensor = ref_lm_head_ce (ref_input , target )
301307 liger_loss : torch .Tensor = liger_lm_head_ce (liger_input , target )
302308
303- torch .testing .assert_close (liger_loss , ref_loss , rtol = 1e-1 , atol = 1e-1 )
309+ torch .testing .assert_close (liger_loss , ref_loss , rtol = rtol , atol = atol )
304310
305311 # Backward pass (backward() with reduction=="none" is not supported yet)
306312 if reduction == "none" :
@@ -309,16 +315,16 @@ def forward(self, x, target):
309315 liger_loss .backward ()
310316 ref_loss .backward ()
311317
312- torch .testing .assert_close (liger_input .grad , ref_input .grad , rtol = 1e-1 , atol = 1e-1 )
318+ torch .testing .assert_close (liger_input .grad , ref_input .grad , rtol = rtol , atol = atol )
313319 torch .testing .assert_close (
314- liger_lm_head_ce .lm_head .weight .grad , ref_lm_head_ce .lm_head .weight .grad , rtol = 1e-1 , atol = 1e-1
320+ liger_lm_head_ce .lm_head .weight .grad , ref_lm_head_ce .lm_head .weight .grad , rtol = rtol , atol = atol
315321 )
316322
317323
318324 # Benchmark
319- from helion ._testing import run_example
320325 from functools import partial
321326
327+ from helion ._testing import run_example
322328
323329 def fwd_bwd_fn (input , target , fn ):
324330 loss = fn (input , target )
@@ -328,5 +334,5 @@ def fwd_bwd_fn(input, target, fn):
328334 ref_lm_head_ce_fwd_bwd = partial (fwd_bwd_fn , fn = ref_lm_head_ce )
329335
330336
331- run_example (liger_lm_head_ce , ref_lm_head_ce , (input , target ), kernel_name = "helion_flce_fwd" , baseline_name = "torch_fwd" , rtol = 1e-1 , atol = 1e-1 )
332- run_example (liger_lm_head_ce_fwd_bwd , ref_lm_head_ce_fwd_bwd , (input , target ), kernel_name = "helion_flce_fwd_bwd" , baseline_name = "torch_fwd_bwd" , rtol = 1e-1 , atol = 1e-1 )
337+ run_example (liger_lm_head_ce , ref_lm_head_ce , (input , target ), kernel_name = "helion_flce_fwd" , baseline_name = "torch_fwd" , rtol = rtol , atol = atol )
338+ run_example (liger_lm_head_ce_fwd_bwd , ref_lm_head_ce_fwd_bwd , (input , target ), kernel_name = "helion_flce_fwd_bwd" , baseline_name = "torch_fwd_bwd" , rtol = rtol , atol = atol )
0 commit comments