Fix per token fp8 quant precision (#4362)

This commit is contained in:
Qingquan Song
2025-03-12 21:19:05 -07:00
committed by GitHub
parent 817d43705c
commit 4068e01292
3 changed files with 5 additions and 13 deletions

View File

@@ -49,8 +49,6 @@ __global__ void per_token_quant_fp8_kernel(
}
__syncthreads();
const float scale_val = 1.0f / block_max;
// Quantize using vectorized loads
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
vec_t input_vec;
@@ -59,7 +57,7 @@ __global__ void per_token_quant_fp8_kernel(
FP8_TYPE output_arr[vec_size];
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) / block_max, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM
output_arr[j] = static_cast<FP8_TYPE>(val);
#else