Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def get_default_dependencies():
return [
"torch>=2.6.0",
]
elif platform == "npu":
return [
"torch_npu==2.6.0",
"triton-ascend"
]


def get_optional_dependencies():
Expand Down Expand Up @@ -67,7 +72,21 @@ def is_xpu_available():
return False


def get_platform() -> Literal["cuda", "rocm", "cpu", "xpu"]:
def is_ascend_available() -> bool:
"""Best-effort Ascend detection.

Checks for common Ascend environment variables and a possible `npu-smi`
utility if present.
"""
try:
subprocess.run(["npu-smi", "info"], check=True)
return True
except (subprocess.SubprocessError, FileNotFoundError):
pass
return False


def get_platform() -> Literal["cuda", "rocm", "cpu", "xpu", "npu"]:
"""
Detect whether the system has NVIDIA or AMD GPU without torch dependency.
"""
Expand All @@ -86,6 +105,9 @@ def get_platform() -> Literal["cuda", "rocm", "cpu", "xpu"]:
if is_xpu_available():
print("Intel GPU detected")
return "xpu"
elif is_ascend_available():
print("Ascend NPU detected")
return "npu"
else:
print("No GPU detected")
return "cpu"
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from liger_kernel.ops.utils import element_mul_kernel
from liger_kernel.ops.utils import is_hip
from liger_kernel.utils import infer_device
from liger_kernel.utils import is_npu_available

if compare_version("triton", operator.ge, "3.0.0"):
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import tanh
Expand Down
7 changes: 5 additions & 2 deletions src/liger_kernel/ops/dyt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import infer_device
from liger_kernel.utils import get_npu_multi_processor_count
from liger_kernel.utils import is_npu_available

if compare_version("triton", operator.ge, "3.0.0"):
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import tanh
Expand Down Expand Up @@ -125,7 +127,8 @@ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
elif device == "xpu":
NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count

elif device == "npu":
NUM_SMS = get_npu_multi_processor_count()
da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
Expand Down
6 changes: 5 additions & 1 deletion src/liger_kernel/ops/fused_add_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import torch_to_triton_dtype
from liger_kernel.utils import get_npu_multi_processor_count
from liger_kernel.utils import is_npu_available

if compare_version("triton", operator.ge, "3.0.0"):
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import rsqrt
Expand Down Expand Up @@ -293,6 +295,8 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
sm_count = torch.cuda.get_device_properties(S.device).multi_processor_count
elif S.device.type == "xpu":
sm_count = torch.xpu.get_device_properties(S.device).gpu_eu_count
elif S.device.type == "npu":
sm_count = get_npu_multi_processor_count()

# fp32 for numerical stability especially.
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/ops/geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.utils import is_npu_available

if compare_version("triton", operator.ge, "3.0.0"):
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import tanh
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/ops/group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.utils import is_npu_available

if compare_version("triton", operator.ge, "3.0.0"):
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import rsqrt
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/ops/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.utils import is_npu_available

if compare_version("triton", operator.ge, "3.0.0"):
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import rsqrt
Expand Down
6 changes: 5 additions & 1 deletion src/liger_kernel/ops/poly_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.utils import get_npu_multi_processor_count
from liger_kernel.utils import is_npu_available

if compare_version("triton", operator.ge, "3.0.0"):
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
Expand Down Expand Up @@ -290,6 +292,8 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
elif X.device.type == "xpu":
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
elif X.device.type == "npu":
sm_count = get_npu_multi_processor_count()

# Allocate or reuse gradients
if in_place is True:
Expand Down
6 changes: 5 additions & 1 deletion src/liger_kernel/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import torch_to_triton_dtype
from liger_kernel.utils import get_npu_multi_processor_count
from liger_kernel.utils import is_npu_available

if compare_version("triton", operator.ge, "3.0.0"):
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import rsqrt
Expand Down Expand Up @@ -449,6 +451,8 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
elif X.device.type == "xpu":
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
elif X.device.type == "npu":
sm_count = get_npu_multi_processor_count()

# fp32 for numerical stability especially.
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
Expand Down
5 changes: 5 additions & 0 deletions src/liger_kernel/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def get_amp_custom_fwd_bwd() -> Callable:
functools.partial(torch.amp.custom_fwd, device_type=device),
functools.partial(torch.amp.custom_bwd, device_type=device),
)
try:
if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
except Exception:
pass
Comment on lines +81 to +85
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exception could it possibly be?

return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd


Expand Down
23 changes: 23 additions & 0 deletions src/liger_kernel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,35 @@ def infer_device():
"""
if torch.cuda.is_available(): # Works for both Nvidia and AMD
return "cuda"
# Use Ascend NPU if available (torch.npu)
elif is_npu_available():
return "npu"
# XPU (Intel) if available
elif torch.xpu.is_available():
return "xpu"
else:
return "cpu"


def is_npu_available() -> bool:
"""Detect Ascend NPU availability."""
try:
from transformers.utils import is_torch_npu_available

return is_torch_npu_available()
except Exception:
return False


def get_npu_multi_processor_count() -> int:
"""Return a heuristic multi-processor count for NPU."""
NPU_MULTI_PROCESSOR_COUNT = 48
if is_npu_available():
return NPU_MULTI_PROCESSOR_COUNT
# Reasonable default to avoid division by zero
return 1

Comment on lines +41 to +48
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is torch.npu going to support get_device_properties so we can get these numbers programmatically? If that's the case, I suggest using that method instead and leave magic number as a fallback. WDYT


def transformers_version_dispatch(
required_version: str,
before_fn,
Expand Down
4 changes: 4 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import pytest
import torch

from liger_kernel.utils import is_npu_available


@pytest.fixture(autouse=True)
def clear_gpu_cache():
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif is_npu_available():
torch.npu.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()
5 changes: 5 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def set_seed(seed=42):
# If you are using XPU
torch.xpu.manual_seed(seed)
torch.xpu.manual_seed_all(seed)
elif device == "npu":
torch.npu.manual_seed(seed)
torch.npu.manual_seed_all(seed)

# Python hash seed
os.environ["PYTHONHASHSEED"] = str(seed)
Expand Down Expand Up @@ -258,6 +261,8 @@ def supports_bfloat16():
return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer
elif device == "xpu":
return True
elif device == "npu":
return True
else:
return False

Expand Down