Add sgl_per_token_quant_fp8 (#4089)
This commit is contained in:
93
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
Normal file
93
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import itertools
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import sgl_per_token_quant_fp8
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
is_hip_ = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||
|
||||
|
||||
def vllm_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)
|
||||
|
||||
|
||||
def sglang_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32)
|
||||
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
|
||||
sgl_per_token_quant_fp8(input, output, scale)
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
def calculate_diff(batch_size: int, seq_len: int):
|
||||
"""Calculate difference between VLLM and SGLang implementations."""
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device)
|
||||
|
||||
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
|
||||
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
|
||||
|
||||
scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item()
|
||||
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
|
||||
|
||||
print(f"Scale difference: {scale_diff}")
|
||||
print(f"Output difference: {output_diff}")
|
||||
|
||||
if torch.allclose(
|
||||
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [16, 32, 64, 128]
|
||||
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sglang"],
|
||||
line_names=["VLLM", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="per-token-dynamic-quant-fp8-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_quantization(batch_size, seq_len, provider):
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "vllm":
|
||||
fn = lambda: vllm_per_token_quant_fp8(x.clone())
|
||||
elif provider == "sglang":
|
||||
fn = lambda: sglang_per_token_quant_fp8(x.clone())
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff(batch_size=4, seq_len=4096)
|
||||
benchmark_quantization.run(print_data=True)
|
||||
@@ -106,6 +106,7 @@ sources = [
|
||||
"src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu",
|
||||
"src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu",
|
||||
"src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu",
|
||||
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
|
||||
"src/sgl-kernel/csrc/speculative/eagle_utils.cu",
|
||||
|
||||
@@ -29,6 +29,7 @@ from sgl_kernel.ops.gemm import (
|
||||
int8_scaled_mm,
|
||||
sgl_per_tensor_quant_fp8,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
sgl_per_token_quant_fp8,
|
||||
)
|
||||
from sgl_kernel.ops.moe import moe_align_block_size
|
||||
from sgl_kernel.ops.sampling import (
|
||||
|
||||
143
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
Normal file
143
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
Normal file
@@ -0,0 +1,143 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#define WARP_SIZE 32
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
||||
#else
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
|
||||
#include "amd/quant_utils.cuh"
|
||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
// Using the default max value from pytorch (240.0) will cause accuracy
|
||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||
#endif
|
||||
|
||||
__device__ __forceinline__ float warpReduceMax(float max_value) {
|
||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
|
||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
|
||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
|
||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
|
||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
|
||||
return max_value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void per_token_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output_q,
|
||||
float* __restrict__ output_s, const int64_t hidden_dim,
|
||||
const int64_t num_tokens) {
|
||||
const int token_idx = blockIdx.x;
|
||||
|
||||
if (token_idx >= num_tokens) return;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int block_dim = blockDim.x;
|
||||
|
||||
const T* token_input = input + token_idx * hidden_dim;
|
||||
FP8_TYPE* token_output = output_q + token_idx * hidden_dim;
|
||||
|
||||
float max_value = 0.0f;
|
||||
|
||||
for (int i = tid; i < hidden_dim; i += block_dim) {
|
||||
float val = static_cast<float>(token_input[i]);
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
|
||||
max_value = warpReduceMax(max_value);
|
||||
|
||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||
const int laneId = threadIdx.x % WARP_SIZE;
|
||||
const int warpId = threadIdx.x / WARP_SIZE;
|
||||
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
|
||||
__syncthreads();
|
||||
|
||||
if (warpId == 0) {
|
||||
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
max_value = warpReduceMax(max_value);
|
||||
}
|
||||
|
||||
__shared__ float block_max;
|
||||
if (tid == 0) {
|
||||
block_max = max_value / FP8_E4M3_MAX;
|
||||
output_s[token_idx] = block_max;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float scale_val = 1.0f / block_max;
|
||||
|
||||
constexpr uint32_t vec_size = 16 / sizeof(T);
|
||||
using vec_t = flashinfer::vec_t<T, vec_size>;
|
||||
|
||||
const int32_t num_vec_elems = hidden_dim / vec_size;
|
||||
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(token_input + i * vec_size);
|
||||
|
||||
FP8_TYPE output_arr[vec_size];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
#ifndef USE_ROCM
|
||||
output_arr[j] = static_cast<FP8_TYPE>(val);
|
||||
#else
|
||||
output_arr[j] = c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
token_output[i * vec_size + j] = output_arr[j];
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t remaining_start = num_vec_elems * vec_size;
|
||||
for (int32_t idx = remaining_start + tid; idx < hidden_dim; idx += block_dim) {
|
||||
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(token_input[idx]) * scale_val, FP8_E4M3_MAX));
|
||||
#ifndef USE_ROCM
|
||||
token_output[idx] = static_cast<FP8_TYPE>(val);
|
||||
#else
|
||||
token_output[idx] = c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(output_q);
|
||||
CHECK_INPUT(output_s);
|
||||
|
||||
const auto input_sizes = input.sizes();
|
||||
const int64_t num_tokens = input_sizes[0];
|
||||
const int64_t hidden_dim = input_sizes[1];
|
||||
|
||||
const int block_size = 128;
|
||||
const int num_blocks = num_tokens;
|
||||
|
||||
dim3 grid(num_blocks);
|
||||
dim3 block(block_size);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
per_token_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()), hidden_dim, num_tokens);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
@@ -160,3 +160,6 @@ void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_r
|
||||
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
|
||||
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
|
||||
torch::Tensor new_kv);
|
||||
|
||||
// sgl_per_token_quant_fp8
|
||||
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
|
||||
|
||||
@@ -118,3 +118,11 @@ def cublas_grouped_gemm(
|
||||
cublas_handle,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
|
||||
|
||||
def sgl_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.sgl_per_token_quant_fp8(input, output_q, output_s)
|
||||
|
||||
@@ -171,6 +171,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
|
||||
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
|
||||
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
||||
|
||||
m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()");
|
||||
m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(_kernels)
|
||||
|
||||
55
sgl-kernel/tests/test_per_token_quant_fp8.py
Normal file
55
sgl-kernel/tests/test_per_token_quant_fp8.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import itertools
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import sgl_per_token_quant_fp8
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
is_hip_ = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||
|
||||
|
||||
def vllm_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)
|
||||
|
||||
|
||||
def sglang_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32)
|
||||
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
|
||||
|
||||
sgl_per_token_quant_fp8(input, output, scale)
|
||||
scale = scale.reshape(-1, 1)
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens,hidden_dim",
|
||||
list(itertools.product([128, 256, 512], [512, 2048, 4096])),
|
||||
)
|
||||
def test_per_token_quant_compare_implementations(
|
||||
num_tokens: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)
|
||||
|
||||
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
|
||||
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
|
||||
|
||||
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(
|
||||
vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the specific test function directly
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user