Fix AWQ Dequant and Weight Loading of deepseek v2 (#6842)

This commit is contained in:
AniZpZ
2025-06-18 04:45:10 +08:00
committed by GitHub
parent e726131523
commit 3eb4a800e8
3 changed files with 18 additions and 11 deletions

View File

@@ -130,10 +130,12 @@ __global__ void __launch_bounds__(256) dequantize_weights(
int* __restrict__ qzeros,
OutputT* __restrict__ output,
int group_size,
int qweight_cols) {
int qweight_cols,
int qweight_rows) {
#if CUDA_VERSION >= 12000
int col = blockIdx.x * blockDim.x + threadIdx.x;
int row = blockIdx.y * blockDim.y + threadIdx.y;
if (col >= qweight_cols || row >= qweight_rows) return;
int group_idx = row / group_size;
int scale_offset = 8 * col + group_idx * qweight_cols * 8;
@@ -188,8 +190,8 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
int x_num_threads = 16;
int y_num_threads = 16;
int x_blocks = qweight_cols / x_num_threads;
int y_blocks = qweight_rows / y_num_threads;
int x_blocks = (qweight_cols + x_num_threads - 1) / x_num_threads;
int y_blocks = (qweight_rows + y_num_threads - 1) / y_num_threads;
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
@@ -206,13 +208,13 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
if (scales.scalar_type() == at::ScalarType::Half) {
auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());
auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());
dequantize_weights<half>
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
dequantize_weights<half><<<num_blocks, threads_per_block, 0, stream>>>(
_qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows);
} else {
auto _scales = reinterpret_cast<__nv_bfloat16*>(scales.data_ptr<at::BFloat16>());
auto _output = reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>());
dequantize_weights<__nv_bfloat16>
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
dequantize_weights<__nv_bfloat16><<<num_blocks, threads_per_block, 0, stream>>>(
_qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows);
}
return output;