diff --git a/sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py b/sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py index bdd0c2b2e..c56df30f5 100644 --- a/sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py +++ b/sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py @@ -186,7 +186,7 @@ configs = list(itertools.product(batch_size_range, seq_len_range, group_size_ran def benchmark(batch_size, seq_len, group_size, provider): dtype = torch.bfloat16 device = torch.device("cuda") - hidden_dim = group_size * 2 + hidden_dim = 7168 x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) diff --git a/sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu index 5afe03801..e5a14602a 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu @@ -2,17 +2,18 @@ #include #include +#include #include "utils.h" using FP8_TYPE = c10::Float8_e4m3fn; -__device__ __forceinline__ float GroupReduce(volatile float* smem, const int tid) { - smem[tid] = fmaxf(smem[tid], smem[tid + 8]); - if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]); - if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]); - if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]); - return smem[0]; +__device__ __forceinline__ float GroupReduce(float val, const int tid) { + val = fmaxf(val, __shfl_xor_sync(0xffff, val, 8)); + val = fmaxf(val, __shfl_xor_sync(0xffff, val, 4)); + val = fmaxf(val, __shfl_xor_sync(0xffff, val, 2)); + val = fmaxf(val, __shfl_xor_sync(0xffff, val, 1)); + return val; } template @@ -21,54 +22,60 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo const int num_groups, const float eps, const float fp8_min, const float fp8_max) { const int groups_per_block = 16; + const int local_group_id = threadIdx.x / 16; + const int lane_id = threadIdx.x % 16; + const int block_group_id = blockIdx.x * groups_per_block; - const int tid = threadIdx.x; - const int local_group_id = tid / 16; // Each 16 threads handle one group - const int local_tid = tid % 16; // Thread ID within the group + const int block_group_offset = (block_group_id + local_group_id) * group_size; - __shared__ float s_absmax[16][17]; // Use 17 instead of 16 to avoid bank conflicts + __shared__ float s_absmax[16]; - // Local maximum value for each thread float local_absmax = eps; - // Ensure this block doesn't process out-of-bounds groups - if (block_group_id + local_group_id < num_groups) { - // Calculate input/output pointers for current group - const T* group_input = input + (block_group_id + local_group_id) * group_size; - FP8_TYPE* group_output = static_cast(output_q) + (block_group_id + local_group_id) * group_size; - float* scale_output = output_s + block_group_id + local_group_id; + const T* group_input = input + block_group_offset; + FP8_TYPE* group_output = static_cast(output_q) + block_group_offset; + float* scale_output = output_s + block_group_id + local_group_id; - // Calculate local maximum absolute value - for (int i = local_tid; i < group_size; i += 16) { - float val = static_cast(group_input[i]); + constexpr uint32_t vec_size = 16 / sizeof(T); + using vec_t = flashinfer::vec_t; + + const int32_t num_vec_elems = group_size / vec_size; + + for (int32_t i = lane_id; i < num_vec_elems; i += 16) { + vec_t input_vec; + input_vec.cast_load(group_input + i * vec_size); + +#pragma unroll + for (uint32_t j = 0; j < vec_size; ++j) { + float val = static_cast(input_vec[j]); float abs_val = fabsf(val); local_absmax = fmaxf(local_absmax, abs_val); } + } - // Store in shared memory - s_absmax[local_group_id][local_tid] = local_absmax; - __syncthreads(); + local_absmax = GroupReduce(local_absmax, lane_id); - // Perform reduction within each group - if (local_tid < 8) { - GroupReduce(&s_absmax[local_group_id][0], local_tid); - } - __syncthreads(); + if (lane_id == 0) { + s_absmax[local_group_id] = local_absmax; + } + __syncthreads(); - // Get the maximum value for this group - const float group_absmax = s_absmax[local_group_id][0]; - const float y_s = group_absmax / fp8_max; + const float group_absmax = s_absmax[local_group_id]; + const float y_s = group_absmax / fp8_max; - // Only the first thread in each group writes the scale - if (local_tid == 0) { - *scale_output = y_s; - } + if (lane_id == 0) { + *scale_output = y_s; + } - // Quantize the data - for (int i = local_tid; i < group_size; i += 16) { - float val = static_cast(group_input[i]); + for (int32_t i = lane_id; i < num_vec_elems; i += 16) { + vec_t input_vec; + input_vec.cast_load(group_input + i * vec_size); + +#pragma unroll + for (uint32_t j = 0; j < vec_size; ++j) { + float val = static_cast(input_vec[j]); float q_val = fminf(fmaxf(val / y_s, fp8_min), fp8_max); - group_output[i] = FP8_TYPE(q_val); + group_output[i * vec_size + j] = FP8_TYPE(q_val); } } } @@ -83,9 +90,8 @@ void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q, CHECK_EQ(input.numel() % group_size, 0); - // Each block processes 16 groups, adjust grid size accordingly dim3 grid((num_groups + 15) / 16); - dim3 block(256); // Keep 256 threads, each 16 threads handle one group + dim3 block(256); cudaStream_t stream = at::cuda::getCurrentCUDAStream();