Skip to content

Commit 98de5b9

Browse files
committed
Unfuse forward/backward and use lock
Signed-off-by: Tcc0403 <[email protected]>
1 parent 8781e8e commit 98de5b9

File tree

1 file changed

+116
-46
lines changed

1 file changed

+116
-46
lines changed

src/liger_kernel/ops/helion/fused_linear_cross_entropy.py

Lines changed: 116 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,35 @@
1+
import math
2+
13
import helion
24
import helion.language as hl
35
import torch
46

57
# Best config for default llama3Config(hidden_size=4096, vocab_size=32000) with 4096 bs*seqlen input
68
config = helion.Config(block_sizes=[32, 32, 256], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer', 'tensor_descriptor', 'pointer', 'pointer', 'tensor_descriptor', 'pointer', 'tensor_descriptor'], load_eviction_policies=['first', '', 'first', 'last', '', '', 'last', 'first'], num_stages=5, num_warps=4, pid_type='flat', range_flattens=[None, True, False], range_multi_buffers=[None, True, False], range_num_stages=[0, 0, 0], range_unroll_factors=[0, 1, 1], range_warp_specializes=[])
79

8-
@helion.kernel(config=config, ignore_warnings=[helion.exc.TensorOperationInWrapper])
9-
def fused_linear_cross_entropy_fwd_bwd(
10+
def helion_lock_acquire(lock_ptr, lock_index):
11+
hl.inline_triton(
12+
"""
13+
while tl.atomic_cas({0} + {1}, 0, 1, sem="acquire") == 1:
14+
pass
15+
""",
16+
args=(lock_ptr, lock_index),
17+
output_like=None,
18+
)
19+
20+
def helion_lock_release(lock_ptr, lock_index):
21+
hl.inline_triton(
22+
"""
23+
tl.atomic_xchg({0} + {1}, 0, sem="release")
24+
""",
25+
args=(lock_ptr, lock_index),
26+
output_like=None,
27+
)
28+
29+
30+
# @helion.kernel(config=config, ignore_warnings=[helion.exc.TensorOperationInWrapper])
31+
@helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper])
32+
def fused_linear_cross_entropy_fwd(
1033
x: torch.Tensor,
1134
weight: torch.Tensor,
1235
target: torch.Tensor,
@@ -31,15 +54,15 @@ def fused_linear_cross_entropy_fwd_bwd(
3154
block_size_v = hl.register_block_size(V)
3255

3356
nll = torch.zeros(BT, device=x.device, dtype=torch.float32)
34-
grad_x = torch.zeros_like(x, dtype=torch.float32)
35-
grad_w = torch.zeros_like(weight, dtype=torch.float32)
57+
lse = torch.zeros(BT, device=x.device, dtype=torch.float32)
3658

3759
# May be useful for splitting fwd and bwd
3860
# lse = torch.full((BT,), fill_value=-torch.inf, device=x.device, dtype=torch.float32)
3961
# neg_target_logits = torch.zeros(BT, device=x.device, dtype=torch.float32)
4062

4163
n_non_ignore = (target != ignore_index).sum().unsqueeze(0)
4264

65+
# forward
4366
for tile_bt in hl.tile(BT, block_size=block_size_bt):
4467
m_i = hl.zeros([tile_bt], dtype=torch.float32) - float("inf")
4568
d_i = hl.zeros([tile_bt], dtype=torch.float32)
@@ -75,50 +98,86 @@ def fused_linear_cross_entropy_fwd_bwd(
7598
nll_tile /= n_non_ignore_value
7699

77100
nll[tile_bt] = nll_tile
78-
# gradients computation
79-
for tile_v in hl.tile(V, block_size=block_size_v):
80-
# Restore logits
81-
acc = hl.zeros([tile_bt, tile_v], dtype=torch.float32)
82-
for tile_h in hl.tile(H, block_size=block_size_h):
83-
x_tile = x[tile_bt, tile_h]
84-
weight_tile = weight[tile_v, tile_h]
85-
acc = hl.dot(x_tile, weight_tile.T, acc=acc, out_dtype=torch.float32)
86-
87-
# softmax(x_i) = exp(x_i) / sum(exp(x_i))
88-
# = exp(x_i) / log(exp(sum(x_i)))
89-
# = exp(x_i) / lse = exp(x_i - lse)
90-
grad_logits_tile = torch.exp(acc - lse_tile[:, None])
91-
offset = tile_v.index.unsqueeze(0) # [1, tile_v]
92-
mask = target_indices == offset # [tile_bt, tile_v]
93-
grad_logits_tile = grad_logits_tile - mask.float()
94-
# handle out of bound values in grad_logits_tile
95-
grad_logits_tile = grad_logits_tile * ((tile_bt.index < BT)[:, None] & (tile_v.index < V)[None, :])
96-
97-
if reduction == "mean":
98-
grad_logits_tile /= n_non_ignore_value
99-
100-
for tile_h in hl.tile(H, block_size=block_size_h):
101-
# grad_x = grad_logits @ weight
102-
rhs_tile = weight[tile_v, tile_h]
103-
partial_grad_x = hl.dot(grad_logits_tile, rhs_tile, out_dtype=torch.float32)
104-
hl.atomic_add(grad_x, [tile_bt, tile_h], partial_grad_x)
105-
# grad_w = grad_logits.T[tile_v, tile_bt] @ x[tile_bt, tile_h]
106-
rhs_tile = x[tile_bt, tile_h]
107-
partial_grad_w = hl.dot(grad_logits_tile.T, rhs_tile, out_dtype=torch.float32)
108-
hl.atomic_add(grad_w, [tile_v, tile_h], partial_grad_w)
109-
101+
lse[tile_bt] = lse_tile
102+
110103
if reduction != "none":
111104
loss = nll.sum()
112105
else:
113106
loss = nll
107+
108+
return loss, lse
109+
110+
@helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper])
111+
def fused_linear_cross_entropy_bwd(
112+
x: torch.Tensor,
113+
weight: torch.Tensor,
114+
target: torch.Tensor,
115+
lse: torch.Tensor,
116+
ignore_index: int = -100,
117+
reduction: str = "mean",
118+
):
119+
BT, H = x.size()
120+
V = weight.size(0)
121+
block_size_bt = hl.register_block_size(BT)
122+
block_size_h = hl.register_block_size(H)
123+
block_size_v = hl.register_block_size(V)
124+
grad_x = torch.zeros_like(x, dtype=torch.float32)
125+
grad_w = torch.zeros_like(weight, dtype=torch.float32)
126+
n_non_ignore = (target != ignore_index).sum().unsqueeze(0)
127+
128+
num_block_bt = (BT + block_size_bt - 1)//block_size_bt
129+
num_block_h = (H + block_size_h - 1)//block_size_h
130+
num_block_v = (V + block_size_v - 1)//block_size_v
131+
grad_x_lock = torch.zeros((num_block_bt, num_block_h), dtype=torch.int32, device=x.device)
132+
grad_w_lock = torch.zeros((num_block_v, num_block_h), dtype=torch.int32, device=x.device)
133+
# backward
134+
for tile_bt, tile_v in hl.tile([BT, V], block_size=(block_size_bt, block_size_v)):
135+
# Restore logits
136+
acc2 = hl.zeros([tile_bt, tile_v], dtype=torch.float32)
137+
for tile_h in hl.tile(H, block_size=block_size_h):
138+
x_tile = x[tile_bt, tile_h]
139+
weight_tile = weight[tile_v, tile_h]
140+
acc2 = hl.dot(x_tile, weight_tile.T, acc=acc2, out_dtype=torch.float32)
141+
142+
# softmax(x_i) = exp(x_i) / sum(exp(x_i))
143+
# = exp(x_i) / log(exp(sum(x_i)))
144+
# = exp(x_i) / lse = exp(x_i - lse)
145+
lse_tile = lse[tile_bt]
146+
target_indices = target[tile_bt].unsqueeze(1) # [tile_bt, 1]
147+
if reduction == "mean":
148+
n_non_ignore_value = hl.load(n_non_ignore, [0])
149+
150+
grad_logits_tile = torch.exp(acc2 - lse_tile[:, None])
151+
offset = tile_v.index.unsqueeze(0) # [1, tile_v]
152+
mask = target_indices == offset # [tile_bt, tile_v]
153+
grad_logits_tile = grad_logits_tile - mask.float()
154+
# handle out of bound values in grad_logits_tile
155+
grad_logits_tile = grad_logits_tile * ((tile_bt.index < BT)[:, None] & (tile_v.index < V)[None, :])
156+
157+
if reduction == "mean":
158+
grad_logits_tile /= n_non_ignore_value
159+
160+
for tile_h in hl.tile(H, block_size=block_size_h):
161+
# grad_x = grad_logits @ weight
162+
rhs_tile = weight[tile_v, tile_h]
163+
partial_grad_x = hl.dot(grad_logits_tile, rhs_tile, out_dtype=torch.float32)
164+
helion_lock_acquire(grad_x_lock, tile_bt.id * num_block_h + tile_h.id)
165+
grad_x[tile_bt, tile_h] += partial_grad_x
166+
helion_lock_release(grad_x_lock, tile_bt.id * num_block_h + tile_h.id)
167+
# hl.atomic_add(grad_x, [tile_bt, tile_h], partial_grad_x)
168+
169+
# for tile_h in hl.tile(H, block_size=block_size_h):
170+
# grad_w = grad_logits.T[tile_v, tile_bt] @ x[tile_bt, tile_h]
171+
rhs_tile = x[tile_bt, tile_h]
172+
partial_grad_w = hl.dot(grad_logits_tile.T, rhs_tile, out_dtype=torch.float32)
173+
helion_lock_acquire(grad_w_lock, tile_v.id * num_block_h + tile_h.id)
174+
grad_w[tile_v, tile_h] += partial_grad_w
175+
helion_lock_release(grad_w_lock, tile_v.id * num_block_h + tile_h.id)
176+
# hl.atomic_add(grad_w, [tile_v, tile_h], partial_grad_w)
177+
114178

115-
# return format is not determined yet
116-
return loss, dict(
117-
{
118-
"grad_x": grad_x,
119-
"grad_w": grad_w,
120-
}
121-
)
179+
180+
return grad_x, grad_w
122181

123182

124183
class LigerFusedLinearCrossEntropyHelionFunction(torch.autograd.Function):
@@ -131,20 +190,30 @@ def forward(
131190
ignore_index=-100,
132191
reduction="mean",
133192
):
134-
loss, aux_output = fused_linear_cross_entropy_fwd_bwd(
193+
loss, lse = fused_linear_cross_entropy_fwd(
135194
_input,
136195
weight,
137196
target,
138197
ignore_index,
139198
reduction,
140199
)
141-
ctx.save_for_backward(aux_output["grad_x"], aux_output["grad_w"])
200+
ctx.ignore_index = ignore_index
201+
ctx.reduction = reduction
202+
ctx.save_for_backward(_input, lse)
142203
return loss
143204

144205
@staticmethod
145206
def backward(ctx, grad_output):
146207
assert grad_output.ndim == 0, "token_scaling is not supported. grad_output must be a scalar"
147-
grad_input, grad_weight = ctx.saved_tensors
208+
_input, lse = ctx.saved_tensors
209+
grad_input, grad_weight = fused_linear_cross_entropy_bwd(
210+
_input,
211+
weight,
212+
target,
213+
lse,
214+
ctx.ignore_index,
215+
ctx.reduction,
216+
)
148217
return grad_input * grad_output, grad_weight * grad_output, None, None, None
149218

150219

@@ -250,6 +319,7 @@ def forward(self, x, target):
250319
from helion._testing import run_example
251320
from functools import partial
252321

322+
253323
def fwd_bwd_fn(input, target, fn):
254324
loss = fn(input, target)
255325
loss.backward()

0 commit comments

Comments
 (0)