[1/3] fix dsv3 awq issue (#4556)
Co-authored-by: leoneo <1320612015@qq.com>
This commit is contained in:
@@ -3,6 +3,16 @@
|
|||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <int lut>
|
||||||
|
__device__ inline int lop3(int a, int b, int c) {
|
||||||
|
int res;
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
|
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||||
@@ -68,32 +78,102 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ uint4 dequantize_s4_to_bf16x2(uint32_t const& source) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
uint4 result;
|
||||||
|
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||||
|
uint32_t const i4s = source;
|
||||||
|
|
||||||
|
// Define masks and constants
|
||||||
|
static constexpr uint32_t MASK = 0x000f000f;
|
||||||
|
static constexpr uint32_t EX = 0x43004300;
|
||||||
|
static constexpr uint32_t MUL = 0x3F803F80;
|
||||||
|
static constexpr uint32_t ADD = 0xC300C300;
|
||||||
|
|
||||||
|
int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s, MASK, EX);
|
||||||
|
int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 4, MASK, EX);
|
||||||
|
int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 8, MASK, EX);
|
||||||
|
int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 12, MASK, EX);
|
||||||
|
|
||||||
|
nv_bfloat162* res = reinterpret_cast<nv_bfloat162*>(h);
|
||||||
|
res[0] = __hfma2(
|
||||||
|
*reinterpret_cast<nv_bfloat162*>(&lo0),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||||
|
res[1] = __hfma2(
|
||||||
|
*reinterpret_cast<nv_bfloat162*>(&hi0),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||||
|
res[2] = __hfma2(
|
||||||
|
*reinterpret_cast<nv_bfloat162*>(&lo1),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||||
|
res[3] = __hfma2(
|
||||||
|
*reinterpret_cast<nv_bfloat162*>(&hi1),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
return {};
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename OutputT>
|
||||||
__global__ void __launch_bounds__(256) dequantize_weights(
|
__global__ void __launch_bounds__(256) dequantize_weights(
|
||||||
int* __restrict__ qweight,
|
int* __restrict__ qweight,
|
||||||
half* __restrict__ scales,
|
OutputT* __restrict__ scales,
|
||||||
int* __restrict__ qzeros,
|
int* __restrict__ qzeros,
|
||||||
half* __restrict__ output,
|
OutputT* __restrict__ output,
|
||||||
int group_size,
|
int group_size,
|
||||||
int qweight_cols) {
|
int qweight_cols) {
|
||||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
uint4 zeros = dequantize_s4_to_fp16x2(qzeros[col + (row / group_size) * qweight_cols]);
|
int group_idx = row / group_size;
|
||||||
uint4 loaded_scale = *(uint4*)(scales + 8 * col + (row / group_size) * qweight_cols * 8);
|
int scale_offset = 8 * col + group_idx * qweight_cols * 8;
|
||||||
|
uint4 loaded_scale = *(uint4*)(scales + scale_offset);
|
||||||
|
|
||||||
uint4 weight_fp16 = dequantize_s4_to_fp16x2(qweight[col + row * qweight_cols]);
|
// Handle different data types
|
||||||
|
if constexpr (std::is_same<OutputT, half>::value) {
|
||||||
|
// FP16 path
|
||||||
|
uint4 zeros = dequantize_s4_to_fp16x2(qzeros[col + group_idx * qweight_cols]);
|
||||||
|
uint4 weight_fp16 = dequantize_s4_to_fp16x2(qweight[col + row * qweight_cols]);
|
||||||
|
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x));
|
// Use PTX assembly for FP16 operations
|
||||||
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(loaded_scale.x));
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x));
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(zeros.y));
|
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(loaded_scale.x));
|
||||||
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(loaded_scale.y));
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(zeros.y));
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(zeros.z));
|
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(loaded_scale.y));
|
||||||
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(loaded_scale.z));
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(zeros.z));
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(zeros.w));
|
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(loaded_scale.z));
|
||||||
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(loaded_scale.w));
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(zeros.w));
|
||||||
|
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(loaded_scale.w));
|
||||||
|
|
||||||
half* output_ptr = output + 8 * col + 8 * row * qweight_cols;
|
OutputT* output_ptr = output + 8 * col + 8 * row * qweight_cols;
|
||||||
*(uint4*)output_ptr = weight_fp16;
|
*(uint4*)output_ptr = weight_fp16;
|
||||||
|
} else if constexpr (std::is_same<OutputT, __nv_bfloat16>::value) {
|
||||||
|
uint4 weight_raw = dequantize_s4_to_bf16x2(qweight[col + row * qweight_cols]);
|
||||||
|
uint4 zero_raw = dequantize_s4_to_bf16x2(qzeros[col + group_idx * qweight_cols]);
|
||||||
|
uint4 scale_raw = *reinterpret_cast<uint4*>(scales + scale_offset);
|
||||||
|
|
||||||
|
// Vectorized processing (each uint4 contains 4 nv_bfloat162)
|
||||||
|
nv_bfloat162* weight_vec = reinterpret_cast<nv_bfloat162*>(&weight_raw);
|
||||||
|
nv_bfloat162* zero_vec = reinterpret_cast<nv_bfloat162*>(&zero_raw);
|
||||||
|
nv_bfloat162* scale_vec = reinterpret_cast<nv_bfloat162*>(&scale_raw);
|
||||||
|
|
||||||
|
// Single instruction dual-channel operation
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; ++i) { // uint4 = 4 * nv_bfloat162
|
||||||
|
weight_vec[i] = __hmul2(__hsub2(weight_vec[i], zero_vec[i]), scale_vec[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Directly store to OutputT array (guaranteed contiguous memory)
|
||||||
|
OutputT* output_ptr = output + 8 * col + row * qweight_cols * 8;
|
||||||
|
static_assert(sizeof(uint4) == 8 * sizeof(OutputT), "Memory layout mismatch");
|
||||||
|
*reinterpret_cast<uint4*>(output_ptr) = weight_raw;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros) {
|
torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros) {
|
||||||
@@ -112,16 +192,23 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
|
|||||||
at::Tensor output = torch::empty({qweight_rows, qweight_cols * 8}, output_tensor_options);
|
at::Tensor output = torch::empty({qweight_rows, qweight_cols * 8}, output_tensor_options);
|
||||||
|
|
||||||
auto _qweight = reinterpret_cast<int*>(qweight.data_ptr<int>());
|
auto _qweight = reinterpret_cast<int*>(qweight.data_ptr<int>());
|
||||||
auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());
|
|
||||||
auto _zeros = reinterpret_cast<int*>(qzeros.data_ptr<int>());
|
auto _zeros = reinterpret_cast<int*>(qzeros.data_ptr<int>());
|
||||||
auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());
|
|
||||||
|
|
||||||
dim3 num_blocks(x_blocks, y_blocks);
|
dim3 num_blocks(x_blocks, y_blocks);
|
||||||
dim3 threads_per_block(x_num_threads, y_num_threads);
|
dim3 threads_per_block(x_num_threads, y_num_threads);
|
||||||
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
|
|
||||||
_qweight, _scales, _zeros, _output, group_size, qweight_cols);
|
if (scales.scalar_type() == at::ScalarType::Half) {
|
||||||
|
auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());
|
||||||
|
auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());
|
||||||
|
dequantize_weights<half>
|
||||||
|
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
|
||||||
|
} else {
|
||||||
|
auto _scales = reinterpret_cast<__nv_bfloat16*>(scales.data_ptr<at::BFloat16>());
|
||||||
|
auto _output = reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>());
|
||||||
|
dequantize_weights<__nv_bfloat16>
|
||||||
|
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
|
||||||
|
}
|
||||||
|
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,57 @@ from sgl_kernel import awq_dequantize
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_awq_order(t: torch.Tensor):
|
||||||
|
bits = 4
|
||||||
|
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||||
|
reverse_order_tensor = torch.arange(
|
||||||
|
t.shape[-1],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=t.device,
|
||||||
|
)
|
||||||
|
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
|
||||||
|
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
|
||||||
|
reverse_order_tensor = reverse_order_tensor.view(-1)
|
||||||
|
|
||||||
|
t = t[:, reverse_order_tensor] & 0xF
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
# qweights - [R , C // 8], int32
|
||||||
|
# scales - [R // G, C ], float16
|
||||||
|
# zeros - [R // G, C // 8], int32
|
||||||
|
def awq_dequantize_torch(
|
||||||
|
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if group_size == -1:
|
||||||
|
group_size = qweight.shape[0]
|
||||||
|
|
||||||
|
bits = 4
|
||||||
|
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
||||||
|
|
||||||
|
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
|
||||||
|
torch.int8
|
||||||
|
)
|
||||||
|
|
||||||
|
iweights = iweights.view(iweights.shape[0], -1)
|
||||||
|
|
||||||
|
zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
|
||||||
|
torch.int8
|
||||||
|
)
|
||||||
|
zeros = zeros.view(qzeros.shape[0], -1)
|
||||||
|
zeros = reverse_awq_order(zeros)
|
||||||
|
|
||||||
|
iweights = reverse_awq_order(iweights)
|
||||||
|
|
||||||
|
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
|
||||||
|
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
|
||||||
|
|
||||||
|
scales = scales.repeat_interleave(group_size, dim=0)
|
||||||
|
zeros = zeros.repeat_interleave(group_size, dim=0)
|
||||||
|
return (iweights - zeros) * scales
|
||||||
|
|
||||||
|
|
||||||
def vllm_awq_dequantize(
|
def vllm_awq_dequantize(
|
||||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -20,16 +71,17 @@ def sglang_awq_dequantize(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"qweight_row,qweight_col",
|
"qweight_row,qweight_col,is_bf16_act",
|
||||||
list(
|
list(
|
||||||
itertools.product(
|
itertools.product(
|
||||||
[3584, 18944, 128, 256, 512, 1024], [448, 576, 4736, 16, 32, 64, 128]
|
[3584, 18944, 128, 256, 512, 1024],
|
||||||
|
[448, 576, 4736, 16, 32, 64, 128],
|
||||||
|
[True, False],
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def test_awq_dequant_compare_implementations(
|
def test_awq_dequant_compare_implementations(
|
||||||
qweight_row: int,
|
qweight_row: int, qweight_col: int, is_bf16_act: bool
|
||||||
qweight_col: int,
|
|
||||||
):
|
):
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
|
||||||
@@ -43,7 +95,12 @@ def test_awq_dequant_compare_implementations(
|
|||||||
group_size = qweight_row
|
group_size = qweight_row
|
||||||
scales_row = qweight_row // group_size
|
scales_row = qweight_row // group_size
|
||||||
scales_col = qweight_col * 8
|
scales_col = qweight_col * 8
|
||||||
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
|
|
||||||
|
if is_bf16_act:
|
||||||
|
scales = torch.rand(scales_row, scales_col, dtype=torch.bfloat16, device=device)
|
||||||
|
else:
|
||||||
|
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
|
||||||
|
|
||||||
qzeros = torch.randint(
|
qzeros = torch.randint(
|
||||||
0,
|
0,
|
||||||
torch.iinfo(torch.int32).max,
|
torch.iinfo(torch.int32).max,
|
||||||
@@ -53,13 +110,21 @@ def test_awq_dequant_compare_implementations(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Run both implementations
|
# Run both implementations
|
||||||
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
|
vllm_out = vllm_awq_dequantize(qweight, scales.to(torch.float16), qzeros)
|
||||||
|
torch_out = awq_dequantize_torch(qweight, scales, qzeros, group_size)
|
||||||
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
|
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
|
||||||
|
|
||||||
# Compare results
|
# Compare results
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
torch_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||||
)
|
)
|
||||||
|
if not is_bf16_act:
|
||||||
|
torch.testing.assert_close(
|
||||||
|
vllm_out.to(torch.float32),
|
||||||
|
sglang_out.to(torch.float32),
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user