4
4
import torch
5
5
6
6
from ..._ops import register_kernel
7
- from ..utils import ipex_xpu
7
+ from ..utils import ipex_xpu , triton_available
8
8
9
- # With default torch, error:
10
- # NotImplementedError: The operator 'aten::_int_mm' for XPU
9
+ # _int_mm is available in torch starting from 2.7 version,
10
+ # but currently it's don't have xpu implementation.
11
11
if ipex_xpu and torch .__version__ >= (2 , 7 ):
12
12
13
13
@register_kernel ("bitsandbytes::int8_linear_matmul" , "xpu" )
@@ -18,6 +18,7 @@ def _(A: torch.Tensor, B: torch.Tensor):
18
18
).reshape (* A .shape [:- 1 ], B .shape [0 ])
19
19
20
20
21
+ # IPEX should be faster for xpu, so at first checking if it is available.
21
22
if ipex_xpu :
22
23
23
24
@register_kernel ("bitsandbytes::dequantize_nf4_ipex" , "xpu" )
@@ -52,23 +53,15 @@ def _(
52
53
raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { out .dtype } " )
53
54
54
55
return out .reshape (shape )
55
- else :
56
- # IPEX should be faster for xpu, so at first checking if it is available.
57
- try :
58
- from ..triton import ops as triton_ops
59
-
60
- triton_available = True
61
- except ImportError as e :
62
- print ("Import error:" , e )
63
- triton_available = False
56
+ elif triton_available :
57
+ from ..triton import ops as triton_ops
64
58
65
- if triton_available :
66
- register_kernel ("bitsandbytes::quantize_blockwise" , "xpu" )(triton_ops .quantize_blockwise )
67
- register_kernel ("bitsandbytes::dequantize_blockwise.out" , "xpu" )(triton_ops .dequantize_blockwise_inplace )
68
- register_kernel ("bitsandbytes::dequantize_blockwise" , "xpu" )(triton_ops .dequantize_blockwise )
69
- register_kernel ("bitsandbytes::quantize_4bit" , "xpu" )(triton_ops .quantize_4bit )
70
- register_kernel ("bitsandbytes::dequantize_4bit.out" , "xpu" )(triton_ops .dequantize_4bit_inplace )
71
- register_kernel ("bitsandbytes::dequantize_4bit" , "xpu" )(triton_ops .dequantize_4bit )
72
- register_kernel ("bitsandbytes::gemv_4bit" , "xpu" )(triton_ops .gemv_4bit )
73
- else :
74
- warnings .warn ("XPU available, but trtion package is missing." )
59
+ register_kernel ("bitsandbytes::quantize_blockwise" , "xpu" )(triton_ops .quantize_blockwise )
60
+ register_kernel ("bitsandbytes::dequantize_blockwise.out" , "xpu" )(triton_ops .dequantize_blockwise_inplace )
61
+ register_kernel ("bitsandbytes::dequantize_blockwise" , "xpu" )(triton_ops .dequantize_blockwise )
62
+ register_kernel ("bitsandbytes::quantize_4bit" , "xpu" )(triton_ops .quantize_4bit )
63
+ register_kernel ("bitsandbytes::dequantize_4bit.out" , "xpu" )(triton_ops .dequantize_4bit_inplace )
64
+ register_kernel ("bitsandbytes::dequantize_4bit" , "xpu" )(triton_ops .dequantize_4bit )
65
+ register_kernel ("bitsandbytes::gemv_4bit" , "xpu" )(triton_ops .gemv_4bit )
66
+ else :
67
+ warnings .warn ("XPU available but no ipex or triton packages found." )
0 commit comments