Skip to content

Commit 1469367

Browse files
committed
remove fast_hadamard_transform dependency.
Signed-off-by: lkk12014402 <kaokao.lv@intel.com>
1 parent a3f7cf0 commit 1469367

File tree

6 files changed

+296
-14
lines changed

6 files changed

+296
-14
lines changed

auto_round/experimental/transform/apply.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,24 @@ def _apply_to_module(
5252

5353
# create transform as submodule
5454
transform_name = "forward_hadamard"
55-
transform = build_transform(**config.dict())
56-
# module.register_module(transform_name, transform)
5755

5856
if config.location == "input":
5957
from .triton.mxfp4 import mxfp4_forward_kernel_wrapper
6058

59+
transform = build_transform(
60+
**config.dict(),
61+
device="cpu",
62+
precision=module.dtype,
63+
location="input"
64+
)
65+
6166
def input_hook(_, args):
6267
input = args[0]
6368
# transform(input)
6469
orig_shape = input.shape
6570
x_flat = input.contiguous().flatten(end_dim=-2)
6671
qdq_input, _ = mxfp4_forward_kernel_wrapper(
67-
x_flat, transform.get_transform_matrix(input.device, input.dtype)
72+
x_flat, transform.weight
6873
)
6974
return qdq_input.reshape(orig_shape)
7075

@@ -77,6 +82,12 @@ def input_hook(_, args):
7782
# fuse transform into weight
7883
assert hasattr(module, "weight")
7984

85+
transform = build_transform(
86+
**config.dict(),
87+
device=module.weight.device,
88+
precision=module.weight.dtype,
89+
)
90+
8091
if config.need_calibration:
8192
# for training, the weight changes with every forward pass
8293
# for autoround tuning: patch wrapper linear qdq_weight func
@@ -93,7 +104,7 @@ def input_hook(_, args):
93104
# delattr(module, transform_name)
94105
# fuse transform into weight
95106
with torch.no_grad():
96-
getattr(module, "weight").copy_(transform(module.weight.to("cuda")).to(module.weight.device))
107+
getattr(module, "weight").copy_(transform(module.weight).to(module.weight.device))
97108

98109
else:
99110
# TODO: apply transform to output/q/k

auto_round/experimental/transform/transforms.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
import torch
2020
import torch.nn as nn
21-
from fast_hadamard_transform import hadamard_transform
21+
22+
from auto_round.experimental.transform.utils.hadamard import deterministic_hadamard_matrix
23+
from auto_round.experimental.transform.utils.matrix import apply_transform_weight
2224

2325

2426
def filter_kwarg_dict(fn_or_method: Callable, kwarg_dict: Dict[str, Any]) -> Dict[str, Any]:
@@ -40,19 +42,43 @@ def remove_parametrizations(self) -> None:
4042

4143
class HadamardTransform(nn.Module):
4244

43-
def __init__(self, transform_block_size: int = 32):
45+
def __init__(
46+
self,
47+
transform_block_size: int = 32,
48+
device: torch.device = None,
49+
precision: torch.dtype = None,
50+
location: str = "weight",
51+
module_type: type[torch.nn.Module] = torch.nn.Linear,
52+
):
4453
super().__init__()
45-
self.dim = transform_block_size
46-
self.scale = 1 / math.sqrt(self.dim)
54+
self.size = transform_block_size
55+
self.scale = 1 / math.sqrt(self.size)
56+
self.location = location
57+
self.module_type = module_type
58+
self.weight = self._create_weight(self.size, device, precision)
59+
60+
def _create_weight(
61+
self,
62+
size: int,
63+
device: torch.device = None,
64+
precision: torch.dtype = None,
65+
) -> torch.nn.Parameter:
66+
data = deterministic_hadamard_matrix(size, precision, device) * self.scale
67+
# TODO: implement SpinQuant, which rotation matrix is learnable
68+
return nn.Parameter(data, requires_grad=False)
4769

48-
# @torch.no_grad()
4970
def forward(self, x: torch.Tensor):
5071
# Hadamard transform is it own inverse
51-
x_shape = x.shape
52-
return hadamard_transform(x.view(-1, self.dim), scale=self.scale).view(x_shape)
53-
54-
def get_transform_matrix(self, device: torch.device = None, dtype: torch.dtype = None):
55-
return hadamard_transform(torch.eye(self.dim, device=device, dtype=dtype), scale=1 / math.sqrt(self.dim))
72+
ori_shape = x.shape
73+
x = x.view(-1, self.size)
74+
return (
75+
apply_transform_weight(
76+
self.weight.to(device=x.device),
77+
x.to(dtype=self.weight.dtype),
78+
self.location,
79+
self.module_type,
80+
)
81+
).to(x.dtype).view(ori_shape)
5682

5783

5884
TRANSFORMS = {
@@ -64,3 +90,4 @@ def get_transform_matrix(self, device: torch.device = None, dtype: torch.dtype =
6490
def build_transform(transform_type: str, **transform_kwargs):
6591
transform = TRANSFORMS[transform_type]
6692
return transform(**filter_kwarg_dict(transform.__init__, transform_kwargs))
93+

auto_round/experimental/transform/utils/__init__.py

Whitespace-only changes.
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
2+
import math
3+
from pathlib import Path
4+
5+
import torch
6+
from safetensors import safe_open
7+
8+
9+
REPO_PATH = Path(__file__).parent / "hadamards.safetensors"
10+
11+
12+
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix", "is_pow2"]
13+
14+
15+
# note that hadamard matrix multiplication reuses the code from
16+
# https://github.com/vllm-project/compressed-tensors/blob/main/src/compressed_tensors/transform/utils/hadamard.py
17+
18+
19+
def deterministic_hadamard_matrix(
20+
size: int,
21+
dtype: torch.dtype = torch.bfloat16,
22+
device: torch.device = torch.device("cpu"),
23+
) -> torch.Tensor:
24+
"""
25+
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
26+
`n` must be a power of 2.
27+
28+
Adapated from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501
29+
30+
:param size: order of the matrix, must be a power of 2
31+
:param dtype: data type of matrix
32+
:param device: device to construct matrix on
33+
:return: hadamard matrix of size `size`
34+
"""
35+
if size <= 0:
36+
raise ValueError("Cannot construct deterministic hadamard of size <= 0")
37+
38+
log2 = int(math.log2(size))
39+
if size != 2**log2:
40+
raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
41+
42+
H = torch.tensor([[1]], dtype=dtype, device=device)
43+
44+
# Sylvester's construction
45+
for _ in range(log2):
46+
H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
47+
48+
return H
49+
50+
51+
def random_hadamard_matrix(
52+
size: int,
53+
dtype: torch.dtype = torch.bfloat16,
54+
device: torch.device = torch.device("cpu"),
55+
gen: torch.Generator | None = None,
56+
) -> torch.Tensor:
57+
"""
58+
Produces a randomly generated Hadamard matrix. Differs from
59+
`deterministic_hadamard_matrix` in that this function supports non powers of 2
60+
and randomization using a seeded generator
61+
62+
Adapated from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501
63+
Known matrices were retrieved from N. J. A. Sloane's Library of Hadamard Matrices http://www.neilsloane.com/hadamard/ # noqa: E501
64+
65+
:param size: The dimension of the hamadard matrix
66+
:param dtype: data type of matrix
67+
:param device: device to construct matrix on
68+
:param gen: Optional generator random values
69+
:return: randomly generated hadamard matrix
70+
"""
71+
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=dtype) # cpu
72+
Q = Q.to(device=device)
73+
Q = Q * 2 - 1
74+
Q = torch.diag(Q)
75+
return _matmul_hadU(Q)
76+
77+
78+
def is_pow2(n: int) -> bool:
79+
"""
80+
Check if a number is a power of 2
81+
82+
:param n: number to check
83+
:return: True iff `n` is a power of 2
84+
"""
85+
return n > 0 and (n & (n - 1) == 0)
86+
87+
88+
def _fetch_hadamard_divisor(
89+
n: int,
90+
dtype: torch.dtype,
91+
device: torch.device = torch.device("cpu"),
92+
file_path: str = REPO_PATH,
93+
) -> torch.Tensor | None:
94+
"""
95+
Fetch a known hadamard matrix from the given file path. The returned matrix will
96+
be of of size `k` such that `n / k` is a power of two. Return None if no such
97+
matrix exists.
98+
99+
Note: This function reopens the safetensors file every time it is called.
100+
This is technically inefficient, but a very small runtime cost and simpler
101+
than forcing callers to manage the file open context
102+
103+
:param n: size of known hadamard matrix
104+
:param dtype: data type to move fetched hadamard to
105+
:param device: device to move fetched hadamard to
106+
:return: a known hadamard matrix of size `n` if one exists, else None
107+
"""
108+
open_device = torch.device("cpu") if device.type == "meta" else device
109+
with safe_open(file_path, framework="pt", device=str(open_device)) as file:
110+
divisors = sorted((int(key) for key in file.keys()), reverse=True)
111+
for divisor in divisors:
112+
if n % divisor == 0 and is_pow2(n // divisor):
113+
return file.get_tensor(str(divisor)).to(dtype=dtype, device=device)
114+
115+
return None
116+
117+
118+
def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
119+
size = X.size(0)
120+
dtype = X.dtype
121+
device = X.device
122+
123+
# Check if we have the determined hadamard matrix
124+
hadK = _fetch_hadamard_divisor(size, dtype, device=device)
125+
if hadK is None:
126+
raise ValueError(f"Cannot construct random hadamard matrix of size {size}")
127+
K = hadK.size(0)
128+
129+
# Reshape diag matrix with randomized -1/+1
130+
input = X.clone().view(-1, size, 1)
131+
output = input.clone()
132+
while input.shape[1] > K:
133+
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
134+
output = output.view(input.shape)
135+
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
136+
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
137+
output = output.view(input.shape[0], input.shape[1], -1)
138+
(input, output) = (output, input)
139+
assert input.shape[1] == K
140+
del output
141+
142+
# Do not explicitly repeat - OOM
143+
# input = torch.bmm(
144+
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
145+
# Use bcast instead
146+
input = hadK.view(1, K, K).to(input) @ input
147+
148+
# normalize
149+
return input.view(X.shape)
1.37 MB
Binary file not shown.
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import torch
2+
3+
4+
__all__ = ["apply_transform_weight"]
5+
6+
# note that apply_transform_weight reuses some code from
7+
# https://github.com/vllm-project/compressed-tensors/blob/main/src/compressed_tensors/transform/utils/matrix.py
8+
9+
def apply_transform_weight(
10+
transform_weight: torch.Tensor,
11+
value: torch.Tensor,
12+
location: str,
13+
module_type: type[torch.nn.Module],
14+
) -> torch.Tensor:
15+
"""
16+
Using the transform location, apply the transform_weight to the
17+
given value wrt linear weights. For more info on input and output transforms,
18+
see `TransformLocation`
19+
20+
The following explains how weights should be applied to values according to location
21+
22+
let x be input activation
23+
W be weight,
24+
yh, xh, Wh be transformed output, input, weight
25+
26+
note that
27+
y = (x W.T) // torch.nn.Linear
28+
29+
Choose values for yh, xh, and Wh which incorporate matrix transforms
30+
31+
let V, Vi be transform matrices on input side
32+
U, Ui be transform matrices on output side
33+
34+
pick xh = (x V)
35+
Wh = (U.T W Vi.T)
36+
yh = (y U)
37+
38+
The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh
39+
40+
(xh) (Wh).T = (x V) (U.T W Vi.T).T
41+
= (x V) (Vi W.T U) // transpose matrix product identity
42+
= (x W.T) U
43+
= y U
44+
= yh
45+
46+
:param transform_weight: transform weight to apply
47+
:param value: value to apply transform_weight to
48+
:param location: determines how weight should be applied
49+
:param model_type: result of type(module), passed in to determine application of
50+
weight transform
51+
:return: value after transform_weight has been applied
52+
"""
53+
54+
if location == "input":
55+
return _multihead_matmul(value, transform_weight)
56+
57+
if module_type == torch.nn.Linear:
58+
return _multihead_matmul(value, transform_weight.T)
59+
60+
61+
def _multihead_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
62+
"""
63+
Performs A @ B for last two dims of two matrices A and B that possibly
64+
have different shapes, as is the case in multi-headed dimension. If
65+
shapes are different, this is equivalent to converting the last two dims
66+
of the smaller matrix into a block-diagonal matrix with the same shape as
67+
the last two dims of the larger matrix.
68+
69+
E.g. if A is half the size of B, this function will perform
70+
[[A ] @ B
71+
[ A]]
72+
73+
If B is a third of the size of A, this function will perform
74+
A @ [[B ]
75+
[ B ]
76+
[ B]]
77+
78+
This function will error out if the shapes are not evenly divisble
79+
80+
:param A: left-hand tensor
81+
:param B: right-hand tensor
82+
:return: result
83+
"""
84+
if A.shape[-1] > B.shape[-2]:
85+
head_dim = B.shape[-2]
86+
num_heads = A.shape[-1] // head_dim
87+
A = A.unflatten(-1, (num_heads, head_dim))
88+
return (A @ B).flatten(-2, -1)
89+
elif A.shape[-1] < B.shape[-2]:
90+
head_dim = A.shape[-1]
91+
num_heads = B.shape[-2] // head_dim
92+
B = B.unflatten(-2, (num_heads, head_dim))
93+
return (A @ B).flatten(-3, -2)
94+
else:
95+
return A @ B

0 commit comments

Comments
 (0)