6
6
from ..._ops import register_kernel
7
7
from ..utils import ipex_xpu , triton_available
8
8
9
- # With default torch, error:
10
- # NotImplementedError: The operator 'aten::_int_mm' for XPU
9
+ # _int_mm is available in torch starting from 2.7 version,
10
+ # but currently it's don't have xpu implementation.
11
11
if ipex_xpu and torch .__version__ >= (2 , 7 ):
12
12
13
13
@register_kernel ("bitsandbytes::int8_linear_matmul" , "xpu" )
@@ -18,6 +18,7 @@ def _(A: torch.Tensor, B: torch.Tensor):
18
18
).reshape (* A .shape [:- 1 ], B .shape [0 ])
19
19
20
20
21
+ # IPEX should be faster for xpu, so at first checking if it is available.
21
22
if ipex_xpu :
22
23
23
24
@register_kernel ("bitsandbytes::dequantize_nf4_ipex" , "xpu" )
53
54
54
55
return out .reshape (shape )
55
56
elif triton_available :
56
- # IPEX should be faster for xpu, so at first checking if it is available.
57
57
from ..triton import ops as triton_ops
58
58
59
59
register_kernel ("bitsandbytes::quantize_blockwise" , "xpu" )(triton_ops .quantize_blockwise )
64
64
register_kernel ("bitsandbytes::dequantize_4bit" , "xpu" )(triton_ops .dequantize_4bit )
65
65
register_kernel ("bitsandbytes::gemv_4bit" , "xpu" )(triton_ops .gemv_4bit )
66
66
else :
67
- warnings .warn ("XPU available, but nor ipex or trtion package is found." )
67
+ warnings .warn ("XPU available but no ipex or triton packages found." )
0 commit comments