Skip to content

[RFC] Native Ascend NPU Support for Liger Kernel #954

@Ginray

Description

@Ginray

1. Background & Motivation

Ascend NPU is a default PyTorch device backend, natively compatible with ecosystems like Transformers, FlagGems, and Llama Factory. We’re also enabling Triton support (repo: triton-ascend).

Liger Kernel’s wide adoption has led to growing user requests for NPU support. This proposal aims to advance native adaptation, with community input welcome.

2. Proposed Implementation Steps

Adaptation proceeds in phases, with no breaking changes to existing devices.

2.1 Device Support Integration

Extend device detection logic to include NPU:

def infer_device():
    """Get current device name based on available devices"""
    if torch.cuda.is_available():  # Works for Nvidia/AMD
        return "cuda"
    elif is_npu_available():  # Ascend NPU check
        return "npu"
    elif torch.xpu.is_available():
        return "xpu"
    else:
        return "cpu"

2.2 Operator Compatibility Guarantee

  • Native-run Operators: Some (e.g., geglu, fused_neighborhood_attention) work on NPU via basic device adaptation.
  • NPU-specific Adjustments: Partial operators need BLOCK_SIZE tweaks to avoid Unified Buffer (UB) overflow.

Example: geglu Operator Modification
Original code:

@triton.jit
def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
    program_id = tl.program_id(0).to(tl.int64)
    a += program_id * stride
    b += program_id * stride
    c += program_id * stride

    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols
    a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
    b_row = tl.load(b + col_offsets, mask=mask, other=0)

    # GELU tanh approximation
    sqrt_2_over_pi = 0.7978845608028654
    a_cubed = a_row ** 3
    tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
    c_row = 0.5 * a_row * (1 + tanh(tanh_arg)) * b_row
    tl.store(c + col_offsets, c_row, mask=mask)

NPU-adapted code (add BLOCK_SIZE_SUB for UB overflow prevention):

@triton.jit
def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, 
                               BLOCK_SIZE: tl.constexpr, BLOCK_SIZE_SUB: tl.constexpr):
    program_id = tl.program_id(0).to(tl.int64)
    base_offset = program_id * stride
    num_sub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_SIZE_SUB)

    for sub_block_idx in range(num_sub_blocks):
        col_offsets = tl.arange(0, BLOCK_SIZE_SUB) + sub_block_idx * BLOCK_SIZE_SUB
        mask = col_offsets < n_cols
        
        a_row = tl.load(a + base_offset + col_offsets, mask=mask, other=0).to(tl.float32)
        b_row = tl.load(b + base_offset + col_offsets, mask=mask, other=0)
        
        # GELU tanh approximation (same as original)
        sqrt_2_over_pi = 0.7978845608028654
        a_cubed = a_row ** 3
        tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
        c_row = 0.5 * a_row * (1 + tanh(tanh_arg)) * b_row
        
        tl.store(c + base_offset + col_offsets, c_row, mask=mask)

Note: BLOCK_SIZE_SUB = BLOCK_SIZE for non-NPU devices.

2.3 Performance Optimization

  • Tune operator parameters for NPU performance.
  • Future plan: Add NPU-affine APIs to triton-ascend.

2.4 CI Integration

We’re applying for NPU devices to integrate into Liger Kernel’s native CI for continuous validation.

3. Code Organization Proposal

NPU modifications won’t break existing code. To improve readability, we consider organizing device-specific operators into dedicated directories (referencing FlagGems):
Reference: FlagGems Muti backend

Seeking community feedback on this approach.

4. Conclusion

We aim to extend Liger Kernel to Ascend NPU while ensuring compatibility/performance. Your input will help refine this initiative.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions