-
Notifications
You must be signed in to change notification settings - Fork 438
Description
1. Background & Motivation
Ascend NPU is a default PyTorch device backend, natively compatible with ecosystems like Transformers, FlagGems, and Llama Factory. We’re also enabling Triton support (repo: triton-ascend).
Liger Kernel’s wide adoption has led to growing user requests for NPU support. This proposal aims to advance native adaptation, with community input welcome.
2. Proposed Implementation Steps
Adaptation proceeds in phases, with no breaking changes to existing devices.
2.1 Device Support Integration
Extend device detection logic to include NPU:
def infer_device():
"""Get current device name based on available devices"""
if torch.cuda.is_available(): # Works for Nvidia/AMD
return "cuda"
elif is_npu_available(): # Ascend NPU check
return "npu"
elif torch.xpu.is_available():
return "xpu"
else:
return "cpu"2.2 Operator Compatibility Guarantee
- Native-run Operators: Some (e.g.,
geglu,fused_neighborhood_attention) work on NPU via basic device adaptation. - NPU-specific Adjustments: Partial operators need
BLOCK_SIZEtweaks to avoid Unified Buffer (UB) overflow.
Example: geglu Operator Modification
Original code:
@triton.jit
def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
program_id = tl.program_id(0).to(tl.int64)
a += program_id * stride
b += program_id * stride
c += program_id * stride
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b + col_offsets, mask=mask, other=0)
# GELU tanh approximation
sqrt_2_over_pi = 0.7978845608028654
a_cubed = a_row ** 3
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
c_row = 0.5 * a_row * (1 + tanh(tanh_arg)) * b_row
tl.store(c + col_offsets, c_row, mask=mask)NPU-adapted code (add BLOCK_SIZE_SUB for UB overflow prevention):
@triton.jit
def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE_SUB: tl.constexpr):
program_id = tl.program_id(0).to(tl.int64)
base_offset = program_id * stride
num_sub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_SIZE_SUB)
for sub_block_idx in range(num_sub_blocks):
col_offsets = tl.arange(0, BLOCK_SIZE_SUB) + sub_block_idx * BLOCK_SIZE_SUB
mask = col_offsets < n_cols
a_row = tl.load(a + base_offset + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b + base_offset + col_offsets, mask=mask, other=0)
# GELU tanh approximation (same as original)
sqrt_2_over_pi = 0.7978845608028654
a_cubed = a_row ** 3
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
c_row = 0.5 * a_row * (1 + tanh(tanh_arg)) * b_row
tl.store(c + base_offset + col_offsets, c_row, mask=mask)Note: BLOCK_SIZE_SUB = BLOCK_SIZE for non-NPU devices.
2.3 Performance Optimization
- Tune operator parameters for NPU performance.
- Future plan: Add NPU-affine APIs to
triton-ascend.
2.4 CI Integration
We’re applying for NPU devices to integrate into Liger Kernel’s native CI for continuous validation.
3. Code Organization Proposal
NPU modifications won’t break existing code. To improve readability, we consider organizing device-specific operators into dedicated directories (referencing FlagGems):
Reference: FlagGems Muti backend
Seeking community feedback on this approach.
4. Conclusion
We aim to extend Liger Kernel to Ascend NPU while ensuring compatibility/performance. Your input will help refine this initiative.