From 4068e012924b5c59fc3eadde99279b604846c5bf Mon Sep 17 00:00:00 2001 From: Qingquan Song Date: Wed, 12 Mar 2025 21:19:05 -0700 Subject: [PATCH] Fix per token fp8 quant precision (#4362) --- sgl-kernel/benchmark/bench_per_token_quant_fp8.py | 6 +----- sgl-kernel/csrc/gemm/per_token_quant_fp8.cu | 4 +--- sgl-kernel/tests/test_per_token_quant_fp8.py | 8 +++----- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/sgl-kernel/benchmark/bench_per_token_quant_fp8.py b/sgl-kernel/benchmark/bench_per_token_quant_fp8.py index ed0bfc78b..8d4e68bd1 100644 --- a/sgl-kernel/benchmark/bench_per_token_quant_fp8.py +++ b/sgl-kernel/benchmark/bench_per_token_quant_fp8.py @@ -22,10 +22,9 @@ def vllm_per_token_quant_fp8( def sglang_per_token_quant_fp8( input: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32) + scale = torch.zeros((input.size(0), 1), device=input.device, dtype=torch.float32) output = torch.empty_like(input, device=input.device, dtype=fp8_type_) sgl_per_token_quant_fp8(input, output, scale) - return output, scale @@ -37,9 +36,6 @@ def calculate_diff(batch_size: int, seq_len: int): vllm_out, vllm_scale = vllm_per_token_quant_fp8(x) sglang_out, sglang_scale = sglang_per_token_quant_fp8(x) - scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item() - output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item() - if torch.allclose( vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 ) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5): diff --git a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu index 9c3b67768..971fb305c 100644 --- a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu +++ b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -49,8 +49,6 @@ __global__ void per_token_quant_fp8_kernel( } __syncthreads(); - const float scale_val = 1.0f / block_max; - // Quantize using vectorized loads for (int32_t i = tid; i < num_vec_elems; i += block_dim) { vec_t input_vec; @@ -59,7 +57,7 @@ __global__ void per_token_quant_fp8_kernel( FP8_TYPE output_arr[vec_size]; #pragma unroll for (uint32_t j = 0; j < vec_size; ++j) { - float val = fmaxf(fminf(static_cast(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX); + float val = fmaxf(fminf(static_cast(input_vec[j]) / block_max, FP8_E4M3_MAX), -FP8_E4M3_MAX); #ifndef USE_ROCM output_arr[j] = static_cast(val); #else diff --git a/sgl-kernel/tests/test_per_token_quant_fp8.py b/sgl-kernel/tests/test_per_token_quant_fp8.py index 20b2722fc..2b0f63e7f 100644 --- a/sgl-kernel/tests/test_per_token_quant_fp8.py +++ b/sgl-kernel/tests/test_per_token_quant_fp8.py @@ -21,18 +21,16 @@ def vllm_per_token_quant_fp8( def sglang_per_token_quant_fp8( input: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32) + scale = torch.zeros((input.size(0), 1), device=input.device, dtype=torch.float32) output = torch.empty_like(input, device=input.device, dtype=fp8_type_) sgl_per_token_quant_fp8(input, output, scale) - scale = scale.reshape(-1, 1) - return output, scale @pytest.mark.parametrize( "num_tokens,hidden_dim", - list(itertools.product([128, 256, 512], [512, 2048, 4096])), + list(itertools.product([32, 64, 128, 256, 512], [128, 256, 512, 2048, 4096])), ) def test_per_token_quant_compare_implementations( num_tokens: int, @@ -44,7 +42,7 @@ def test_per_token_quant_compare_implementations( vllm_out, vllm_scale = vllm_per_token_quant_fp8(x) sglang_out, sglang_scale = sglang_per_token_quant_fp8(x) - torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5) torch.testing.assert_close( vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3 )