Skip to content

Commit 1f179b9

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1469367 commit 1f179b9

File tree

5 files changed

+23
-21
lines changed

5 files changed

+23
-21
lines changed

auto_round/experimental/transform/apply.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,14 @@ def _apply_to_module(
5656
if config.location == "input":
5757
from .triton.mxfp4 import mxfp4_forward_kernel_wrapper
5858

59-
transform = build_transform(
60-
**config.dict(),
61-
device="cpu",
62-
precision=module.dtype,
63-
location="input"
64-
)
59+
transform = build_transform(**config.dict(), device="cpu", precision=module.dtype, location="input")
6560

6661
def input_hook(_, args):
6762
input = args[0]
6863
# transform(input)
6964
orig_shape = input.shape
7065
x_flat = input.contiguous().flatten(end_dim=-2)
71-
qdq_input, _ = mxfp4_forward_kernel_wrapper(
72-
x_flat, transform.weight
73-
)
66+
qdq_input, _ = mxfp4_forward_kernel_wrapper(x_flat, transform.weight)
7467
return qdq_input.reshape(orig_shape)
7568

7669
# for fused transform + quantization kernel

auto_round/experimental/transform/transforms.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,17 @@ def forward(self, x: torch.Tensor):
7272
ori_shape = x.shape
7373
x = x.view(-1, self.size)
7474
return (
75-
apply_transform_weight(
76-
self.weight.to(device=x.device),
77-
x.to(dtype=self.weight.dtype),
78-
self.location,
79-
self.module_type,
75+
(
76+
apply_transform_weight(
77+
self.weight.to(device=x.device),
78+
x.to(dtype=self.weight.dtype),
79+
self.location,
80+
self.module_type,
81+
)
8082
)
81-
).to(x.dtype).view(ori_shape)
83+
.to(x.dtype)
84+
.view(ori_shape)
85+
)
8286

8387

8488
TRANSFORMS = {
@@ -90,4 +94,3 @@ def forward(self, x: torch.Tensor):
9094
def build_transform(transform_type: str, **transform_kwargs):
9195
transform = TRANSFORMS[transform_type]
9296
return transform(**filter_kwarg_dict(transform.__init__, transform_kwargs))
93-
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# # Copyright (C) 2026 Intel Corporation
2+
# # SPDX-License-Identifier: Apache-2.0

auto_round/experimental/transform/utils/hadamard.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
# # Copyright (C) 2026 Intel Corporation
2+
# # SPDX-License-Identifier: Apache-2.0
13

24
import math
35
from pathlib import Path
46

57
import torch
68
from safetensors import safe_open
79

8-
910
REPO_PATH = Path(__file__).parent / "hadamards.safetensors"
1011

1112

@@ -25,7 +26,7 @@ def deterministic_hadamard_matrix(
2526
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
2627
`n` must be a power of 2.
2728
28-
Adapated from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501
29+
Adapted from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501
2930
3031
:param size: order of the matrix, must be a power of 2
3132
:param dtype: data type of matrix
@@ -59,7 +60,7 @@ def random_hadamard_matrix(
5960
`deterministic_hadamard_matrix` in that this function supports non powers of 2
6061
and randomization using a seeded generator
6162
62-
Adapated from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501
63+
Adapted from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501
6364
Known matrices were retrieved from N. J. A. Sloane's Library of Hadamard Matrices http://www.neilsloane.com/hadamard/ # noqa: E501
6465
6566
:param size: The dimension of the hamadard matrix

auto_round/experimental/transform/utils/matrix.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
import torch
1+
# # Copyright (C) 2026 Intel Corporation
2+
# # SPDX-License-Identifier: Apache-2.0
23

4+
import torch
35

46
__all__ = ["apply_transform_weight"]
57

68
# note that apply_transform_weight reuses some code from
79
# https://github.com/vllm-project/compressed-tensors/blob/main/src/compressed_tensors/transform/utils/matrix.py
810

11+
912
def apply_transform_weight(
1013
transform_weight: torch.Tensor,
1114
value: torch.Tensor,
@@ -75,7 +78,7 @@ def _multihead_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
7578
[ B ]
7679
[ B]]
7780
78-
This function will error out if the shapes are not evenly divisble
81+
This function will error out if the shapes are not evenly divisible
7982
8083
:param A: left-hand tensor
8184
:param B: right-hand tensor

0 commit comments

Comments
 (0)