Skip to content

Commit 515cbbd

Browse files
authored
Fix zero scale in fp8 quantization (#3652)
* fix zero scale in fp8 quant * symmetry
1 parent 4b3109b commit 515cbbd

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/turbomind/kernels/quantization.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ __global__ void quant_symm_row(
2626
for (int di = threadIdx.x * vec_size; di < dim; di += blockDim.x * vec_size) {
2727
Array<T, vec_size> vec;
2828
Ldg(vec, src + ti * src_ld + di);
29-
auto absmax = static_cast<Tscale>(find_absmax<threads>(vec));
29+
auto absmax = fmaxf(static_cast<Tscale>(find_absmax<threads>(vec)), 1e-8f);
3030
const Tscale scale = absmax / qmax;
3131
const Tscale inv_scale = qmax / absmax;
3232
if (threadIdx.x % threads == 0) {
@@ -179,7 +179,7 @@ __global__ void quant_symm_block(Tout* out, Tscale* scales, const T* src, Tscale
179179

180180
absmax = BlockReduce{temp_storage}.Reduce(absmax, [](auto a, auto b) { return __hmax(a, b); });
181181
if (threadIdx.x == 0) {
182-
auto maxval = static_cast<Tscale>(absmax);
182+
auto maxval = fmaxf(static_cast<Tscale>(absmax), 1e-8f);
183183
scales[blockIdx.x * gridDim.y + blockIdx.y] = maxval / qmax;
184184
shared_inv_scale = qmax / maxval;
185185
}

0 commit comments

Comments
 (0)