From 2c8fd99363e65343667c8316ce8a870b1bfd2bf7 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Thu, 3 Apr 2025 09:29:59 +0800 Subject: [PATCH] [sgl-kernel] per token group quant support COLUMN MAJOR (#4817) --- .../bench_per_token_group_quant_8bit.py | 10 +- .../csrc/gemm/per_token_group_quant_8bit.cu | 68 +++-- .../tests/test_per_token_group_quant_8bit.py | 254 ++++++++++++++---- 3 files changed, 252 insertions(+), 80 deletions(-) diff --git a/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py index b03369c1d..5cee72ebb 100644 --- a/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py +++ b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py @@ -148,9 +148,11 @@ def sglang_per_token_group_quant_8bit( def calculate_diff(batch_size, seq_len, group_size, dst_dtype): device = torch.device("cuda") - hidden_dim = group_size * 2 + hidden_dim = 7168 - x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16) + x = torch.randn( + batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16 + ) x_q_triton, x_s_triton = triton_per_token_group_quant_8bit( x.clone(), group_size, dst_dtype @@ -196,7 +198,9 @@ def benchmark(batch_size, seq_len, group_size, dst_dtype, provider): device = torch.device("cuda") hidden_dim = 7168 - x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16) + x = torch.randn( + batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16 + ) quantiles = [0.5, 0.2, 0.8] diff --git a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu index 57a5ab8ad..b374fd3e2 100644 --- a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu +++ b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu @@ -16,7 +16,7 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) { return val; } -template +template __global__ void per_token_group_quant_8bit_kernel( const T* __restrict__ input, void* __restrict__ output_q, @@ -26,19 +26,30 @@ __global__ void per_token_group_quant_8bit_kernel( const int groups_per_block, const float eps, const float min_8bit, - const float max_8bit) { + const float max_8bit, + const int scale_num_rows = 0, + const int scale_stride = 0) { const int threads_per_group = 16; const int local_group_id = threadIdx.x / threads_per_group; const int lane_id = threadIdx.x % threads_per_group; const int block_group_id = blockIdx.x * groups_per_block; - const int block_group_offset = (block_group_id + local_group_id) * group_size; + const int global_group_id = block_group_id + local_group_id; + const int block_group_offset = global_group_id * group_size; float local_absmax = eps; const T* group_input = input + block_group_offset; DST_DTYPE* group_output = static_cast(output_q) + block_group_offset; - float* scale_output = output_s + (block_group_id + local_group_id); + float* scale_output; + + if constexpr (IS_COLUMN_MAJOR) { + const int row_idx = global_group_id / scale_num_rows; + const int col_idx = global_group_id % scale_num_rows; + scale_output = output_s + (col_idx * scale_stride + row_idx); + } else { + scale_output = output_s + global_group_id; + } constexpr uint32_t vec_size = 16 / sizeof(T); using vec_t = flashinfer::vec_t; @@ -88,11 +99,11 @@ void sgl_per_token_group_quant_8bit( double max_8bit) { CHECK_INPUT(input); CHECK_INPUT(output_q); - CHECK_INPUT(output_s); const int num_groups = input.numel() / group_size; CHECK_EQ(input.numel() % group_size, 0); + CHECK_EQ(output_s.dim(), 2); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -114,20 +125,39 @@ void sgl_per_token_group_quant_8bit( const int num_blocks = num_groups / groups_per_block; const int num_threads = groups_per_block * THREADS_PER_GROUP; -#define LAUNCH_KERNEL(T, DST_DTYPE) \ - do { \ - dim3 grid(num_blocks); \ - dim3 block(num_threads); \ - per_token_group_quant_8bit_kernel<<>>( \ - static_cast(input.data_ptr()), \ - output_q.data_ptr(), \ - static_cast(output_s.data_ptr()), \ - group_size, \ - num_groups, \ - groups_per_block, \ - (float)eps, \ - (float)min_8bit, \ - (float)max_8bit); \ + const bool is_column_major = output_s.stride(0) < output_s.stride(1); + const int scale_num_rows = output_s.size(1); + const int scale_stride = output_s.stride(1); + +#define LAUNCH_KERNEL(T, DST_DTYPE) \ + do { \ + dim3 grid(num_blocks); \ + dim3 block(num_threads); \ + if (is_column_major) { \ + per_token_group_quant_8bit_kernel<<>>( \ + static_cast(input.data_ptr()), \ + output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), \ + group_size, \ + num_groups, \ + groups_per_block, \ + (float)eps, \ + (float)min_8bit, \ + (float)max_8bit, \ + scale_num_rows, \ + scale_stride); \ + } else { \ + per_token_group_quant_8bit_kernel<<>>( \ + static_cast(input.data_ptr()), \ + output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), \ + group_size, \ + num_groups, \ + groups_per_block, \ + (float)eps, \ + (float)min_8bit, \ + (float)max_8bit); \ + } \ } while (0) DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { diff --git a/sgl-kernel/tests/test_per_token_group_quant_8bit.py b/sgl-kernel/tests/test_per_token_group_quant_8bit.py index b628a6a42..083ca1cad 100644 --- a/sgl-kernel/tests/test_per_token_group_quant_8bit.py +++ b/sgl-kernel/tests/test_per_token_group_quant_8bit.py @@ -9,12 +9,12 @@ from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_ from sglang.srt.utils import is_hip -is_hip_ = is_hip() -fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn +_is_hip = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn @triton.jit -def _per_token_group_quant_8bit( +def _per_token_group_quant_fp8( # Pointers to inputs and output y_ptr, y_q_ptr, @@ -25,15 +25,16 @@ def _per_token_group_quant_8bit( N, # Avoid to divide zero eps, - # Information for 8bit data type (int8 or fp8_type_) - max_8bit, - min_8bit, + # Information for float8 + fp8_min, + fp8_max, # Meta-parameters BLOCK: tl.constexpr, ): """A Triton-accelerated function to perform per-token-group quantization on a tensor. - This function converts the tensor values into 8bit values. + + This function converts the tensor values into float8 values. """ # Map the program id to the row of X and Y it should compute. g_id = tl.program_id(0) @@ -47,8 +48,57 @@ def _per_token_group_quant_8bit( y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - y_s = _absmax / max_8bit - y_q = tl.clamp(y / y_s, min_8bit, max_8bit).to(y_q_ptr.dtype.element_ty) + y_s = _absmax / fp8_max + y_s_inv = 1.0 / y_s + y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_s_ptr, y_s) @@ -57,17 +107,22 @@ def _per_token_group_quant_8bit( def triton_per_token_group_quant_8bit( x: torch.Tensor, group_size: int, - dst_dtype: torch.dtype, eps: float = 1e-10, + dtype: torch.dtype = fp8_type_, + column_major_scales: bool = False, + scale_tma_aligned: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the quantized tensor along with the scaling factor used for quantization. + Args: x: The input tenosr with ndim >= 2. group_size: The group size used for quantization. eps: The minimum to avoid dividing zero. - dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. + dtype: The dype of output tensor. + Returns: Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ @@ -76,41 +131,79 @@ def triton_per_token_group_quant_8bit( ), "the last dimension of `x` cannot be divisible by `group_size`" assert x.is_contiguous(), "`x` is not contiguous" - if dst_dtype == torch.int8: - iinfo = torch.iinfo(dst_dtype) - max_8bit = iinfo.max - min_8bit = iinfo.min + if dtype == torch.int8: + finfo = torch.iinfo(dtype) else: - finfo = torch.finfo(dst_dtype) - max_8bit = finfo.max - min_8bit = finfo.min + finfo = torch.finfo(dtype) - x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype) + fp8_max = finfo.max + + if _is_hip: + if dtype == torch.int8: + fp8_max = 127.0 + else: + fp8_max = 224.0 + + fp8_min = -fp8_max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) M = x.numel() // group_size N = group_size - x_s = torch.empty( - x.shape[:-1] + (x.shape[-1] // group_size,), - device=x.device, - dtype=torch.float32, - ) + if column_major_scales: + if scale_tma_aligned: + # aligned to 4 * sizeof(float) + aligned_size = (x.shape[-2] + 3) // 4 * 4 + x_s = torch.empty( + x.shape[:-2] + (x.shape[-1] // group_size, aligned_size), + device=x.device, + dtype=torch.float32, + ).permute(-1, -2)[: x.shape[-2], :] + else: + x_s = torch.empty( + (x.shape[-1] // group_size,) + x.shape[:-1], + device=x.device, + dtype=torch.float32, + ).permute(-1, -2) + else: + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) BLOCK = triton.next_power_of_2(N) # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 - _per_token_group_quant_8bit[(M,)]( - x, - x_q, - x_s, - group_size, - N, - eps, - max_8bit, - min_8bit, - BLOCK=BLOCK, - num_warps=num_warps, - num_stages=num_stages, - ) + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + N, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) return x_q, x_s @@ -118,28 +211,48 @@ def triton_per_token_group_quant_8bit( def sglang_per_token_group_quant_8bit( x: torch.Tensor, group_size: int, - dst_dtype: torch.dtype, eps: float = 1e-10, + dtype: torch.dtype = fp8_type_, + column_major_scales: bool = False, + scale_tma_aligned: bool = False, ): assert ( x.shape[-1] % group_size == 0 ), "the last dimension of `x` cannot be divisible by `group_size`" assert x.is_contiguous(), "`x` is not contiguous" - x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype) - x_s = torch.empty( - x.shape[:-1] + (x.shape[-1] // group_size,), - device=x.device, - dtype=torch.float32, - ) + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + if scale_tma_aligned: + # aligned to 4 * sizeof(float) + aligned_size = (x.shape[-2] + 3) // 4 * 4 + x_s = torch.empty( + x.shape[:-2] + (x.shape[-1] // group_size, aligned_size), + device=x.device, + dtype=torch.float32, + ).permute(-1, -2)[: x.shape[-2], :] + else: + x_s = torch.empty( + (x.shape[-1] // group_size,) + x.shape[:-1], + device=x.device, + dtype=torch.float32, + ).permute(-1, -2) + else: + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) - if dst_dtype == torch.int8: - iinfo = torch.iinfo(dst_dtype) + if dtype == torch.int8: + iinfo = torch.iinfo(dtype) int8_max = iinfo.max int8_min = iinfo.min sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max) else: - f8_info = torch.finfo(dst_dtype) + f8_info = torch.finfo(dtype) fp8_max = f8_info.max fp8_min = f8_info.min sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) @@ -148,30 +261,55 @@ def sglang_per_token_group_quant_8bit( @pytest.mark.parametrize( - "batch_size, seq_len, group_size, dst_dtype", + "num_tokens, hidden_dim, group_size, dst_dtype, column_major_scales, scale_tma_aligned", list( itertools.product( - [1, 2, 4, 8, 16, 32, 64, 128], # batch_size - [64, 128, 256, 512, 1024, 2048], # seq_len - [16, 32, 64, 128, 256], # group_size + [127, 128, 512, 1024, 4096, 8192], # num_tokens + [256, 512, 1024, 2048, 4096], # hidden_dim + [8, 16, 32, 64, 128], # group_size [torch.int8, fp8_type_], # dtype + [False, True], # column_major_scales + [False, True], # scale_tma_aligned ) ), ) -def test_per_token_group_quant_compare_implementations( - batch_size, seq_len, group_size, dst_dtype +def test_per_token_group_quant_with_column_major( + num_tokens, + hidden_dim, + group_size, + dst_dtype, + column_major_scales, + scale_tma_aligned, ): - x = torch.randn( - (batch_size, seq_len, group_size * 2), device="cuda", dtype=torch.float16 + if not column_major_scales and scale_tma_aligned: + return + + x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.float16) + + x_q_triton, x_s_triton = triton_per_token_group_quant_8bit( + x, + group_size, + eps=1e-10, + dtype=dst_dtype, + column_major_scales=column_major_scales, + scale_tma_aligned=scale_tma_aligned, ) - x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(x, group_size, dst_dtype) - x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(x, group_size, dst_dtype) + x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit( + x, + group_size, + eps=1e-10, + dtype=dst_dtype, + column_major_scales=column_major_scales, + scale_tma_aligned=scale_tma_aligned, + ) assert torch.allclose( x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5 ) - assert torch.allclose(x_s_triton, x_s_sglang, rtol=1e-3, atol=1e-5) + assert torch.allclose( + x_s_triton.contiguous(), x_s_sglang.contiguous(), rtol=1e-3, atol=1e-5 + ) if __name__ == "__main__":