From 7a4309cc8a56e7a2cffba82a5189b51fd5776259 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Tue, 29 Jul 2025 23:31:54 +0800 Subject: [PATCH] [sgl-kernel performace] fix fp8 quant kernels dispatch __nv_fp8_e4m3 bug to improve performance 10%-20% (#8499) Co-authored-by: Ke Bao --- sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu | 14 ++++----- .../csrc/gemm/per_token_group_quant_8bit.cu | 1 - sgl-kernel/csrc/gemm/per_token_quant_fp8.cu | 29 +++++++++---------- 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu index b10dd96f5..6da13d079 100644 --- a/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu +++ b/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu @@ -44,10 +44,10 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output } } -template +template __global__ void per_tensor_quant_fp8_kernel( const T* __restrict__ input, - FP8_TYPE* __restrict__ output, + DST_DTYPE* __restrict__ output, const float* __restrict__ scale, const int64_t num_elements) { const int gid = blockIdx.x * blockDim.x + threadIdx.x; @@ -65,12 +65,12 @@ __global__ void per_tensor_quant_fp8_kernel( vec_t input_vec; input_vec.cast_load(input + i * VEC_SIZE); - FP8_TYPE output_arr[VEC_SIZE]; + DST_DTYPE output_arr[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { float val = fmax(fmin(static_cast(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX); #ifndef USE_ROCM - output_arr[j] = static_cast(val); + output_arr[j] = static_cast(val); #else output_arr[j] = c10::Float8_e4m3fnuz( __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), @@ -84,7 +84,7 @@ __global__ void per_tensor_quant_fp8_kernel( for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) { float val = fmax(-FP8_E4M3_MAX, fmin(static_cast(input[idx]) * scale_val, FP8_E4M3_MAX)); #ifndef USE_ROCM - output[idx] = static_cast(val); + output[idx] = static_cast(val); #else output[idx] = c10::Float8_e4m3fnuz( __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), @@ -113,9 +113,9 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch static_cast(input.data_ptr()), static_cast(output_s.data_ptr()), num_elements); } - per_tensor_quant_fp8_kernel<<>>( + per_tensor_quant_fp8_kernel<<>>( static_cast(input.data_ptr()), - static_cast(output_q.data_ptr()), + static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), static_cast(output_s.data_ptr()), num_elements); return true; diff --git a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu index d818ddfb8..474164ce6 100644 --- a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu +++ b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu @@ -1,5 +1,4 @@ #include -#include #include #include diff --git a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu index 9367f1584..7b58f838f 100644 --- a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu +++ b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -12,10 +12,10 @@ static constexpr int kWarpSize = 32; // • One warp handles one token. // • Eight tokens per 256‑thread CTA. // --------------------------------------------------------------------------- -template +template __global__ void per_token_quant_fp8_kernel( const T* __restrict__ input, - FP8_TYPE* __restrict__ output_q, + DST_DTYPE* __restrict__ output_q, float* __restrict__ output_s, const int64_t hidden_dim, const int64_t num_tokens) { @@ -26,7 +26,7 @@ __global__ void per_token_quant_fp8_kernel( // Global tensors for this token const T* token_input = input + token_id * hidden_dim; - FP8_TYPE* token_output = output_q + token_id * hidden_dim; + DST_DTYPE* token_output = output_q + token_id * hidden_dim; float* token_scale = output_s + token_id; // @@ -62,14 +62,13 @@ __global__ void per_token_quant_fp8_kernel( for (int i = lane_id; i < num_vec_elems; i += kWarpSize) { vec_t input_vec; input_vec.cast_load(token_input + i * kVecSize); - FP8_TYPE output_arr[kVecSize]; + DST_DTYPE output_arr[kVecSize]; #pragma unroll for (uint32_t j = 0; j < kVecSize; ++j) { float val = static_cast(input_vec[j]) * scale_inv; val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX); - #ifndef USE_ROCM - output_arr[j] = static_cast(val); + output_arr[j] = static_cast(val); #else output_arr[j] = c10::Float8_e4m3fnuz( __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), @@ -83,10 +82,10 @@ __global__ void per_token_quant_fp8_kernel( // --------------------------------------------------------------------------- // 2. Baseline kernel (1 token / CTA, CUB block reduce) // --------------------------------------------------------------------------- -template +template __global__ void per_token_quant_fp8_small_batch_kernel( const T* __restrict__ input, - FP8_TYPE* __restrict__ output_q, + DST_DTYPE* __restrict__ output_q, float* __restrict__ output_s, const int64_t hidden_dim, const int64_t num_tokens) { @@ -97,7 +96,7 @@ __global__ void per_token_quant_fp8_small_batch_kernel( const int block_dim = blockDim.x; const T* token_input = input + token_idx * hidden_dim; - FP8_TYPE* token_output = output_q + token_idx * hidden_dim; + DST_DTYPE* token_output = output_q + token_idx * hidden_dim; float max_value = 0.0f; @@ -135,12 +134,12 @@ __global__ void per_token_quant_fp8_small_batch_kernel( vec_t input_vec; input_vec.cast_load(token_input + i * VEC_SIZE); - FP8_TYPE output_arr[VEC_SIZE]; + DST_DTYPE output_arr[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { float val = fmaxf(fminf(static_cast(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX); #ifndef USE_ROCM - output_arr[j] = static_cast(val); + output_arr[j] = static_cast(val); #else output_arr[j] = c10::Float8_e4m3fnuz( __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), @@ -173,9 +172,9 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: constexpr int THREADS = TOKENS_PER_CTA * kWarpSize; // 256 dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA); dim3 block(THREADS); - per_token_quant_fp8_kernel<<>>( + per_token_quant_fp8_kernel<<>>( static_cast(input.data_ptr()), - static_cast(output_q.data_ptr()), + static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), static_cast(output_s.data_ptr()), hidden_dim, num_tokens); @@ -184,9 +183,9 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: constexpr int THREADS = 256; dim3 grid(num_tokens); dim3 block(THREADS); - per_token_quant_fp8_small_batch_kernel<<>>( + per_token_quant_fp8_small_batch_kernel<<>>( static_cast(input.data_ptr()), - static_cast(output_q.data_ptr()), + static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), static_cast(output_s.data_ptr()), hidden_dim, num_tokens);