1+ import math
2+
13import helion
24import helion .language as hl
35import torch
46
57# Best config for default llama3Config(hidden_size=4096, vocab_size=32000) with 4096 bs*seqlen input
68config = helion .Config (block_sizes = [32 , 32 , 256 ], indexing = ['tensor_descriptor' , 'tensor_descriptor' , 'pointer' , 'tensor_descriptor' , 'pointer' , 'pointer' , 'tensor_descriptor' , 'pointer' , 'tensor_descriptor' ], load_eviction_policies = ['first' , '' , 'first' , 'last' , '' , '' , 'last' , 'first' ], num_stages = 5 , num_warps = 4 , pid_type = 'flat' , range_flattens = [None , True , False ], range_multi_buffers = [None , True , False ], range_num_stages = [0 , 0 , 0 ], range_unroll_factors = [0 , 1 , 1 ], range_warp_specializes = [])
79
8- @helion .kernel (config = config , ignore_warnings = [helion .exc .TensorOperationInWrapper ])
9- def fused_linear_cross_entropy_fwd_bwd (
10+ def helion_lock_acquire (lock_ptr , lock_index ):
11+ hl .inline_triton (
12+ """
13+ while tl.atomic_cas({0} + {1}, 0, 1, sem="acquire") == 1:
14+ pass
15+ """ ,
16+ args = (lock_ptr , lock_index ),
17+ output_like = None ,
18+ )
19+
20+ def helion_lock_release (lock_ptr , lock_index ):
21+ hl .inline_triton (
22+ """
23+ tl.atomic_xchg({0} + {1}, 0, sem="release")
24+ """ ,
25+ args = (lock_ptr , lock_index ),
26+ output_like = None ,
27+ )
28+
29+
30+ # @helion.kernel(config=config, ignore_warnings=[helion.exc.TensorOperationInWrapper])
31+ @helion .kernel (autotune_effort = "none" , ignore_warnings = [helion .exc .TensorOperationInWrapper ])
32+ def fused_linear_cross_entropy_fwd (
1033 x : torch .Tensor ,
1134 weight : torch .Tensor ,
1235 target : torch .Tensor ,
@@ -31,15 +54,15 @@ def fused_linear_cross_entropy_fwd_bwd(
3154 block_size_v = hl .register_block_size (V )
3255
3356 nll = torch .zeros (BT , device = x .device , dtype = torch .float32 )
34- grad_x = torch .zeros_like (x , dtype = torch .float32 )
35- grad_w = torch .zeros_like (weight , dtype = torch .float32 )
57+ lse = torch .zeros (BT , device = x .device , dtype = torch .float32 )
3658
3759 # May be useful for splitting fwd and bwd
3860 # lse = torch.full((BT,), fill_value=-torch.inf, device=x.device, dtype=torch.float32)
3961 # neg_target_logits = torch.zeros(BT, device=x.device, dtype=torch.float32)
4062
4163 n_non_ignore = (target != ignore_index ).sum ().unsqueeze (0 )
4264
65+ # forward
4366 for tile_bt in hl .tile (BT , block_size = block_size_bt ):
4467 m_i = hl .zeros ([tile_bt ], dtype = torch .float32 ) - float ("inf" )
4568 d_i = hl .zeros ([tile_bt ], dtype = torch .float32 )
@@ -75,50 +98,86 @@ def fused_linear_cross_entropy_fwd_bwd(
7598 nll_tile /= n_non_ignore_value
7699
77100 nll [tile_bt ] = nll_tile
78- # gradients computation
79- for tile_v in hl .tile (V , block_size = block_size_v ):
80- # Restore logits
81- acc = hl .zeros ([tile_bt , tile_v ], dtype = torch .float32 )
82- for tile_h in hl .tile (H , block_size = block_size_h ):
83- x_tile = x [tile_bt , tile_h ]
84- weight_tile = weight [tile_v , tile_h ]
85- acc = hl .dot (x_tile , weight_tile .T , acc = acc , out_dtype = torch .float32 )
86-
87- # softmax(x_i) = exp(x_i) / sum(exp(x_i))
88- # = exp(x_i) / log(exp(sum(x_i)))
89- # = exp(x_i) / lse = exp(x_i - lse)
90- grad_logits_tile = torch .exp (acc - lse_tile [:, None ])
91- offset = tile_v .index .unsqueeze (0 ) # [1, tile_v]
92- mask = target_indices == offset # [tile_bt, tile_v]
93- grad_logits_tile = grad_logits_tile - mask .float ()
94- # handle out of bound values in grad_logits_tile
95- grad_logits_tile = grad_logits_tile * ((tile_bt .index < BT )[:, None ] & (tile_v .index < V )[None , :])
96-
97- if reduction == "mean" :
98- grad_logits_tile /= n_non_ignore_value
99-
100- for tile_h in hl .tile (H , block_size = block_size_h ):
101- # grad_x = grad_logits @ weight
102- rhs_tile = weight [tile_v , tile_h ]
103- partial_grad_x = hl .dot (grad_logits_tile , rhs_tile , out_dtype = torch .float32 )
104- hl .atomic_add (grad_x , [tile_bt , tile_h ], partial_grad_x )
105- # grad_w = grad_logits.T[tile_v, tile_bt] @ x[tile_bt, tile_h]
106- rhs_tile = x [tile_bt , tile_h ]
107- partial_grad_w = hl .dot (grad_logits_tile .T , rhs_tile , out_dtype = torch .float32 )
108- hl .atomic_add (grad_w , [tile_v , tile_h ], partial_grad_w )
109-
101+ lse [tile_bt ] = lse_tile
102+
110103 if reduction != "none" :
111104 loss = nll .sum ()
112105 else :
113106 loss = nll
107+
108+ return loss , lse
109+
110+ @helion .kernel (autotune_effort = "none" , ignore_warnings = [helion .exc .TensorOperationInWrapper ])
111+ def fused_linear_cross_entropy_bwd (
112+ x : torch .Tensor ,
113+ weight : torch .Tensor ,
114+ target : torch .Tensor ,
115+ lse : torch .Tensor ,
116+ ignore_index : int = - 100 ,
117+ reduction : str = "mean" ,
118+ ):
119+ BT , H = x .size ()
120+ V = weight .size (0 )
121+ block_size_bt = hl .register_block_size (BT )
122+ block_size_h = hl .register_block_size (H )
123+ block_size_v = hl .register_block_size (V )
124+ grad_x = torch .zeros_like (x , dtype = torch .float32 )
125+ grad_w = torch .zeros_like (weight , dtype = torch .float32 )
126+ n_non_ignore = (target != ignore_index ).sum ().unsqueeze (0 )
127+
128+ num_block_bt = (BT + block_size_bt - 1 )// block_size_bt
129+ num_block_h = (H + block_size_h - 1 )// block_size_h
130+ num_block_v = (V + block_size_v - 1 )// block_size_v
131+ grad_x_lock = torch .zeros ((num_block_bt , num_block_h ), dtype = torch .int32 , device = x .device )
132+ grad_w_lock = torch .zeros ((num_block_v , num_block_h ), dtype = torch .int32 , device = x .device )
133+ # backward
134+ for tile_bt , tile_v in hl .tile ([BT , V ], block_size = (block_size_bt , block_size_v )):
135+ # Restore logits
136+ acc2 = hl .zeros ([tile_bt , tile_v ], dtype = torch .float32 )
137+ for tile_h in hl .tile (H , block_size = block_size_h ):
138+ x_tile = x [tile_bt , tile_h ]
139+ weight_tile = weight [tile_v , tile_h ]
140+ acc2 = hl .dot (x_tile , weight_tile .T , acc = acc2 , out_dtype = torch .float32 )
141+
142+ # softmax(x_i) = exp(x_i) / sum(exp(x_i))
143+ # = exp(x_i) / log(exp(sum(x_i)))
144+ # = exp(x_i) / lse = exp(x_i - lse)
145+ lse_tile = lse [tile_bt ]
146+ target_indices = target [tile_bt ].unsqueeze (1 ) # [tile_bt, 1]
147+ if reduction == "mean" :
148+ n_non_ignore_value = hl .load (n_non_ignore , [0 ])
149+
150+ grad_logits_tile = torch .exp (acc2 - lse_tile [:, None ])
151+ offset = tile_v .index .unsqueeze (0 ) # [1, tile_v]
152+ mask = target_indices == offset # [tile_bt, tile_v]
153+ grad_logits_tile = grad_logits_tile - mask .float ()
154+ # handle out of bound values in grad_logits_tile
155+ grad_logits_tile = grad_logits_tile * ((tile_bt .index < BT )[:, None ] & (tile_v .index < V )[None , :])
156+
157+ if reduction == "mean" :
158+ grad_logits_tile /= n_non_ignore_value
159+
160+ for tile_h in hl .tile (H , block_size = block_size_h ):
161+ # grad_x = grad_logits @ weight
162+ rhs_tile = weight [tile_v , tile_h ]
163+ partial_grad_x = hl .dot (grad_logits_tile , rhs_tile , out_dtype = torch .float32 )
164+ helion_lock_acquire (grad_x_lock , tile_bt .id * num_block_h + tile_h .id )
165+ grad_x [tile_bt , tile_h ] += partial_grad_x
166+ helion_lock_release (grad_x_lock , tile_bt .id * num_block_h + tile_h .id )
167+ # hl.atomic_add(grad_x, [tile_bt, tile_h], partial_grad_x)
168+
169+ # for tile_h in hl.tile(H, block_size=block_size_h):
170+ # grad_w = grad_logits.T[tile_v, tile_bt] @ x[tile_bt, tile_h]
171+ rhs_tile = x [tile_bt , tile_h ]
172+ partial_grad_w = hl .dot (grad_logits_tile .T , rhs_tile , out_dtype = torch .float32 )
173+ helion_lock_acquire (grad_w_lock , tile_v .id * num_block_h + tile_h .id )
174+ grad_w [tile_v , tile_h ] += partial_grad_w
175+ helion_lock_release (grad_w_lock , tile_v .id * num_block_h + tile_h .id )
176+ # hl.atomic_add(grad_w, [tile_v, tile_h], partial_grad_w)
177+
114178
115- # return format is not determined yet
116- return loss , dict (
117- {
118- "grad_x" : grad_x ,
119- "grad_w" : grad_w ,
120- }
121- )
179+
180+ return grad_x , grad_w
122181
123182
124183class LigerFusedLinearCrossEntropyHelionFunction (torch .autograd .Function ):
@@ -131,20 +190,30 @@ def forward(
131190 ignore_index = - 100 ,
132191 reduction = "mean" ,
133192 ):
134- loss , aux_output = fused_linear_cross_entropy_fwd_bwd (
193+ loss , lse = fused_linear_cross_entropy_fwd (
135194 _input ,
136195 weight ,
137196 target ,
138197 ignore_index ,
139198 reduction ,
140199 )
141- ctx .save_for_backward (aux_output ["grad_x" ], aux_output ["grad_w" ])
200+ ctx .ignore_index = ignore_index
201+ ctx .reduction = reduction
202+ ctx .save_for_backward (_input , lse )
142203 return loss
143204
144205 @staticmethod
145206 def backward (ctx , grad_output ):
146207 assert grad_output .ndim == 0 , "token_scaling is not supported. grad_output must be a scalar"
147- grad_input , grad_weight = ctx .saved_tensors
208+ _input , lse = ctx .saved_tensors
209+ grad_input , grad_weight = fused_linear_cross_entropy_bwd (
210+ _input ,
211+ weight ,
212+ target ,
213+ lse ,
214+ ctx .ignore_index ,
215+ ctx .reduction ,
216+ )
148217 return grad_input * grad_output , grad_weight * grad_output , None , None , None
149218
150219
@@ -250,6 +319,7 @@ def forward(self, x, target):
250319 from helion ._testing import run_example
251320 from functools import partial
252321
322+
253323 def fwd_bwd_fn (input , target , fn ):
254324 loss = fn (input , target )
255325 loss .backward ()
0 commit comments