diff --git a/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu index 3ad43e7c6..cbf39f041 100644 --- a/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu +++ b/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu @@ -2,20 +2,23 @@ #include #include +#include #include "utils.h" using FP8_TYPE = c10::Float8_e4m3fn; -__device__ __forceinline__ float GroupReduceMax(volatile float* smem, const int tid) { - smem[tid] = fmaxf(smem[tid], smem[tid + 8]); - if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]); - if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]); - if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]); - return smem[0]; +__device__ __forceinline__ float GroupReduceMax(float val, const int tid) { + unsigned mask = 0xffff; + + val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 2)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 1)); + return val; } -template +template __global__ void per_token_group_quant_fp8_kernel( const T* __restrict__ input, void* __restrict__ output_q, @@ -25,46 +28,53 @@ __global__ void per_token_group_quant_fp8_kernel( const float eps, const float fp8_min, const float fp8_max) { - const int groups_per_block = 16; - const int block_group_id = blockIdx.x * groups_per_block; - const int tid = threadIdx.x; - const int local_group_id = tid / 16; - const int local_tid = tid % 16; + const int threads_per_group = 16; + const int local_group_id = threadIdx.x / threads_per_group; + const int lane_id = threadIdx.x % threads_per_group; - __shared__ float s_absmax[16][17]; + const int block_group_id = blockIdx.x * GROUPS_PER_BLOCK; + const int block_group_offset = (block_group_id + local_group_id) * group_size; float local_absmax = eps; - if (block_group_id + local_group_id < num_groups) { - const T* group_input = input + (block_group_id + local_group_id) * group_size; - FP8_TYPE* group_output = static_cast(output_q) + (block_group_id + local_group_id) * group_size; - float* scale_output = output_s + block_group_id + local_group_id; + const T* group_input = input + block_group_offset; + FP8_TYPE* group_output = static_cast(output_q) + block_group_offset; + float* scale_output = output_s + (block_group_id + local_group_id); - for (int i = local_tid; i < group_size; i += 16) { - float val = static_cast(group_input[i]); + constexpr uint32_t vec_size = 16 / sizeof(T); + using vec_t = flashinfer::vec_t; + + const int32_t num_vec_elems = group_size / vec_size; + + for (int32_t i = lane_id; i < num_vec_elems; i += 16) { + vec_t input_vec; + input_vec.cast_load(group_input + i * vec_size); + +#pragma unroll + for (uint32_t j = 0; j < vec_size; ++j) { + float val = static_cast(input_vec[j]); float abs_val = fabsf(val); local_absmax = fmaxf(local_absmax, abs_val); } + } - s_absmax[local_group_id][local_tid] = local_absmax; - __syncthreads(); + local_absmax = GroupReduceMax(local_absmax, lane_id); - if (local_tid < 8) { - GroupReduceMax(&s_absmax[local_group_id][0], local_tid); - } - __syncthreads(); + const float y_s = local_absmax / fp8_max; - const float group_absmax = s_absmax[local_group_id][0]; - const float y_s = group_absmax / fp8_max; + if (lane_id == 0) { + *scale_output = y_s; + } - if (local_tid == 0) { - *scale_output = y_s; - } + for (int32_t i = lane_id; i < num_vec_elems; i += 16) { + vec_t input_vec; + input_vec.cast_load(group_input + i * vec_size); - for (int i = local_tid; i < group_size; i += 16) { - float val = static_cast(group_input[i]); +#pragma unroll + for (uint32_t j = 0; j < vec_size; ++j) { + float val = static_cast(input_vec[j]); float q_val = fminf(fmaxf(val / y_s, fp8_min), fp8_max); - group_output[i] = FP8_TYPE(q_val); + group_output[i * vec_size + j] = FP8_TYPE(q_val); } } } @@ -85,21 +95,52 @@ void sgl_per_token_group_quant_fp8( CHECK_EQ(input.numel() % group_size, 0); - dim3 grid((num_groups + 15) / 16); - dim3 block(256); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + constexpr int THREADS_PER_GROUP = 16; + + int groups_per_block = 1; + + if (num_groups % 16 == 0) { + groups_per_block = 16; + } else if (num_groups % 8 == 0) { + groups_per_block = 8; + } else if (num_groups % 4 == 0) { + groups_per_block = 4; + } else if (num_groups % 2 == 0) { + groups_per_block = 2; + } + +#define LAUNCH_KERNEL(T, GPB) \ + do { \ + constexpr int GROUPS_PER_BLOCK = GPB; \ + dim3 grid((num_groups + GROUPS_PER_BLOCK - 1) / GROUPS_PER_BLOCK); \ + dim3 block(GROUPS_PER_BLOCK* THREADS_PER_GROUP); \ + per_token_group_quant_fp8_kernel<<>>( \ + static_cast(input.data_ptr()), \ + output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), \ + group_size, \ + num_groups, \ + (float)eps, \ + (float)fp8_min, \ + (float)fp8_max); \ + } while (0) + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { - per_token_group_quant_fp8_kernel<<>>( - static_cast(input.data_ptr()), - output_q.data_ptr(), - static_cast(output_s.data_ptr()), - group_size, - num_groups, - (float)eps, - (float)fp8_min, - (float)fp8_max); + if (groups_per_block == 16) { + LAUNCH_KERNEL(scalar_t, 16); + } else if (groups_per_block == 8) { + LAUNCH_KERNEL(scalar_t, 8); + } else if (groups_per_block == 4) { + LAUNCH_KERNEL(scalar_t, 4); + } else if (groups_per_block == 2) { + LAUNCH_KERNEL(scalar_t, 2); + } else { + LAUNCH_KERNEL(scalar_t, 1); + } return true; }); + +#undef LAUNCH_KERNEL } diff --git a/sgl-kernel/tests/test_per_token_group_quant_fp8.py b/sgl-kernel/tests/test_per_token_group_quant_fp8.py index ddc11b86b..9fa7c9bd1 100644 --- a/sgl-kernel/tests/test_per_token_group_quant_fp8.py +++ b/sgl-kernel/tests/test_per_token_group_quant_fp8.py @@ -149,9 +149,9 @@ def sglang_per_token_group_quant_fp8( "batch_size, seq_len, group_size", list( itertools.product( - [1, 2, 4, 8, 16], # batch_size + [1, 2, 4, 8, 16, 32, 64, 128], # batch_size [64, 128, 256, 512, 1024, 2048], # seq_len - [64, 128, 256], # group_size + [16, 32, 64, 128, 256], # group_size ) ), )