fix accuracy issue (#4376)
This commit is contained in:
@@ -49,6 +49,8 @@ __global__ void per_token_quant_fp8_kernel(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float scale_val = 1.0f / block_max;
|
||||
|
||||
// Quantize using vectorized loads
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
@@ -57,7 +59,7 @@ __global__ void per_token_quant_fp8_kernel(
|
||||
FP8_TYPE output_arr[vec_size];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) / block_max, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
#ifndef USE_ROCM
|
||||
output_arr[j] = static_cast<FP8_TYPE>(val);
|
||||
#else
|
||||
|
||||
Reference in New Issue
Block a user