Skip to content

Commit a1826d6

Browse files
committed
enable test for ipex
1 parent db8d603 commit a1826d6

File tree

4 files changed

+30
-32
lines changed

4 files changed

+30
-32
lines changed

bitsandbytes/backends/triton/ops.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,14 @@
22

33
import torch
44

5+
from . import triton_kernels
6+
57
# currently codes unused, kept for reference
68
# Should be the same for quant/dequant
79
# from bitsandbytes.functional import get_4bit_type
810
# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu")
911
# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu")
1012

11-
try:
12-
from . import triton_kernels
13-
14-
triton_available = True
15-
except ImportError as e:
16-
print("Import error:", e)
17-
triton_available = False
18-
1913

2014
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
2115
torch._check_is_size(blocksize)

bitsandbytes/backends/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@
1010
ipex_cpu = None
1111
ipex_xpu = None
1212

13+
try:
14+
import triton # noqa: F401
15+
import triton.language as tl # noqa: F401
16+
17+
triton_available = True
18+
except ImportError as e:
19+
triton_available = False
20+
21+
1322
_NF4_QUANT_TABLE = torch.tensor(
1423
[
1524
-1.0,

bitsandbytes/backends/xpu/ops.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import torch
55

66
from ..._ops import register_kernel
7-
from ..utils import ipex_xpu
7+
from ..utils import ipex_xpu, triton_available
88

9-
# With default torch, error:
10-
# NotImplementedError: The operator 'aten::_int_mm' for XPU
9+
# _int_mm is available in torch starting from 2.7 version,
10+
# but currently it's don't have xpu implementation.
1111
if ipex_xpu and torch.__version__ >= (2, 7):
1212

1313
@register_kernel("bitsandbytes::int8_linear_matmul", "xpu")
@@ -18,6 +18,7 @@ def _(A: torch.Tensor, B: torch.Tensor):
1818
).reshape(*A.shape[:-1], B.shape[0])
1919

2020

21+
# IPEX should be faster for xpu, so at first checking if it is available.
2122
if ipex_xpu:
2223

2324
@register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu")
@@ -52,23 +53,15 @@ def _(
5253
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
5354

5455
return out.reshape(shape)
55-
else:
56-
# IPEX should be faster for xpu, so at first checking if it is available.
57-
try:
58-
from ..triton import ops as triton_ops
59-
60-
triton_available = True
61-
except ImportError as e:
62-
print("Import error:", e)
63-
triton_available = False
56+
elif triton_available:
57+
from ..triton import ops as triton_ops
6458

65-
if triton_available:
66-
register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise)
67-
register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu")(triton_ops.dequantize_blockwise_inplace)
68-
register_kernel("bitsandbytes::dequantize_blockwise", "xpu")(triton_ops.dequantize_blockwise)
69-
register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit)
70-
register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace)
71-
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit)
72-
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
73-
else:
74-
warnings.warn("XPU available, but trtion package is missing.")
59+
register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise)
60+
register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu")(triton_ops.dequantize_blockwise_inplace)
61+
register_kernel("bitsandbytes::dequantize_blockwise", "xpu")(triton_ops.dequantize_blockwise)
62+
register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit)
63+
register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace)
64+
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit)
65+
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
66+
else:
67+
warnings.warn("XPU available but no ipex or triton packages found.")

tests/test_modules.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from torch import nn
66

77
import bitsandbytes as bnb
8+
from bitsandbytes.backends.utils import triton_available
9+
from bitsandbytes.functional import ipex_xpu
810
from tests.helpers import get_available_devices, id_formatter
911

1012

@@ -287,8 +289,8 @@ def test_linear_kbit_fp32_bias(device, module):
287289
def test_kbit_backprop(device, module):
288290
if device == "cpu":
289291
pytest.xfail("Test is not yet supported on CPU")
290-
if device == "xpu":
291-
pytest.xfail("Missing int8_double_quant implementation XPU")
292+
if device == "xpu" and module == bnb.nn.Linear8bitLt and not ipex_xpu and triton_available:
293+
pytest.xfail("Missing int8_double_quant implementation in Triton for XPU")
292294

293295
b = 16
294296
dim1 = 36

0 commit comments

Comments
 (0)