Skip to content

Commit 1ad2e6c

Browse files
committed
revert cpu changes
1 parent df73d3e commit 1ad2e6c

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

bitsandbytes/nn/modules.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,8 +671,9 @@ def to(self, *args, **kwargs):
671671
if device is not None and device.type != "meta" and self.data.device.type == "cpu":
672672
if device.type != "cpu" or self.data.dtype != torch.int8:
673673
return self._quantize(device)
674-
elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu") and (ipex_cpu or ipex_xpu):
675-
self.CB = self.data
674+
elif self.data.dtype == torch.int8:
675+
if device.type == "cpu" or (device.type == "xpu" and ipex_xpu):
676+
self.CB = self.data
676677

677678
new_param = Int8Params(
678679
super().to(device=device, dtype=dtype, non_blocking=non_blocking),

tests/test_functional.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,10 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
137137
abserr = sum(diffs) / len(diffs)
138138
relerr = sum(reldiffs) / len(reldiffs)
139139
if signed:
140-
threshold_abserr = 0.0036 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0035
141140
assert abserr < 0.0036
142141
assert relerr < 0.015
143142
else:
144-
assert abserr < 0.00175 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0023
143+
assert abserr < 0.00175 if (device in "cpu") or (device in "xpu" and F.ipex_xpu) else 0.0023
145144
assert relerr < 0.012
146145
assert A2.dtype == dtype
147146

@@ -172,8 +171,10 @@ def test_blockwise_cpu_large(self, hidden, blocksize):
172171
@pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits"))
173172
@pytest.mark.parametrize("method", ["linear", "fp8", "dynamic"])
174173
def test_few_bit_quant(self, device, bits, method):
175-
if bits != 8 and (device == "cpu" or (device == "xpu" and F.ipex_xpu)):
176-
pytest.skip("CPU/XPU implementation only supports 8 bits")
174+
if device in "cpu" and bits != 8:
175+
pytest.skip("CPU implementation only supports 8 bits")
176+
if device in "xpu" and bits != 8 and F.ipex_xpu:
177+
pytest.skip("XPU ipex implementation only supports 8 bits")
177178

178179
abserrs = []
179180
relerrs = []

0 commit comments

Comments
 (0)