diff --git a/sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py b/sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py new file mode 100644 index 000000000..bdd0c2b2e --- /dev/null +++ b/sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py @@ -0,0 +1,209 @@ +import itertools +import math +from typing import Any, Dict, List, Optional, Tuple + +import torch +import triton +import triton.language as tl +from sgl_kernel import sgl_per_token_group_quant_fp8 + +from sglang.srt.utils import get_device_core_count, get_device_name, is_hip + +is_hip_ = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + # Stride of input + y_stride, + # Collums of input + N, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group quantization on a + tensor. + + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * y_stride + y_q_ptr += g_id * y_stride + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < N + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def triton_per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: torch.dtype = fp8_type_, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + + Args: + x: The input tenosr with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. + """ + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_max = finfo.max + + fp8_min = -fp8_max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + N, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + + +def sglang_per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: torch.dtype = fp8_type_, +): + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_max = finfo.max + + fp8_min = -fp8_max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) + + sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) + + return x_q, x_s + + +def calculate_diff(batch_size, seq_len, group_size): + dtype = torch.float16 + device = torch.device("cuda") + hidden_dim = group_size * 2 + + x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + + x_q_triton, x_s_triton = triton_per_token_group_quant_fp8(x.clone(), group_size) + x_q_sglang, x_s_sglang = sglang_per_token_group_quant_fp8(x.clone(), group_size) + + if torch.allclose( + x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5 + ) and torch.allclose(x_s_triton, x_s_sglang, rtol=1e-3, atol=1e-5): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [1, 2, 4, 8, 16, 32, 64] +seq_len_range = [64, 128, 256, 512, 1024, 2048] +group_size_range = [128] # For DeepSeek V3/R1 + +configs = list(itertools.product(batch_size_range, seq_len_range, group_size_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "group_size"], + x_vals=configs, + line_arg="provider", + line_vals=["triton", "sglang"], + line_names=["Triton", "SGL Kernel"], + styles=[("blue", "-"), ("green", "-")], + ylabel="us", + plot_name="per-token-group-quant-fp8-performance", + args={}, + ) +) +def benchmark(batch_size, seq_len, group_size, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + hidden_dim = group_size * 2 + + x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "triton": + fn = lambda: triton_per_token_group_quant_fp8(x.clone(), group_size) + elif provider == "sglang": + fn = lambda: sglang_per_token_group_quant_fp8(x.clone(), group_size) + + ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + + calculate_diff(batch_size=4, seq_len=128, group_size=64) + + benchmark.run(print_data=True) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 5352fae5c..9890b647f 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -100,6 +100,7 @@ sources = [ "src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu", "src/sgl-kernel/csrc/eagle_utils.cu", "src/sgl-kernel/csrc/speculative_sampling.cu", + "src/sgl-kernel/csrc/per_token_group_quant_fp8.cu", "3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/norm.cu", diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 314416a4f..50135dc15 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -29,6 +29,7 @@ from sgl_kernel.ops import ( register_graph_buffers, rmsnorm, sampling_scaling_penalties, + sgl_per_token_group_quant_fp8, silu_and_mul, top_k_renorm_prob, top_k_top_p_sampling_from_probs, @@ -65,4 +66,5 @@ __all__ = [ "tree_speculative_sampling_target_only", "build_tree_kernel_efficient", "build_tree_kernel", + "sgl_per_token_group_quant_fp8", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu new file mode 100644 index 000000000..894d1d332 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu @@ -0,0 +1,100 @@ +#include +#include + +#include + +#include "utils.h" + +using FP8_TYPE = c10::Float8_e4m3fn; + +__device__ __forceinline__ float WarpReduce(volatile float* smem, const int tid) { + if (tid < 8) { + smem[tid] = fmaxf(smem[tid], smem[tid + 8]); + if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]); + if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]); + if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]); + } + return smem[0]; +} + +template +__global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, void* __restrict__ output_q, + float* __restrict__ output_s, const int group_size, + const int num_groups, const float eps, const float fp8_min, + const float fp8_max) { + const int groups_per_block = 16; + const int block_group_id = blockIdx.x * groups_per_block; + const int tid = threadIdx.x; + const int local_group_id = tid / 16; // Each 16 threads handle one group + const int local_tid = tid % 16; // Thread ID within the group + + __shared__ float s_absmax[16][17]; // Use 17 instead of 16 to avoid bank conflicts + + // Local maximum value for each thread + float local_absmax = eps; + + // Ensure this block doesn't process out-of-bounds groups + if (block_group_id + local_group_id < num_groups) { + // Calculate input/output pointers for current group + const T* group_input = input + (block_group_id + local_group_id) * group_size; + FP8_TYPE* group_output = static_cast(output_q) + (block_group_id + local_group_id) * group_size; + float* scale_output = output_s + block_group_id + local_group_id; + + // Calculate local maximum absolute value + for (int i = local_tid; i < group_size; i += 16) { + float val = static_cast(group_input[i]); + float abs_val = fabsf(val); + local_absmax = fmaxf(local_absmax, abs_val); + } + + // Store in shared memory + s_absmax[local_group_id][local_tid] = local_absmax; + __syncthreads(); + + // Perform reduction within each group + if (local_tid < 8) { + WarpReduce(&s_absmax[local_group_id][0], local_tid); + } + __syncthreads(); + + // Get the maximum value for this group + const float group_absmax = s_absmax[local_group_id][0]; + const float y_s = group_absmax / fp8_max; + + // Only the first thread in each group writes the scale + if (local_tid == 0) { + *scale_output = y_s; + } + + // Quantize the data + for (int i = local_tid; i < group_size; i += 16) { + float val = static_cast(group_input[i]); + float q_val = fminf(fmaxf(val / y_s, fp8_min), fp8_max); + group_output[i] = FP8_TYPE(q_val); + } + } +} + +void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, + int64_t group_size, double eps, double fp8_min, double fp8_max) { + CHECK_INPUT(input); + CHECK_INPUT(output_q); + CHECK_INPUT(output_s); + + const int num_groups = input.numel() / group_size; + + CHECK_EQ(input.numel() % group_size, 0); + + // Each block processes 16 groups, adjust grid size accordingly + dim3 grid((num_groups + 15) / 16); + dim3 block(256); // Keep 256 threads, each 16 threads handle one group + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { + per_token_group_quant_fp8_kernel<<>>( + static_cast(input.data_ptr()), output_q.data_ptr(), static_cast(output_s.data_ptr()), + group_size, num_groups, (float)eps, (float)fp8_min, (float)fp8_max); + return true; + }); +} diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index 266935ad9..6876138e4 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -143,3 +143,7 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk, int64_t depth, int64_t draft_token_num); + +// sgl_per_token_group_quant_fp8 +void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size, + double eps, double fp8_min, double fp8_max); diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 55f83bddd..e3cd7a96d 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -579,3 +579,17 @@ def build_tree_kernel( depth, draft_token_num, ) + + +def sgl_per_token_group_quant_fp8( + input: torch.Tensor, + output_q: torch.Tensor, + output_s: torch.Tensor, + group_size: int, + eps: float, + fp8_min: float, + fp8_max: float, +) -> None: + torch.ops.sgl_kernels.sgl_per_token_group_quant_fp8( + input, output_q, output_s, group_size, eps, fp8_min, fp8_max + ) diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc index e964b3268..677191778 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension.cc +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -153,6 +153,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, " "int topk, int depth, int draft_token_num) -> ()"); m.impl("build_tree_kernel", torch::kCUDA, &build_tree_kernel); + + // per_token_group_quant_fp8 + m.def( + "sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size," + " float eps, float fp8_min, float fp8_max) -> ()"); + m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); } REGISTER_EXTENSION(_kernels) diff --git a/sgl-kernel/tests/test_per_token_group_quant_fp8.py b/sgl-kernel/tests/test_per_token_group_quant_fp8.py new file mode 100644 index 000000000..ddc11b86b --- /dev/null +++ b/sgl-kernel/tests/test_per_token_group_quant_fp8.py @@ -0,0 +1,173 @@ +import itertools +from typing import Any, Dict, List, Optional, Tuple + +import pytest +import torch +import triton +import triton.language as tl +from sgl_kernel import sgl_per_token_group_quant_fp8 + +from sglang.srt.utils import get_device_core_count, get_device_name, is_hip + +is_hip_ = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + # Stride of input + y_stride, + # Collums of input + N, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group quantization on a + tensor. + + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * y_stride + y_q_ptr += g_id * y_stride + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < N + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def triton_per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: torch.dtype = fp8_type_, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + + Args: + x: The input tenosr with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. + """ + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_max = finfo.max + + fp8_min = -fp8_max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + N, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + + +def sglang_per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: torch.dtype = fp8_type_, +): + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_max = finfo.max + + fp8_min = -fp8_max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) + + sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) + + return x_q, x_s + + +@pytest.mark.parametrize( + "batch_size, seq_len, group_size", + list( + itertools.product( + [1, 2, 4, 8, 16], # batch_size + [64, 128, 256, 512, 1024, 2048], # seq_len + [64, 128, 256], # group_size + ) + ), +) +def test_per_token_group_quant_compare_implementations(batch_size, seq_len, group_size): + x = torch.randn( + (batch_size, seq_len, group_size * 2), device="cuda", dtype=torch.float16 + ) + + x_q_triton, x_s_triton = triton_per_token_group_quant_fp8(x, group_size) + x_q_sglang, x_s_sglang = sglang_per_token_group_quant_fp8(x, group_size) + + assert torch.allclose( + x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5 + ) + assert torch.allclose(x_s_triton, x_s_sglang, rtol=1e-3, atol=1e-5) + + +if __name__ == "__main__": + pytest.main([__file__])