From 6a384d5c012e424e5baf9891efa5465088e807dc Mon Sep 17 00:00:00 2001 From: Chunan Zeng Date: Sat, 22 Mar 2025 00:37:57 -0700 Subject: [PATCH] Speed up per token and per tensor quant by 15% (#4639) --- sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu | 26 +++++++-------- sgl-kernel/csrc/gemm/per_token_quant_fp8.cu | 35 ++++++++++---------- 2 files changed, 29 insertions(+), 32 deletions(-) diff --git a/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu index a95d5ea72..b10dd96f5 100644 --- a/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu +++ b/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu @@ -54,42 +54,40 @@ __global__ void per_tensor_quant_fp8_kernel( const int grid_size = blockDim.x * gridDim.x; const float scale_val = 1.0f / (*scale); - constexpr uint32_t vec_size = 16 / sizeof(T); - using vec_t = flashinfer::vec_t; + // We want to store 128 bits of data at a time. 16 = 128 / 8 bits + // Load is already vectorized, so 16 elements work for T. + const uint32_t VEC_SIZE = 16; + using vec_t = flashinfer::vec_t; - const int32_t num_vec_elems = num_elements / vec_size; + const int32_t num_vec_elems = num_elements / VEC_SIZE; for (int32_t i = gid; i < num_vec_elems; i += grid_size) { vec_t input_vec; - input_vec.cast_load(input + i * vec_size); + input_vec.cast_load(input + i * VEC_SIZE); - FP8_TYPE output_arr[vec_size]; + FP8_TYPE output_arr[VEC_SIZE]; #pragma unroll - for (uint32_t j = 0; j < vec_size; ++j) { + for (uint32_t j = 0; j < VEC_SIZE; ++j) { float val = fmax(fmin(static_cast(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX); #ifndef USE_ROCM output_arr[j] = static_cast(val); #else output_arr[j] = c10::Float8_e4m3fnuz( - __hip_cvt_float_to_fp8(value, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), + __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), c10::Float8_e4m3fnuz::from_bits()); #endif } - -#pragma unroll - for (uint32_t j = 0; j < vec_size; ++j) { - output[i * vec_size + j] = output_arr[j]; - } + *(uint4*)(output + i * VEC_SIZE) = *(uint4*)output_arr; } - const int32_t remaining_start = num_vec_elems * vec_size; + const int32_t remaining_start = num_vec_elems * VEC_SIZE; for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) { float val = fmax(-FP8_E4M3_MAX, fmin(static_cast(input[idx]) * scale_val, FP8_E4M3_MAX)); #ifndef USE_ROCM output[idx] = static_cast(val); #else output[idx] = c10::Float8_e4m3fnuz( - __hip_cvt_float_to_fp8(value, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), + __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), c10::Float8_e4m3fnuz::from_bits()); #endif } diff --git a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu index 9c3b67768..db09483ce 100644 --- a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu +++ b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -24,17 +24,19 @@ __global__ void per_token_quant_fp8_kernel( float max_value = 0.0f; - constexpr uint32_t vec_size = 16 / sizeof(T); - using vec_t = flashinfer::vec_t; - const int32_t num_vec_elems = hidden_dim / vec_size; + // We want to store 128 bits of data at a time. 16 = 128 / 8 bits + // Load is already vectorized, so 16 elements work for T. + const uint32_t VEC_SIZE = 16; + using vec_t = flashinfer::vec_t; + const int32_t num_vec_elems = hidden_dim / VEC_SIZE; // Find max using vectorized loads for (int32_t i = tid; i < num_vec_elems; i += block_dim) { vec_t input_vec; - input_vec.cast_load(token_input + i * vec_size); + input_vec.cast_load(token_input + i * VEC_SIZE); #pragma unroll - for (uint32_t j = 0; j < vec_size; ++j) { + for (uint32_t j = 0; j < VEC_SIZE; ++j) { float val = static_cast(input_vec[j]); max_value = fmaxf(max_value, fabsf(val)); } @@ -42,24 +44,24 @@ __global__ void per_token_quant_fp8_kernel( max_value = blockReduceMax(max_value); - __shared__ float block_max; + __shared__ float scale; if (tid == 0) { - block_max = max_value / FP8_E4M3_MAX; - output_s[token_idx] = block_max; + scale = max_value / FP8_E4M3_MAX; + output_s[token_idx] = scale; } __syncthreads(); - const float scale_val = 1.0f / block_max; + const float scale_inv = 1.0f / scale; // Quantize using vectorized loads for (int32_t i = tid; i < num_vec_elems; i += block_dim) { vec_t input_vec; - input_vec.cast_load(token_input + i * vec_size); + input_vec.cast_load(token_input + i * VEC_SIZE); - FP8_TYPE output_arr[vec_size]; + FP8_TYPE output_arr[VEC_SIZE]; #pragma unroll - for (uint32_t j = 0; j < vec_size; ++j) { - float val = fmaxf(fminf(static_cast(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX); + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + float val = fmaxf(fminf(static_cast(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX); #ifndef USE_ROCM output_arr[j] = static_cast(val); #else @@ -69,10 +71,7 @@ __global__ void per_token_quant_fp8_kernel( #endif } -#pragma unroll - for (uint32_t j = 0; j < vec_size; ++j) { - token_output[i * vec_size + j] = output_arr[j]; - } + *(uint4*)(token_output + i * VEC_SIZE) = *(uint4*)output_arr; } } @@ -85,7 +84,7 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: const int64_t num_tokens = input_sizes[0]; const int64_t hidden_dim = input_sizes[1]; - TORCH_CHECK(hidden_dim % 8 == 0, "Hidden dimension must be divisible by 8, but got ", hidden_dim); + TORCH_CHECK(hidden_dim % 16 == 0, "Hidden dimension must be divisible by 16, but got ", hidden_dim); const int block_size = 256; const int num_blocks = num_tokens;