fix accuracy issue (#4376)

This commit is contained in:
Yineng Zhang
2025-03-13 02:06:22 -07:00
committed by GitHub
parent cf721fdece
commit 2937387a50
4 changed files with 17 additions and 5 deletions

View File

@@ -49,6 +49,8 @@ __global__ void per_token_quant_fp8_kernel(
}
__syncthreads();
const float scale_val = 1.0f / block_max;
// Quantize using vectorized loads
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
vec_t input_vec;
@@ -57,7 +59,7 @@ __global__ void per_token_quant_fp8_kernel(
FP8_TYPE output_arr[vec_size];
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) / block_max, FP8_E4M3_MAX), -FP8_E4M3_MAX);
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM
output_arr[j] = static_cast<FP8_TYPE>(val);
#else