Skip to content

[Triton/XPU] Support 4bit dequantization logic on Triton #1629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
166 changes: 166 additions & 0 deletions bitsandbytes/backends/triton/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from collections.abc import Sequence

import torch

from . import triton_kernels

# currently codes unused, kept for reference
# Should be the same for quant/dequant
# from bitsandbytes.functional import get_4bit_type
# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu")
# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu")


def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")

n = A.numel()
blocks = -(n // -blocksize)

absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
out = torch.empty_like(A.flatten(), dtype=torch.uint8)

triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
out = out.reshape(A.shape)

return out, absmax.float()


def dequantize_blockwise(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")

out = torch.empty_like(A, dtype=dtype, device=A.device)
triton_kernels.dequant_int8_blockwise(
A,
code,
absmax,
out,
blocksize,
)

return out


def dequantize_blockwise_inplace(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")

triton_kernels.dequant_int8_blockwise(
A,
code,
absmax,
out,
blocksize,
)


def quantize_4bit(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
)

n = A.numel()

# TODO: Support when weight matrix is not divisible by blocksize
# torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")

blocks = -(n // -(blocksize * 2))

absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)

triton_kernels.quantize_4bit_blockwise_triton(
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
)
packed = out

if quant_storage != torch.uint8:
packed = out.squeeze().view(quant_storage).unsqueeze(1)

return packed, absmax.float()


def dequantize_4bit(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on XPU, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
# torch._check(
# A.dtype == torch.uint8,
# lambda: f"Blockwise 4bit dequantization on XPU only supports uint8 storage, got {A.dtype}",
# )
# Check if this is fine and fast
if A.dtype != torch.uint8:
A = A.squeeze().view(torch.uint8).unsqueeze(1)
Comment on lines +116 to +118
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH this was meant to support FSDP1; I'm not sure it's necessary to try to support other storage dtypes for FSDP2, or if it's worth considering at all. @Titus-von-Koeller may have more thought on that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this just to be able to pass tests. If there should be something more complex or better to skip if the type is not uint8 storage, let me know. I'm not sure about this approach with type casting.


out = torch.empty(shape, dtype=dtype, device=A.device)

triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out


def dequantize_4bit_inplace(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)


def gemv_4bit(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
if B.dtype != torch.uint8:
B = B.squeeze().view(torch.uint8).unsqueeze(1)

B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device)

triton_kernels._dequantize_4bit_impl_passing_code(
B,
absmax,
blocksize,
code,
dtype=A.dtype,
out=B_dq_triton,
)

return torch.nn.functional.linear(
A,
B_dq_triton,
bias=None,
)
Loading