[1/3] fix dsv3 awq issue (#4556)

Co-authored-by: leoneo <1320612015@qq.com>
This commit is contained in:
AniZpZ
2025-03-22 16:07:17 +08:00
committed by GitHub
parent 38f25e87fc
commit 321ab756bc
2 changed files with 179 additions and 27 deletions

View File

@@ -3,6 +3,16 @@
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
#include <torch/all.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
@@ -68,32 +78,102 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
#endif
}
__device__ uint4 dequantize_s4_to_bf16x2(uint32_t const& source) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = source;
// Define masks and constants
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300;
int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s, MASK, EX);
int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 4, MASK, EX);
int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 8, MASK, EX);
int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 12, MASK, EX);
nv_bfloat162* res = reinterpret_cast<nv_bfloat162*>(h);
res[0] = __hfma2(
*reinterpret_cast<nv_bfloat162*>(&lo0),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
res[1] = __hfma2(
*reinterpret_cast<nv_bfloat162*>(&hi0),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
res[2] = __hfma2(
*reinterpret_cast<nv_bfloat162*>(&lo1),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
res[3] = __hfma2(
*reinterpret_cast<nv_bfloat162*>(&hi1),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return result;
#else
assert(false);
return {};
#endif
}
template <typename OutputT>
__global__ void __launch_bounds__(256) dequantize_weights(
int* __restrict__ qweight,
half* __restrict__ scales,
OutputT* __restrict__ scales,
int* __restrict__ qzeros,
half* __restrict__ output,
OutputT* __restrict__ output,
int group_size,
int qweight_cols) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
int row = blockIdx.y * blockDim.y + threadIdx.y;
uint4 zeros = dequantize_s4_to_fp16x2(qzeros[col + (row / group_size) * qweight_cols]);
uint4 loaded_scale = *(uint4*)(scales + 8 * col + (row / group_size) * qweight_cols * 8);
int group_idx = row / group_size;
int scale_offset = 8 * col + group_idx * qweight_cols * 8;
uint4 loaded_scale = *(uint4*)(scales + scale_offset);
uint4 weight_fp16 = dequantize_s4_to_fp16x2(qweight[col + row * qweight_cols]);
// Handle different data types
if constexpr (std::is_same<OutputT, half>::value) {
// FP16 path
uint4 zeros = dequantize_s4_to_fp16x2(qzeros[col + group_idx * qweight_cols]);
uint4 weight_fp16 = dequantize_s4_to_fp16x2(qweight[col + row * qweight_cols]);
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(loaded_scale.x));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(zeros.y));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(loaded_scale.y));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(zeros.z));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(loaded_scale.z));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(zeros.w));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(loaded_scale.w));
// Use PTX assembly for FP16 operations
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(loaded_scale.x));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(zeros.y));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(loaded_scale.y));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(zeros.z));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(loaded_scale.z));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(zeros.w));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(loaded_scale.w));
half* output_ptr = output + 8 * col + 8 * row * qweight_cols;
*(uint4*)output_ptr = weight_fp16;
OutputT* output_ptr = output + 8 * col + 8 * row * qweight_cols;
*(uint4*)output_ptr = weight_fp16;
} else if constexpr (std::is_same<OutputT, __nv_bfloat16>::value) {
uint4 weight_raw = dequantize_s4_to_bf16x2(qweight[col + row * qweight_cols]);
uint4 zero_raw = dequantize_s4_to_bf16x2(qzeros[col + group_idx * qweight_cols]);
uint4 scale_raw = *reinterpret_cast<uint4*>(scales + scale_offset);
// Vectorized processing (each uint4 contains 4 nv_bfloat162)
nv_bfloat162* weight_vec = reinterpret_cast<nv_bfloat162*>(&weight_raw);
nv_bfloat162* zero_vec = reinterpret_cast<nv_bfloat162*>(&zero_raw);
nv_bfloat162* scale_vec = reinterpret_cast<nv_bfloat162*>(&scale_raw);
// Single instruction dual-channel operation
#pragma unroll
for (int i = 0; i < 4; ++i) { // uint4 = 4 * nv_bfloat162
weight_vec[i] = __hmul2(__hsub2(weight_vec[i], zero_vec[i]), scale_vec[i]);
}
// Directly store to OutputT array (guaranteed contiguous memory)
OutputT* output_ptr = output + 8 * col + row * qweight_cols * 8;
static_assert(sizeof(uint4) == 8 * sizeof(OutputT), "Memory layout mismatch");
*reinterpret_cast<uint4*>(output_ptr) = weight_raw;
}
}
torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros) {
@@ -112,16 +192,23 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
at::Tensor output = torch::empty({qweight_rows, qweight_cols * 8}, output_tensor_options);
auto _qweight = reinterpret_cast<int*>(qweight.data_ptr<int>());
auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());
auto _zeros = reinterpret_cast<int*>(qzeros.data_ptr<int>());
auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());
dim3 num_blocks(x_blocks, y_blocks);
dim3 threads_per_block(x_num_threads, y_num_threads);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
_qweight, _scales, _zeros, _output, group_size, qweight_cols);
if (scales.scalar_type() == at::ScalarType::Half) {
auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());
auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());
dequantize_weights<half>
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
} else {
auto _scales = reinterpret_cast<__nv_bfloat16*>(scales.data_ptr<at::BFloat16>());
auto _output = reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>());
dequantize_weights<__nv_bfloat16>
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
}
return output;
}