diff --git a/sgl-kernel/benchmark/bench_per_token_quant_fp8.py b/sgl-kernel/benchmark/bench_per_token_quant_fp8.py new file mode 100644 index 000000000..19055d2c5 --- /dev/null +++ b/sgl-kernel/benchmark/bench_per_token_quant_fp8.py @@ -0,0 +1,93 @@ +import itertools +from typing import Optional, Tuple + +import torch +import triton +import triton.testing +from sgl_kernel import sgl_per_token_quant_fp8 +from vllm import _custom_ops as ops + +from sglang.srt.utils import is_hip + +is_hip_ = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn + + +def vllm_per_token_quant_fp8( + input: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True) + + +def sglang_per_token_quant_fp8( + input: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32) + output = torch.empty_like(input, device=input.device, dtype=fp8_type_) + sgl_per_token_quant_fp8(input, output, scale) + + return output, scale + + +def calculate_diff(batch_size: int, seq_len: int): + """Calculate difference between VLLM and SGLang implementations.""" + device = torch.device("cuda") + x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device) + + vllm_out, vllm_scale = vllm_per_token_quant_fp8(x) + sglang_out, sglang_scale = sglang_per_token_quant_fp8(x) + + scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item() + output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item() + + print(f"Scale difference: {scale_diff}") + print(f"Output difference: {output_diff}") + + if torch.allclose( + vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 + ) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [16, 32, 64, 128] +seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096] + +configs = list(itertools.product(batch_size_range, seq_len_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len"], + x_vals=configs, + line_arg="provider", + line_vals=["vllm", "sglang"], + line_names=["VLLM", "SGL Kernel"], + styles=[("blue", "-"), ("green", "-")], + ylabel="us", + plot_name="per-token-dynamic-quant-fp8-performance", + args={}, + ) +) +def benchmark_quantization(batch_size, seq_len, provider): + dtype = torch.float16 + device = torch.device("cuda") + + x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "vllm": + fn = lambda: vllm_per_token_quant_fp8(x.clone()) + elif provider == "sglang": + fn = lambda: sglang_per_token_quant_fp8(x.clone()) + + ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + calculate_diff(batch_size=4, seq_len=4096) + benchmark_quantization.run(print_data=True) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 4abf44019..545ff1bfc 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -106,6 +106,7 @@ sources = [ "src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu", "src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu", + "src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu", "src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu", "src/sgl-kernel/csrc/moe/moe_align_kernel.cu", "src/sgl-kernel/csrc/speculative/eagle_utils.cu", diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 417191a67..ab7f673b0 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -29,6 +29,7 @@ from sgl_kernel.ops.gemm import ( int8_scaled_mm, sgl_per_tensor_quant_fp8, sgl_per_token_group_quant_fp8, + sgl_per_token_quant_fp8, ) from sgl_kernel.ops.moe import moe_align_block_size from sgl_kernel.ops.sampling import ( diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu new file mode 100644 index 000000000..be272e065 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -0,0 +1,143 @@ +#include +#include + +#include +#include +#include + +#include "utils.h" + +#define WARP_SIZE 32 + +#ifndef USE_ROCM +#include +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); +#else +#include + +#include "amd/quant_utils.cuh" +using FP8_TYPE = c10::Float8_e4m3fnuz; +// Using the default max value from pytorch (240.0) will cause accuracy +// issue when running dynamic quantization. Here use 224.0f for rocm. +constexpr auto FP8_E4M3_MAX = 224.0f; +#endif + +__device__ __forceinline__ float warpReduceMax(float max_value) { + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1)); + return max_value; +} + +template +__global__ void per_token_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output_q, + float* __restrict__ output_s, const int64_t hidden_dim, + const int64_t num_tokens) { + const int token_idx = blockIdx.x; + + if (token_idx >= num_tokens) return; + + const int tid = threadIdx.x; + const int block_dim = blockDim.x; + + const T* token_input = input + token_idx * hidden_dim; + FP8_TYPE* token_output = output_q + token_idx * hidden_dim; + + float max_value = 0.0f; + + for (int i = tid; i < hidden_dim; i += block_dim) { + float val = static_cast(token_input[i]); + max_value = fmaxf(max_value, fabsf(val)); + } + + max_value = warpReduceMax(max_value); + + static __shared__ float warpLevelMaxs[WARP_SIZE]; + const int laneId = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + + if (laneId == 0) warpLevelMaxs[warpId] = max_value; + __syncthreads(); + + if (warpId == 0) { + max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; + max_value = warpReduceMax(max_value); + } + + __shared__ float block_max; + if (tid == 0) { + block_max = max_value / FP8_E4M3_MAX; + output_s[token_idx] = block_max; + } + __syncthreads(); + + const float scale_val = 1.0f / block_max; + + constexpr uint32_t vec_size = 16 / sizeof(T); + using vec_t = flashinfer::vec_t; + + const int32_t num_vec_elems = hidden_dim / vec_size; + + for (int32_t i = tid; i < num_vec_elems; i += block_dim) { + vec_t input_vec; + input_vec.cast_load(token_input + i * vec_size); + + FP8_TYPE output_arr[vec_size]; +#pragma unroll + for (uint32_t j = 0; j < vec_size; ++j) { + float val = fmax(fmin(static_cast(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX); +#ifndef USE_ROCM + output_arr[j] = static_cast(val); +#else + output_arr[j] = c10::Float8_e4m3fnuz( + __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), + c10::Float8_e4m3fnuz::from_bits()); +#endif + } + +#pragma unroll + for (uint32_t j = 0; j < vec_size; ++j) { + token_output[i * vec_size + j] = output_arr[j]; + } + } + + const int32_t remaining_start = num_vec_elems * vec_size; + for (int32_t idx = remaining_start + tid; idx < hidden_dim; idx += block_dim) { + float val = fmax(-FP8_E4M3_MAX, fmin(static_cast(token_input[idx]) * scale_val, FP8_E4M3_MAX)); +#ifndef USE_ROCM + token_output[idx] = static_cast(val); +#else + token_output[idx] = c10::Float8_e4m3fnuz( + __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), + c10::Float8_e4m3fnuz::from_bits()); +#endif + } +} + +void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s) { + CHECK_INPUT(input); + CHECK_INPUT(output_q); + CHECK_INPUT(output_s); + + const auto input_sizes = input.sizes(); + const int64_t num_tokens = input_sizes[0]; + const int64_t hidden_dim = input_sizes[1]; + + const int block_size = 128; + const int num_blocks = num_tokens; + + dim3 grid(num_blocks); + dim3 block(block_size); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { + per_token_quant_fp8_kernel<<>>( + static_cast(input.data_ptr()), static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), hidden_dim, num_tokens); + return true; + }); +} diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index 5cd4417b1..f5ebffb12 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -160,3 +160,6 @@ void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_r void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, torch::Tensor new_kv); + +// sgl_per_token_quant_fp8 +void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); diff --git a/sgl-kernel/src/sgl-kernel/ops/gemm.py b/sgl-kernel/src/sgl-kernel/ops/gemm.py index b21cb8550..883894e96 100644 --- a/sgl-kernel/src/sgl-kernel/ops/gemm.py +++ b/sgl-kernel/src/sgl-kernel/ops/gemm.py @@ -118,3 +118,11 @@ def cublas_grouped_gemm( cublas_handle, get_cuda_stream(), ) + + +def sgl_per_token_quant_fp8( + input: torch.Tensor, + output_q: torch.Tensor, + output_s: torch.Tensor, +) -> None: + torch.ops.sgl_kernels.sgl_per_token_quant_fp8(input, output_q, output_s) diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc index 61eba3d8c..a8ee87707 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension.cc +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -171,6 +171,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " "Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); + + m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()"); + m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8); } REGISTER_EXTENSION(_kernels) diff --git a/sgl-kernel/tests/test_per_token_quant_fp8.py b/sgl-kernel/tests/test_per_token_quant_fp8.py new file mode 100644 index 000000000..20b2722fc --- /dev/null +++ b/sgl-kernel/tests/test_per_token_quant_fp8.py @@ -0,0 +1,55 @@ +import itertools +from typing import Optional, Tuple + +import pytest +import torch +from sgl_kernel import sgl_per_token_quant_fp8 +from vllm import _custom_ops as ops + +from sglang.srt.utils import is_hip + +is_hip_ = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn + + +def vllm_per_token_quant_fp8( + input: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True) + + +def sglang_per_token_quant_fp8( + input: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32) + output = torch.empty_like(input, device=input.device, dtype=fp8_type_) + + sgl_per_token_quant_fp8(input, output, scale) + scale = scale.reshape(-1, 1) + + return output, scale + + +@pytest.mark.parametrize( + "num_tokens,hidden_dim", + list(itertools.product([128, 256, 512], [512, 2048, 4096])), +) +def test_per_token_quant_compare_implementations( + num_tokens: int, + hidden_dim: int, +): + device = torch.device("cuda") + x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device) + + vllm_out, vllm_scale = vllm_per_token_quant_fp8(x) + sglang_out, sglang_scale = sglang_per_token_quant_fp8(x) + + torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3) + torch.testing.assert_close( + vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3 + ) + + +if __name__ == "__main__": + # Run the specific test function directly + pytest.main([__file__])