From 0c8dab9e67b1fe0d274a27af03540b2ce5525a37 Mon Sep 17 00:00:00 2001 From: Yuan Luo Date: Wed, 23 Jul 2025 21:22:59 +0800 Subject: [PATCH] [sgl-kernel] Opt per_token_quant_fp8 with warp reduce (#8130) Co-authored-by: luoyuan.luo --- sgl-kernel/csrc/gemm/per_token_quant_fp8.cu | 122 +++++++++++++++++--- 1 file changed, 106 insertions(+), 16 deletions(-) diff --git a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu index db09483ce..9367f1584 100644 --- a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu +++ b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -1,18 +1,95 @@ #include #include -#include #include #include "utils.h" -template +static constexpr int kWarpSize = 32; + +// --------------------------------------------------------------------------- +// 1. Warp‑local, no shared memory +// • One warp handles one token. +// • Eight tokens per 256‑thread CTA. +// --------------------------------------------------------------------------- +template __global__ void per_token_quant_fp8_kernel( const T* __restrict__ input, FP8_TYPE* __restrict__ output_q, float* __restrict__ output_s, const int64_t hidden_dim, const int64_t num_tokens) { + const int warp_id = threadIdx.x / kWarpSize; // 0‑7 (8 warps) + const int lane_id = threadIdx.x & (kWarpSize - 1); // 0‑31 + const int token_id = blockIdx.x * kTokensPerCTA + warp_id; + if (token_id >= num_tokens) return; + + // Global tensors for this token + const T* token_input = input + token_id * hidden_dim; + FP8_TYPE* token_output = output_q + token_id * hidden_dim; + float* token_scale = output_s + token_id; + + // + // Pass-1: Perform a warp reduce to find the max_value of a token's hidden_dim + // + float max_value = 0.f; + using vec_t = flashinfer::vec_t; + const int32_t num_vec_elems = hidden_dim / kVecSize; + + for (int32_t i = lane_id; i < num_vec_elems; i += kWarpSize) { + vec_t input_vec; + input_vec.cast_load(token_input + i * kVecSize); + +#pragma unroll + for (uint32_t j = 0; j < kVecSize; ++j) { + max_value = fmaxf(max_value, fabsf(static_cast(input_vec[j]))); + } + } + + float warp_max = warpReduceMax(max_value); + + __shared__ float scale; + scale = warp_max / FP8_E4M3_MAX; + // Broadcast scale + if (lane_id == 0) { + token_scale[0] = scale; + } + float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale; + + // + // Pass-2: quantize and write back + // + for (int i = lane_id; i < num_vec_elems; i += kWarpSize) { + vec_t input_vec; + input_vec.cast_load(token_input + i * kVecSize); + FP8_TYPE output_arr[kVecSize]; +#pragma unroll + for (uint32_t j = 0; j < kVecSize; ++j) { + float val = static_cast(input_vec[j]) * scale_inv; + val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX); + +#ifndef USE_ROCM + output_arr[j] = static_cast(val); +#else + output_arr[j] = c10::Float8_e4m3fnuz( + __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), + c10::Float8_e4m3fnuz::from_bits()); +#endif + } + *(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr; + } +} + +// --------------------------------------------------------------------------- +// 2. Baseline kernel (1 token / CTA, CUB block reduce) +// --------------------------------------------------------------------------- +template +__global__ void per_token_quant_fp8_small_batch_kernel( + const T* __restrict__ input, + FP8_TYPE* __restrict__ output_q, + float* __restrict__ output_s, + const int64_t hidden_dim, + const int64_t num_tokens) { const int token_idx = blockIdx.x; if (token_idx >= num_tokens) return; @@ -79,28 +156,41 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: CHECK_INPUT(input); CHECK_INPUT(output_q); CHECK_INPUT(output_s); - const auto input_sizes = input.sizes(); const int64_t num_tokens = input_sizes[0]; const int64_t hidden_dim = input_sizes[1]; - TORCH_CHECK(hidden_dim % 16 == 0, "Hidden dimension must be divisible by 16, but got ", hidden_dim); - const int block_size = 256; - const int num_blocks = num_tokens; - - dim3 grid(num_blocks); - dim3 block(block_size); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // Hard-code sm_count + int sm_count = 132; + constexpr int TOKENS_PER_CTA = 8; + const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { - per_token_quant_fp8_kernel<<>>( - static_cast(input.data_ptr()), - static_cast(output_q.data_ptr()), - static_cast(output_s.data_ptr()), - hidden_dim, - num_tokens); + if (use_warp_kernel) { + // -------- warp‑local --------------------------------------------------- + constexpr int THREADS = TOKENS_PER_CTA * kWarpSize; // 256 + dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA); + dim3 block(THREADS); + per_token_quant_fp8_kernel<<>>( + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + hidden_dim, + num_tokens); + } else { + // -------- baseline ----------------------------------------------------- + constexpr int THREADS = 256; + dim3 grid(num_tokens); + dim3 block(THREADS); + per_token_quant_fp8_small_batch_kernel<<>>( + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + hidden_dim, + num_tokens); + } return true; }); }