Skip to content

Commit ac89705

Browse files
install torchao mps ops by default when running on Apple Silicon
1 parent 4d8bab5 commit ac89705

File tree

2 files changed

+13
-38
lines changed

2 files changed

+13
-38
lines changed

Diff for: install/install_torchao.sh

+8-1
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,16 @@ else
3131
fi
3232
echo "Using pip executable: $PIP_EXECUTABLE"
3333

34+
if [[ $(uname -s) == "Darwin" && $(uname -m) == "arm64" ]]; then
35+
echo "Building torchao experimental mps ops (Apple Silicon detected)"
36+
APPLE_SILICON_DETECTED=1
37+
else
38+
echo "NOT building torchao experimental mps ops (Apple Silicon NOT detected)"
39+
APPLE_SILICON_DETECTED=0
40+
fi
3441

3542
export TORCHAO_PIN=$(cat install/.pins/torchao-pin.txt)
3643
(
3744
set -x
38-
USE_CPP=1 $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@${TORCHAO_PIN}
45+
USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=${APPLE_SILICON_DETECTED} $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@${TORCHAO_PIN}
3946
)

Diff for: torchchat/utils/quantize.py

+5-37
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from torchao.experimental.quant_api import (
5757
int8_dynamic_activation_intx_weight,
5858
IntxWeightEmbeddingQuantizer,
59+
UIntxWeightOnlyLinearQuantizer,
5960
)
6061
from torchao.quantization.granularity import (
6162
PerGroup,
@@ -137,12 +138,12 @@ def quantize_model(
137138
group_size = q_kwargs["groupsize"]
138139
bit_width = q_kwargs["bitwidth"]
139140
has_weight_zeros = q_kwargs["has_weight_zeros"]
140-
granularity = PerRow() if group_size == -1 else PerGroup(group_size)
141+
granularity = PerRow() if group_size == -1 else PerGroup(group_size)
141142
weight_dtype = getattr(torch, f"int{bit_width}")
142143

143144
try:
144145
quantize_(
145-
model,
146+
model,
146147
int8_dynamic_activation_intx_weight(
147148
weight_dtype=weight_dtype,
148149
granularity=granularity,
@@ -154,7 +155,7 @@ def quantize_model(
154155
print("Encountered error during quantization: {e}")
155156
print("Trying with PlainLayout")
156157
quantize_(
157-
model,
158+
model,
158159
int8_dynamic_activation_intx_weight(
159160
weight_dtype=weight_dtype,
160161
granularity=granularity,
@@ -946,38 +947,5 @@ def quantized_model(self) -> nn.Module:
946947
"linear:int4": Int4WeightOnlyQuantizer,
947948
"linear:a8wxdq": None, # uses quantize_ API
948949
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
950+
"linear:afpwx": UIntxWeightOnlyLinearQuantizer,
949951
}
950-
951-
try:
952-
import importlib.util
953-
import os
954-
import sys
955-
956-
torchao_build_path = f"{os.getcwd()}/torchao-build"
957-
958-
# Try loading quantizer
959-
torchao_experimental_quant_api_spec = importlib.util.spec_from_file_location(
960-
"torchao_experimental_quant_api",
961-
f"{torchao_build_path}/src/ao/torchao/experimental/quant_api.py",
962-
)
963-
torchao_experimental_quant_api = importlib.util.module_from_spec(
964-
torchao_experimental_quant_api_spec
965-
)
966-
sys.modules["torchao_experimental_quant_api"] = torchao_experimental_quant_api
967-
torchao_experimental_quant_api_spec.loader.exec_module(
968-
torchao_experimental_quant_api
969-
)
970-
from torchao_experimental_quant_api import UIntxWeightOnlyLinearQuantizer
971-
quantizer_class_dict["linear:afpwx"] = UIntxWeightOnlyLinearQuantizer
972-
973-
# Try loading custom op
974-
try:
975-
libname = "libtorchao_ops_mps_aten.dylib"
976-
libpath = f"{torchao_build_path}/cmake-out/lib/{libname}"
977-
torch.ops.load_library(libpath)
978-
print("Loaded torchao mps ops.")
979-
except Exception as e:
980-
print("Unable to load torchao mps ops library.")
981-
982-
except Exception as e:
983-
print("Unable to import torchao experimental quant_api with error: ", e)

0 commit comments

Comments
 (0)