Add awq dequantize kernel to sgl with 1x to 3x speedup (#4104)
This commit is contained in:
118
sgl-kernel/benchmark/bench_awq_dequant.py
Normal file
118
sgl-kernel/benchmark/bench_awq_dequant.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import itertools
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import awq_dequantize
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
def vllm_awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
|
||||
|
||||
def sglang_awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
return awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
|
||||
def calculate_diff(qweight_row: int, qweight_col: int):
|
||||
"""Calculate difference between VLLM and SGLang implementations."""
|
||||
device = torch.device("cuda")
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
group_size = qweight_row
|
||||
scales_row = qweight_row // group_size
|
||||
scales_col = qweight_col * 8
|
||||
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
|
||||
qzeros = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(scales_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
|
||||
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
|
||||
|
||||
if torch.allclose(
|
||||
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
|
||||
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
|
||||
|
||||
configs = list(itertools.product(qweight_row_range, qweight_cols_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["qweight_row", "qweight_col"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sglang"],
|
||||
line_names=["VLLM", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="awq-dequantize-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(qweight_row, qweight_col, provider):
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
group_size = qweight_row
|
||||
scales_row = qweight_row // group_size
|
||||
scales_col = qweight_col * 8
|
||||
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
|
||||
qzeros = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(scales_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "vllm":
|
||||
fn = lambda: vllm_awq_dequantize(
|
||||
qweight.clone(), scales.clone(), qzeros.clone()
|
||||
)
|
||||
elif provider == "sglang":
|
||||
fn = lambda: sglang_awq_dequantize(
|
||||
qweight.clone(), scales.clone(), qzeros.clone()
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff(qweight_row=3584, qweight_col=448)
|
||||
benchmark.run(print_data=True)
|
||||
127
sgl-kernel/csrc/gemm/awq_kernel.cu
Normal file
127
sgl-kernel/csrc/gemm/awq_kernel.cu
Normal file
@@ -0,0 +1,127 @@
|
||||
// Adapted from
|
||||
// https://github.com/vllm-project/vllm/blob/eb59b5a6cba6727d3727c0372258db9002f687c1/csrc/quantization/awq/gemm_kernels.cu#L350
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||
uint4 result;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
// Note that the entire sequence only requires 1 shift instruction. This is
|
||||
// thanks to the register packing format and the fact that we force our
|
||||
// integers to be unsigned, and account for this in the fp16 subtractions. In
|
||||
// addition, I exploit the fact that sub and fma have the same throughput in
|
||||
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
|
||||
// the bottom bits before hand.
|
||||
|
||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
|
||||
// dependency if we issue immediately before required.
|
||||
const uint32_t top_i4s = i4s >> 8;
|
||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[1])
|
||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[2])
|
||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[3])
|
||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
|
||||
// This is the half2 {1024, 1024} represented as an integer.
|
||||
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
|
||||
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||
// This is the half2 {-64, -64} represented as an integer.
|
||||
static constexpr uint32_t NEG_64 = 0xd400d400;
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
// Convert elt_01
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_23
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
// Convert elt_45
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_67
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
|
||||
return result;
|
||||
#else
|
||||
assert(false);
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(256) dequantize_weights(
|
||||
int* __restrict__ qweight,
|
||||
half* __restrict__ scales,
|
||||
int* __restrict__ qzeros,
|
||||
half* __restrict__ output,
|
||||
int group_size,
|
||||
int qweight_cols) {
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
uint4 zeros = dequantize_s4_to_fp16x2(qzeros[col + (row / group_size) * qweight_cols]);
|
||||
uint4 loaded_scale = *(uint4*)(scales + 8 * col + (row / group_size) * qweight_cols * 8);
|
||||
|
||||
uint4 weight_fp16 = dequantize_s4_to_fp16x2(qweight[col + row * qweight_cols]);
|
||||
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x));
|
||||
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(loaded_scale.x));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(zeros.y));
|
||||
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(loaded_scale.y));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(zeros.z));
|
||||
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(loaded_scale.z));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(zeros.w));
|
||||
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(loaded_scale.w));
|
||||
|
||||
half* output_ptr = output + 8 * col + 8 * row * qweight_cols;
|
||||
*(uint4*)output_ptr = weight_fp16;
|
||||
}
|
||||
|
||||
torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros) {
|
||||
int qweight_rows = qweight.size(0);
|
||||
int qweight_cols = qweight.size(1);
|
||||
int group_size = qweight_rows / scales.size(0);
|
||||
|
||||
int x_num_threads = 16;
|
||||
int y_num_threads = 16;
|
||||
int x_blocks = qweight_cols / x_num_threads;
|
||||
int y_blocks = qweight_rows / y_num_threads;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
|
||||
|
||||
auto output_tensor_options = torch::TensorOptions().dtype(scales.dtype()).device(scales.device());
|
||||
at::Tensor output = torch::empty({qweight_rows, qweight_cols * 8}, output_tensor_options);
|
||||
|
||||
auto _qweight = reinterpret_cast<int*>(qweight.data_ptr<int>());
|
||||
auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());
|
||||
auto _zeros = reinterpret_cast<int*>(qzeros.data_ptr<int>());
|
||||
auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());
|
||||
|
||||
dim3 num_blocks(x_blocks, y_blocks);
|
||||
dim3 threads_per_block(x_num_threads, y_num_threads);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
_qweight, _scales, _zeros, _output, group_size, qweight_cols);
|
||||
|
||||
return output;
|
||||
}
|
||||
@@ -75,6 +75,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
m.def("awq_dequantize(Tensor qweight, Tensor scales, Tensor qzeros) -> Tensor");
|
||||
m.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
|
||||
|
||||
m.def(
|
||||
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
|
||||
"bias) -> Tensor");
|
||||
|
||||
@@ -112,6 +112,7 @@ void apply_rope_pos_ids_cos_sin_cache(
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros);
|
||||
torch::Tensor int8_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
|
||||
@@ -23,6 +23,7 @@ from sgl_kernel.elementwise import (
|
||||
silu_and_mul,
|
||||
)
|
||||
from sgl_kernel.gemm import (
|
||||
awq_dequantize,
|
||||
bmm_fp8,
|
||||
cublas_grouped_gemm,
|
||||
fp8_blockwise_scaled_mm,
|
||||
|
||||
@@ -4,6 +4,12 @@ import torch
|
||||
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
|
||||
|
||||
|
||||
def awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> torch.ByteTensor:
|
||||
return torch.ops.sgl_kernels.awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
|
||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernel.int8_scaled_mm(
|
||||
mat_a,
|
||||
|
||||
@@ -150,6 +150,7 @@ sources = [
|
||||
"csrc/elementwise/rope.cu",
|
||||
"csrc/gemm/bmm_fp8.cu",
|
||||
"csrc/gemm/cublas_grouped_gemm.cu",
|
||||
"csrc/gemm/awq_kernel.cu",
|
||||
"csrc/gemm/fp8_gemm_kernel.cu",
|
||||
"csrc/gemm/fp8_blockwise_gemm_kernel.cu",
|
||||
"csrc/gemm/int8_gemm_kernel.cu",
|
||||
|
||||
67
sgl-kernel/tests/test_awq_dequant.py
Normal file
67
sgl-kernel/tests/test_awq_dequant.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import itertools
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import awq_dequantize
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
def vllm_awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
|
||||
|
||||
def sglang_awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"qweight_row,qweight_col",
|
||||
list(
|
||||
itertools.product(
|
||||
[3584, 18944, 128, 256, 512, 1024], [448, 576, 4736, 16, 32, 64, 128]
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_awq_dequant_compare_implementations(
|
||||
qweight_row: int,
|
||||
qweight_col: int,
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
group_size = qweight_row
|
||||
scales_row = qweight_row // group_size
|
||||
scales_col = qweight_col * 8
|
||||
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
|
||||
qzeros = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(scales_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Run both implementations
|
||||
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
|
||||
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
# Compare results
|
||||
torch.testing.assert_close(
|
||||
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the specific test function directly
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user