From ad55f1718262d30d1a8eb4ca16d3b12952bdc712 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Fri, 7 Mar 2025 10:05:43 +0800 Subject: [PATCH] [quant kernel] sgl-kernel support per_tensor_quant fp8 (#3786) --- .../benchmark/bench_per_tensor_quant_fp8.py | 98 +++++++++++ sgl-kernel/setup.py | 1 + sgl-kernel/src/sgl-kernel/__init__.py | 1 + .../csrc/gemm/per_tensor_quant_fp8.cu | 163 ++++++++++++++++++ .../src/sgl-kernel/include/sgl_kernels_ops.h | 1 + sgl-kernel/src/sgl-kernel/ops/gemm.py | 9 + sgl-kernel/src/sgl-kernel/torch_extension.cc | 3 + sgl-kernel/tests/test_per_tensor_quant_fp8.py | 67 +++++++ 8 files changed, 343 insertions(+) create mode 100644 sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py create mode 100644 sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu create mode 100644 sgl-kernel/tests/test_per_tensor_quant_fp8.py diff --git a/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py b/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py new file mode 100644 index 000000000..7a07efd93 --- /dev/null +++ b/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py @@ -0,0 +1,98 @@ +import itertools +import math +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +import triton +import triton.testing +from sgl_kernel import sgl_per_tensor_quant_fp8 +from vllm import _custom_ops as ops + +from sglang.srt.utils import is_hip + +is_hip_ = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn + + +def vllm_scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return ops.scaled_fp8_quant(input, scale) + + +def sglang_scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + fp8_type_: torch.dtype = torch.float8_e4m3fn + output = torch.empty_like(input, device=input.device, dtype=fp8_type_) + is_static = True + if scale is None: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + is_static = False + sgl_per_tensor_quant_fp8(input, output, scale, is_static) + + return output, scale + + +def calculate_diff(batch_size: int, seq_len: int): + """Calculate difference between VLLM and SGLang implementations.""" + device = torch.device("cuda") + x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device) + + vllm_out, vllm_scale = vllm_scaled_fp8_quant(x) + sglang_out, sglang_scale = sglang_scaled_fp8_quant(x) + + scale_diff = torch.abs(vllm_scale - sglang_scale).item() + output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item() + + if torch.allclose( + vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 + ) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [16, 32, 64, 128] +seq_len_range = [64, 128, 256, 512, 1024, 2048] + +configs = list(itertools.product(batch_size_range, seq_len_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len"], + x_vals=configs, + line_arg="provider", + line_vals=["vllm", "sglang"], + line_names=["VLLM", "SGL Kernel"], + styles=[("blue", "-"), ("green", "-")], + ylabel="us", + plot_name="per-tensor-quant-fp8-performance", + args={}, + ) +) +def benchmark(batch_size, seq_len, provider): + dtype = torch.float16 + device = torch.device("cuda") + + x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "vllm": + fn = lambda: vllm_scaled_fp8_quant(x.clone()) + elif provider == "sglang": + fn = lambda: sglang_scaled_fp8_quant(x.clone()) + + ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + calculate_diff(batch_size=4, seq_len=4096) + benchmark.run(print_data=True) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 472d3c725..4abf44019 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -106,6 +106,7 @@ sources = [ "src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu", "src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu", + "src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu", "src/sgl-kernel/csrc/moe/moe_align_kernel.cu", "src/sgl-kernel/csrc/speculative/eagle_utils.cu", "src/sgl-kernel/csrc/speculative/speculative_sampling.cu", diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index eef55cafc..417191a67 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -27,6 +27,7 @@ from sgl_kernel.ops.gemm import ( fp8_blockwise_scaled_mm, fp8_scaled_mm, int8_scaled_mm, + sgl_per_tensor_quant_fp8, sgl_per_token_group_quant_fp8, ) from sgl_kernel.ops.moe import moe_align_block_size diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu new file mode 100644 index 000000000..d9cabd783 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu @@ -0,0 +1,163 @@ +#include +#include + +#include +#include +#include + +#include "utils.h" + +#define WARP_SIZE 32 + +#ifndef USE_ROCM +#include +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); +#else +#include + +#include "amd/quant_utils.cuh" +using FP8_TYPE = c10::Float8_e4m3fnuz; +// Using the default max value from pytorch (240.0) will cause accuracy +// issue when running dynamic quantization. Here use 224.0f for rocm. +constexpr auto FP8_E4M3_MAX = 224.0f; +#endif + +__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { + float old; + old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + return old; +} + +__device__ __forceinline__ float warpReduceMax(float max_value) { + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1)); + return max_value; +} + +template +__global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, + const int64_t num_elements) { + float max_value = 0.0f; + unsigned int tid = threadIdx.x; + unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x; + const int grid_size = blockDim.x * gridDim.x; + + constexpr uint32_t vec_size = 16 / sizeof(T); + using vec_t = flashinfer::vec_t; + + const int32_t num_vec_elems = num_elements / vec_size; + + for (int32_t i = gid; i < num_vec_elems; i += grid_size) { + vec_t input_vec; + input_vec.cast_load(input + i * vec_size); + +#pragma unroll + for (uint32_t j = 0; j < vec_size; ++j) { + float val = static_cast(input_vec[j]); + max_value = fmaxf(max_value, fabsf(val)); + } + } + + const int32_t remaining_start = num_vec_elems * vec_size; + for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) { + float val = static_cast(input[idx]); + max_value = fmaxf(max_value, fabsf(val)); + } + + static __shared__ float warpLevelMaxs[WARP_SIZE]; + const int laneId = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + + max_value = warpReduceMax(max_value); + + if (laneId == 0) warpLevelMaxs[warpId] = max_value; + __syncthreads(); + + max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; + + if (warpId == 0) max_value = warpReduceMax(max_value); + + if (tid == 0) { + atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX); + } +} + +template +__global__ void per_tensor_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output, + const float* __restrict__ scale, const int64_t num_elements) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + const int grid_size = blockDim.x * gridDim.x; + const float scale_val = 1.0f / (*scale); + + constexpr uint32_t vec_size = 16 / sizeof(T); + using vec_t = flashinfer::vec_t; + + const int32_t num_vec_elems = num_elements / vec_size; + + for (int32_t i = gid; i < num_vec_elems; i += grid_size) { + vec_t input_vec; + input_vec.cast_load(input + i * vec_size); + + FP8_TYPE output_arr[vec_size]; +#pragma unroll + for (uint32_t j = 0; j < vec_size; ++j) { + float val = fmax(fmin(static_cast(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX); +#ifndef USE_ROCM + output_arr[j] = static_cast(val); +#else + output_arr[j] = c10::Float8_e4m3fnuz( + __hip_cvt_float_to_fp8(value, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), + c10::Float8_e4m3fnuz::from_bits()); +#endif + } + +#pragma unroll + for (uint32_t j = 0; j < vec_size; ++j) { + output[i * vec_size + j] = output_arr[j]; + } + } + + const int32_t remaining_start = num_vec_elems * vec_size; + for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) { + float val = fmax(-FP8_E4M3_MAX, fmin(static_cast(input[idx]) * scale_val, FP8_E4M3_MAX)); +#ifndef USE_ROCM + output[idx] = static_cast(val); +#else + output[idx] = c10::Float8_e4m3fnuz( + __hip_cvt_float_to_fp8(value, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), + c10::Float8_e4m3fnuz::from_bits()); +#endif + } +} + +void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, bool is_static) { + CHECK_INPUT(input); + CHECK_INPUT(output_q); + CHECK_INPUT(output_s); + + const int block_size = 256; + const int num_elements = input.numel(); + const int num_blocks = min((num_elements + block_size - 1) / block_size, 1024); + + dim3 grid(num_blocks); + dim3 block(block_size); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { + if (is_static == false) { + per_tensor_absmax_kernel<<>>( + static_cast(input.data_ptr()), static_cast(output_s.data_ptr()), num_elements); + } + + per_tensor_quant_fp8_kernel<<>>( + static_cast(input.data_ptr()), static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), num_elements); + return true; + }); +} diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index fcc2c6139..5cd4417b1 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -92,6 +92,7 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T const torch::Dtype& out_dtype); void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size, double eps, double fp8_min, double fp8_max); +void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); void cublas_grouped_gemm(const std::vector& inputs, const std::vector& weights, const std::vector& outputs, const torch::Dtype& out_dtype, int64_t cublas_handle, int64_t cuda_stream); diff --git a/sgl-kernel/src/sgl-kernel/ops/gemm.py b/sgl-kernel/src/sgl-kernel/ops/gemm.py index 1084753c3..b21cb8550 100644 --- a/sgl-kernel/src/sgl-kernel/ops/gemm.py +++ b/sgl-kernel/src/sgl-kernel/ops/gemm.py @@ -91,6 +91,15 @@ def sgl_per_token_group_quant_fp8( ) +def sgl_per_tensor_quant_fp8( + input: torch.Tensor, + output_q: torch.Tensor, + output_s: torch.Tensor, + is_static: bool, +) -> None: + torch.ops.sgl_kernels.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static) + + def cublas_grouped_gemm( inputs: List[torch.Tensor], weights: List[torch.Tensor], diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc index a7578d393..61eba3d8c 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension.cc +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -90,6 +90,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { " float eps, float fp8_min, float fp8_max) -> ()"); m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); + m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); + m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8); + m.def( "cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs," " ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"); diff --git a/sgl-kernel/tests/test_per_tensor_quant_fp8.py b/sgl-kernel/tests/test_per_tensor_quant_fp8.py new file mode 100644 index 000000000..70b05af5d --- /dev/null +++ b/sgl-kernel/tests/test_per_tensor_quant_fp8.py @@ -0,0 +1,67 @@ +import itertools +from typing import Optional, Tuple + +import pytest +import torch +from sgl_kernel import sgl_per_tensor_quant_fp8 +from vllm import _custom_ops as ops + +from sglang.srt.utils import is_hip + +is_hip_ = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn + + +def vllm_scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return ops.scaled_fp8_quant(input, scale) + + +def sglang_scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + fp8_type_: torch.dtype = torch.float8_e4m3fn + output = torch.empty_like(input, device=input.device, dtype=fp8_type_) + is_static = True + if scale is None: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + is_static = False + sgl_per_tensor_quant_fp8(input, output, scale, is_static) + + return output, scale + + +@pytest.mark.parametrize( + "num_tokens,hidden_dim", + list(itertools.product([128, 256, 512], [512, 2048, 4096])), +) +def test_per_tensor_quant_compare_implementations( + num_tokens: int, + hidden_dim: int, +): + device = torch.device("cuda") + x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device) + + vllm_out, vllm_scale = vllm_scaled_fp8_quant(x) + sglang_out, sglang_scale = sglang_scaled_fp8_quant(x) + + torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3) + torch.testing.assert_close( + vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3 + ) + + scale = torch.rand(1, dtype=torch.float32, device=device) + vllm_out, vllm_scale = vllm_scaled_fp8_quant(x, scale) + sglang_out, sglang_scale = sglang_scaled_fp8_quant(x, scale) + + torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3) + torch.testing.assert_close( + vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3 + ) + + +if __name__ == "__main__": + pytest.main([__file__])