Skip to content

Commit 30b12eb

Browse files
committed
testing misc
Signed-off-by: Tcc0403 <[email protected]>
1 parent 98de5b9 commit 30b12eb

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

src/liger_kernel/ops/helion/fused_linear_cross_entropy.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,13 +274,19 @@ def forward(self, x, target):
274274

275275
device = "cuda"
276276

277-
batch_size = 8
278-
seq_len = 4096
277+
batch_size = 2
278+
seq_len = 1024
279279
hidden_size = 4096
280280
vocab_size = 32000
281+
# batch_size = 2
282+
# seq_len = 256
283+
# hidden_size = 512
284+
# vocab_size = 1024
281285
dtype = torch.float32
282286
reduction = "mean"
283287
ignore_index = -100
288+
rtol = 1e-2
289+
atol = 1e-2
284290

285291
input = torch.randn(batch_size * seq_len, hidden_size, device=device, requires_grad=True)
286292
weight = torch.randn(vocab_size, hidden_size, device=device, requires_grad=True)
@@ -300,7 +306,7 @@ def forward(self, x, target):
300306
ref_loss: torch.Tensor = ref_lm_head_ce(ref_input, target)
301307
liger_loss: torch.Tensor = liger_lm_head_ce(liger_input, target)
302308

303-
torch.testing.assert_close(liger_loss, ref_loss, rtol=1e-1, atol=1e-1)
309+
torch.testing.assert_close(liger_loss, ref_loss, rtol=rtol, atol=atol)
304310

305311
# Backward pass (backward() with reduction=="none" is not supported yet)
306312
if reduction == "none":
@@ -309,16 +315,16 @@ def forward(self, x, target):
309315
liger_loss.backward()
310316
ref_loss.backward()
311317

312-
torch.testing.assert_close(liger_input.grad, ref_input.grad, rtol=1e-1, atol=1e-1)
318+
torch.testing.assert_close(liger_input.grad, ref_input.grad, rtol=rtol, atol=atol)
313319
torch.testing.assert_close(
314-
liger_lm_head_ce.lm_head.weight.grad, ref_lm_head_ce.lm_head.weight.grad, rtol=1e-1, atol=1e-1
320+
liger_lm_head_ce.lm_head.weight.grad, ref_lm_head_ce.lm_head.weight.grad, rtol=rtol, atol=atol
315321
)
316322

317323

318324
# Benchmark
319-
from helion._testing import run_example
320325
from functools import partial
321326

327+
from helion._testing import run_example
322328

323329
def fwd_bwd_fn(input, target, fn):
324330
loss = fn(input, target)
@@ -328,5 +334,5 @@ def fwd_bwd_fn(input, target, fn):
328334
ref_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=ref_lm_head_ce)
329335

330336

331-
run_example(liger_lm_head_ce, ref_lm_head_ce, (input, target), kernel_name="helion_flce_fwd", baseline_name="torch_fwd", rtol=1e-1, atol=1e-1)
332-
run_example(liger_lm_head_ce_fwd_bwd, ref_lm_head_ce_fwd_bwd, (input, target), kernel_name="helion_flce_fwd_bwd", baseline_name="torch_fwd_bwd", rtol=1e-1, atol=1e-1)
337+
run_example(liger_lm_head_ce, ref_lm_head_ce, (input, target), kernel_name="helion_flce_fwd", baseline_name="torch_fwd", rtol=rtol, atol=atol)
338+
run_example(liger_lm_head_ce_fwd_bwd, ref_lm_head_ce_fwd_bwd, (input, target), kernel_name="helion_flce_fwd_bwd", baseline_name="torch_fwd_bwd", rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)