[sgl-kernel performace] fix fp8 quant kernels dispatch __nv_fp8_e4m3 bug to improve performance 10%-20% (#8499)

Co-authored-by: Ke Bao <ispobaoke@gmail.com>
This commit is contained in:
Xiaoyu Zhang
2025-07-29 23:31:54 +08:00
committed by GitHub
parent 813670660c
commit 7a4309cc8a
3 changed files with 21 additions and 23 deletions

View File

@@ -44,10 +44,10 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output
}
}
template <typename T>
template <typename T, typename DST_DTYPE>
__global__ void per_tensor_quant_fp8_kernel(
const T* __restrict__ input,
FP8_TYPE* __restrict__ output,
DST_DTYPE* __restrict__ output,
const float* __restrict__ scale,
const int64_t num_elements) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
@@ -65,12 +65,12 @@ __global__ void per_tensor_quant_fp8_kernel(
vec_t input_vec;
input_vec.cast_load(input + i * VEC_SIZE);
FP8_TYPE output_arr[VEC_SIZE];
DST_DTYPE output_arr[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM
output_arr[j] = static_cast<FP8_TYPE>(val);
output_arr[j] = static_cast<DST_DTYPE>(val);
#else
output_arr[j] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
@@ -84,7 +84,7 @@ __global__ void per_tensor_quant_fp8_kernel(
for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(input[idx]) * scale_val, FP8_E4M3_MAX));
#ifndef USE_ROCM
output[idx] = static_cast<FP8_TYPE>(val);
output[idx] = static_cast<DST_DTYPE>(val);
#else
output[idx] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
@@ -113,9 +113,9 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements);
}
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
per_tensor_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()),
static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
num_elements);
return true;