Fix AWQ Dequant and Weight Loading of deepseek v2 (#6842)
This commit is contained in:
@@ -130,10 +130,12 @@ __global__ void __launch_bounds__(256) dequantize_weights(
|
||||
int* __restrict__ qzeros,
|
||||
OutputT* __restrict__ output,
|
||||
int group_size,
|
||||
int qweight_cols) {
|
||||
int qweight_cols,
|
||||
int qweight_rows) {
|
||||
#if CUDA_VERSION >= 12000
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
if (col >= qweight_cols || row >= qweight_rows) return;
|
||||
|
||||
int group_idx = row / group_size;
|
||||
int scale_offset = 8 * col + group_idx * qweight_cols * 8;
|
||||
@@ -188,8 +190,8 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
|
||||
|
||||
int x_num_threads = 16;
|
||||
int y_num_threads = 16;
|
||||
int x_blocks = qweight_cols / x_num_threads;
|
||||
int y_blocks = qweight_rows / y_num_threads;
|
||||
int x_blocks = (qweight_cols + x_num_threads - 1) / x_num_threads;
|
||||
int y_blocks = (qweight_rows + y_num_threads - 1) / y_num_threads;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
|
||||
|
||||
@@ -206,13 +208,13 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
|
||||
if (scales.scalar_type() == at::ScalarType::Half) {
|
||||
auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());
|
||||
auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());
|
||||
dequantize_weights<half>
|
||||
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
|
||||
dequantize_weights<half><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
_qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows);
|
||||
} else {
|
||||
auto _scales = reinterpret_cast<__nv_bfloat16*>(scales.data_ptr<at::BFloat16>());
|
||||
auto _output = reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>());
|
||||
dequantize_weights<__nv_bfloat16>
|
||||
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
|
||||
dequantize_weights<__nv_bfloat16><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
_qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows);
|
||||
}
|
||||
|
||||
return output;
|
||||
|
||||
@@ -67,8 +67,8 @@ def sglang_awq_dequantize(
|
||||
"qweight_row,qweight_col,is_bf16_act",
|
||||
list(
|
||||
itertools.product(
|
||||
[3584, 18944, 128, 256, 512, 1024],
|
||||
[448, 576, 4736, 16, 32, 64, 128],
|
||||
[3584, 18944, 128, 256, 512, 1024, 1536],
|
||||
[448, 576, 4736, 16, 32, 64, 128, 72],
|
||||
[True, False],
|
||||
)
|
||||
),
|
||||
@@ -77,7 +77,6 @@ def test_awq_dequant_compare_implementations(
|
||||
qweight_row: int, qweight_col: int, is_bf16_act: bool
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
|
||||
Reference in New Issue
Block a user