diff --git a/sgl-kernel/benchmark/bench_per_token_quant_fp8.py b/sgl-kernel/benchmark/bench_per_token_quant_fp8.py index ef50957e2..a72a1a3d0 100644 --- a/sgl-kernel/benchmark/bench_per_token_quant_fp8.py +++ b/sgl-kernel/benchmark/bench_per_token_quant_fp8.py @@ -12,6 +12,39 @@ from sglang.srt.utils import is_hip _is_hip = is_hip() fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn +# Get correct FP8 E4M3 maximum value +if _is_hip: + FP8_E4M3_MAX = 224.0 # ROCM uses 224.0 +else: + # For CUDA, get the actual max value from the type + FP8_E4M3_MAX = float(torch.finfo(fp8_type_).max) + + +def torch_per_token_quant_fp8( + input: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Pure PyTorch reference implementation for per-token FP8 quantization.""" + device = input.device + dtype = input.dtype + + # Find max absolute value per token (row) - exactly like CUDA kernel + max_vals = torch.abs(input).max(dim=1)[0] # [num_tokens] + + # Calculate scale per token - exactly like CUDA kernel: scale = max_value / FP8_E4M3_MAX + scales = max_vals / FP8_E4M3_MAX # [num_tokens] + + # No special zero handling - directly compute 1.0 / scale like CUDA kernel + scale_inv = 1.0 / scales # [num_tokens] + + # Quantize: input * scale_inv, then clamp to FP8 range + quantized_float = input * scale_inv.unsqueeze(1) # Broadcast scale_inv + quantized_float = torch.clamp(quantized_float, -FP8_E4M3_MAX, FP8_E4M3_MAX) + + # Convert to FP8 - use more explicit conversion + quantized_fp8 = quantized_float.to(fp8_type_) + + return quantized_fp8, scales + def vllm_per_token_quant_fp8( input: torch.Tensor, @@ -29,53 +62,100 @@ def sglang_per_token_quant_fp8( return output, scale -def calculate_diff(batch_size: int, seq_len: int): - """Calculate difference between VLLM and SGLang implementations.""" +def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int): + """Compare Torch reference, VLLM, and SGLang implementations.""" device = torch.device("cuda") - x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device) + x = torch.rand( + (batch_size * seq_len, hidden_dim), dtype=torch.float16, device=device + ) + # Get all three implementations + torch_out, torch_scale = torch_per_token_quant_fp8(x) vllm_out, vllm_scale = vllm_per_token_quant_fp8(x) sglang_out, sglang_scale = sglang_per_token_quant_fp8(x) - scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item() - output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item() + print(f"\n=== Comparison for hidden_dim={hidden_dim} ===") - if torch.allclose( - vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 - ) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5): - print("✅ All implementations match") - else: - print("❌ Implementations differ") + # Compare scales + torch_vllm_scale_diff = torch.abs(torch_scale - vllm_scale).mean().item() + torch_sglang_scale_diff = torch.abs(torch_scale - sglang_scale).mean().item() + vllm_sglang_scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item() + + print(f"Scale differences:") + print(f" Torch vs VLLM: {torch_vllm_scale_diff:.8f}") + print(f" Torch vs SGLang: {torch_sglang_scale_diff:.8f}") + print(f" VLLM vs SGLang: {vllm_sglang_scale_diff:.8f}") + + # Compare outputs + torch_vllm_out_diff = torch.abs(torch_out.float() - vllm_out.float()).mean().item() + torch_sglang_out_diff = ( + torch.abs(torch_out.float() - sglang_out.float()).mean().item() + ) + vllm_sglang_out_diff = ( + torch.abs(vllm_out.float() - sglang_out.float()).mean().item() + ) + + print(f"Output differences:") + print(f" Torch vs VLLM: {torch_vllm_out_diff:.8f}") + print(f" Torch vs SGLang: {torch_sglang_out_diff:.8f}") + print(f" VLLM vs SGLang: {vllm_sglang_out_diff:.8f}") + + # Check tolerances + rtol, atol = 1e-3, 1e-5 + + torch_vllm_match = torch.allclose( + torch_out.float(), vllm_out.float(), rtol=rtol, atol=atol + ) and torch.allclose(torch_scale, vllm_scale, rtol=rtol, atol=atol) + torch_sglang_match = torch.allclose( + torch_out.float(), sglang_out.float(), rtol=rtol, atol=atol + ) and torch.allclose(torch_scale, sglang_scale, rtol=rtol, atol=atol) + + if hidden_dim == 1368: + rtol = 1e-2 + # we found vllm sglang has diff when hidden dim is not dividable by 16 + # and we believe SGLang is closer to Torch implementation + + vllm_sglang_match = torch.allclose( + vllm_out.float(), sglang_out.float(), rtol=rtol, atol=atol + ) and torch.allclose(vllm_scale, sglang_scale, rtol=rtol, atol=atol) + + print(f"Matches (rtol={rtol}, atol={atol}):") + print(f" Torch vs VLLM: {'✅' if torch_vllm_match else '❌'}") + print(f" Torch vs SGLang: {'✅' if torch_sglang_match else '❌'}") + print(f" VLLM vs SGLang: {'✅' if vllm_sglang_match else '❌'}") batch_size_range = [16, 32, 64, 128] seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096] +hidden_dim_range = [1368, 2048, 4096] -configs = list(itertools.product(batch_size_range, seq_len_range)) +configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_range)) @triton.testing.perf_report( triton.testing.Benchmark( - x_names=["batch_size", "seq_len"], + x_names=["batch_size", "seq_len", "hidden_dim"], x_vals=configs, line_arg="provider", - line_vals=["vllm", "sglang"], - line_names=["VLLM", "SGL Kernel"], - styles=[("blue", "-"), ("green", "-")], + line_vals=["torch", "vllm", "sglang"], + line_names=["Torch Reference", "VLLM", "SGL Kernel"], + styles=[("red", "-"), ("blue", "-"), ("green", "-")], ylabel="us", plot_name="per-token-dynamic-quant-fp8-performance", args={}, ) ) -def benchmark_quantization(batch_size, seq_len, provider): +def benchmark_quantization(batch_size, seq_len, hidden_dim, provider): dtype = torch.float16 device = torch.device("cuda") - x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) + x = torch.randn(batch_size * seq_len, hidden_dim, device=device, dtype=dtype) quantiles = [0.5, 0.2, 0.8] - if provider == "vllm": + if provider == "torch": + fn = lambda: torch_per_token_quant_fp8(x.clone()) + elif provider == "vllm": fn = lambda: vllm_per_token_quant_fp8(x.clone()) elif provider == "sglang": fn = lambda: sglang_per_token_quant_fp8(x.clone()) @@ -86,5 +166,12 @@ def benchmark_quantization(batch_size, seq_len, provider): if __name__ == "__main__": - calculate_diff(batch_size=4, seq_len=4096) + # Test various hidden dimensions for correctness + test_dims = [1368, 2048, 4096] + + for dim in test_dims: + calculate_diff(batch_size=4, seq_len=4096, hidden_dim=dim) + + print("\n" + "=" * 60) + print("Starting performance benchmark...") benchmark_quantization.run(print_data=True) diff --git a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu index 7b58f838f..a3c60ad5b 100644 --- a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu +++ b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -75,14 +75,21 @@ __global__ void per_token_quant_fp8_kernel( c10::Float8_e4m3fnuz::from_bits()); #endif } - *(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr; + if constexpr (kVecSize == 16) { + *(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr; + } else { + // Use element-wise copy for vector size 8 to ensure correctness + for (int k = 0; k < kVecSize; ++k) { + token_output[i * kVecSize + k] = output_arr[k]; + } + } } } // --------------------------------------------------------------------------- // 2. Baseline kernel (1 token / CTA, CUB block reduce) // --------------------------------------------------------------------------- -template +template __global__ void per_token_quant_fp8_small_batch_kernel( const T* __restrict__ input, DST_DTYPE* __restrict__ output_q, @@ -100,19 +107,17 @@ __global__ void per_token_quant_fp8_small_batch_kernel( float max_value = 0.0f; - // We want to store 128 bits of data at a time. 16 = 128 / 8 bits - // Load is already vectorized, so 16 elements work for T. - const uint32_t VEC_SIZE = 16; - using vec_t = flashinfer::vec_t; - const int32_t num_vec_elems = hidden_dim / VEC_SIZE; + // Use template parameter for vector size + using vec_t = flashinfer::vec_t; + const int32_t num_vec_elems = hidden_dim / kVecSize; // Find max using vectorized loads for (int32_t i = tid; i < num_vec_elems; i += block_dim) { vec_t input_vec; - input_vec.cast_load(token_input + i * VEC_SIZE); + input_vec.cast_load(token_input + i * kVecSize); #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { + for (uint32_t j = 0; j < kVecSize; ++j) { float val = static_cast(input_vec[j]); max_value = fmaxf(max_value, fabsf(val)); } @@ -132,11 +137,11 @@ __global__ void per_token_quant_fp8_small_batch_kernel( // Quantize using vectorized loads for (int32_t i = tid; i < num_vec_elems; i += block_dim) { vec_t input_vec; - input_vec.cast_load(token_input + i * VEC_SIZE); + input_vec.cast_load(token_input + i * kVecSize); - DST_DTYPE output_arr[VEC_SIZE]; + DST_DTYPE output_arr[kVecSize]; #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { + for (uint32_t j = 0; j < kVecSize; ++j) { float val = fmaxf(fminf(static_cast(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX); #ifndef USE_ROCM output_arr[j] = static_cast(val); @@ -147,7 +152,14 @@ __global__ void per_token_quant_fp8_small_batch_kernel( #endif } - *(uint4*)(token_output + i * VEC_SIZE) = *(uint4*)output_arr; + if constexpr (kVecSize == 16) { + *(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr; + } else { + // Use element-wise copy for vector size 8 to ensure correctness + for (int k = 0; k < kVecSize; ++k) { + token_output[i * kVecSize + k] = output_arr[k]; + } + } } } @@ -158,13 +170,14 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: const auto input_sizes = input.sizes(); const int64_t num_tokens = input_sizes[0]; const int64_t hidden_dim = input_sizes[1]; - TORCH_CHECK(hidden_dim % 16 == 0, "Hidden dimension must be divisible by 16, but got ", hidden_dim); + TORCH_CHECK(hidden_dim % 8 == 0, "Hidden dimension must be divisible by 8, but got ", hidden_dim); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Hard-code sm_count int sm_count = 132; constexpr int TOKENS_PER_CTA = 8; const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA); + const bool use_vec16 = (hidden_dim % 16 == 0); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { if (use_warp_kernel) { @@ -172,23 +185,43 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: constexpr int THREADS = TOKENS_PER_CTA * kWarpSize; // 256 dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA); dim3 block(THREADS); - per_token_quant_fp8_kernel<<>>( - static_cast(input.data_ptr()), - static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), - static_cast(output_s.data_ptr()), - hidden_dim, - num_tokens); + + if (use_vec16) { + per_token_quant_fp8_kernel<<>>( + static_cast(input.data_ptr()), + static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + hidden_dim, + num_tokens); + } else { + per_token_quant_fp8_kernel<<>>( + static_cast(input.data_ptr()), + static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + hidden_dim, + num_tokens); + } } else { // -------- baseline ----------------------------------------------------- constexpr int THREADS = 256; dim3 grid(num_tokens); dim3 block(THREADS); - per_token_quant_fp8_small_batch_kernel<<>>( - static_cast(input.data_ptr()), - static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), - static_cast(output_s.data_ptr()), - hidden_dim, - num_tokens); + + if (use_vec16) { + per_token_quant_fp8_small_batch_kernel<<>>( + static_cast(input.data_ptr()), + static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + hidden_dim, + num_tokens); + } else { + per_token_quant_fp8_small_batch_kernel<<>>( + static_cast(input.data_ptr()), + static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + hidden_dim, + num_tokens); + } } return true; }); diff --git a/sgl-kernel/tests/test_per_token_quant_fp8.py b/sgl-kernel/tests/test_per_token_quant_fp8.py index 80efd06e7..40ec9d897 100644 --- a/sgl-kernel/tests/test_per_token_quant_fp8.py +++ b/sgl-kernel/tests/test_per_token_quant_fp8.py @@ -36,7 +36,7 @@ def sglang_per_token_quant_fp8( @pytest.mark.parametrize( "num_tokens,hidden_dim", - list(itertools.product([128, 256, 512], [512, 2048, 4096])), + list(itertools.product([128, 256, 512], [512, 1368, 2048, 4096])), ) def test_per_token_quant_compare_implementations( num_tokens: int,