From b3251e9f40b85159d52563b9ca8276fa0fa03703 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Sat, 8 Mar 2025 21:47:35 +0800 Subject: [PATCH] refine quant kernel code style (#4211) --- .../sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu | 13 +------------ .../sgl-kernel/csrc/gemm/per_token_quant_fp8.cu | 14 +------------- sgl-kernel/src/sgl-kernel/include/utils.h | 16 ++++++++++++++++ 3 files changed, 18 insertions(+), 25 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu index ea222c001..a95d5ea72 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu @@ -37,18 +37,7 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output max_value = fmaxf(max_value, fabsf(val)); } - static __shared__ float warpLevelMaxs[WARP_SIZE]; - const int laneId = threadIdx.x % WARP_SIZE; - const int warpId = threadIdx.x / WARP_SIZE; - - max_value = warpReduceMax(max_value); - - if (laneId == 0) warpLevelMaxs[warpId] = max_value; - __syncthreads(); - - max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; - - if (warpId == 0) max_value = warpReduceMax(max_value); + max_value = blockReduceMax(max_value); if (tid == 0) { atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX); diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu index 1491af126..12616ff44 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -30,19 +30,7 @@ __global__ void per_token_quant_fp8_kernel( max_value = fmaxf(max_value, fabsf(val)); } - max_value = warpReduceMax(max_value); - - static __shared__ float warpLevelMaxs[WARP_SIZE]; - const int laneId = threadIdx.x % WARP_SIZE; - const int warpId = threadIdx.x / WARP_SIZE; - - if (laneId == 0) warpLevelMaxs[warpId] = max_value; - __syncthreads(); - - if (warpId == 0) { - max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; - max_value = warpReduceMax(max_value); - } + max_value = blockReduceMax(max_value); __shared__ float block_max; if (tid == 0) { diff --git a/sgl-kernel/src/sgl-kernel/include/utils.h b/sgl-kernel/src/sgl-kernel/include/utils.h index 79bf84671..c099bf5aa 100644 --- a/sgl-kernel/src/sgl-kernel/include/utils.h +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -124,4 +124,20 @@ __device__ __forceinline__ float warpReduceMax(float max_value) { max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1)); return max_value; } + +__device__ __forceinline__ float blockReduceMax(float max_value) { + static __shared__ float warpLevelMaxs[WARP_SIZE]; + const int laneId = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + + max_value = warpReduceMax(max_value); + + if (laneId == 0) warpLevelMaxs[warpId] = max_value; + __syncthreads(); + + max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; + if (warpId == 0) max_value = warpReduceMax(max_value); + + return max_value; +} #endif