Fix AWQ Dequant and Weight Loading of deepseek v2 (#6842)
This commit is contained in:
@@ -2137,8 +2137,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
):
|
):
|
||||||
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
||||||
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
||||||
|
cat_dim = 0
|
||||||
|
if (
|
||||||
|
self.quant_config.get_name() == "awq"
|
||||||
|
or self.quant_config.get_name() == "moe_wna16"
|
||||||
|
):
|
||||||
|
cat_dim = 1
|
||||||
fused_weight = torch.cat(
|
fused_weight = torch.cat(
|
||||||
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
|
||||||
)
|
)
|
||||||
param_name = (
|
param_name = (
|
||||||
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
||||||
|
|||||||
@@ -130,10 +130,12 @@ __global__ void __launch_bounds__(256) dequantize_weights(
|
|||||||
int* __restrict__ qzeros,
|
int* __restrict__ qzeros,
|
||||||
OutputT* __restrict__ output,
|
OutputT* __restrict__ output,
|
||||||
int group_size,
|
int group_size,
|
||||||
int qweight_cols) {
|
int qweight_cols,
|
||||||
|
int qweight_rows) {
|
||||||
#if CUDA_VERSION >= 12000
|
#if CUDA_VERSION >= 12000
|
||||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
|
if (col >= qweight_cols || row >= qweight_rows) return;
|
||||||
|
|
||||||
int group_idx = row / group_size;
|
int group_idx = row / group_size;
|
||||||
int scale_offset = 8 * col + group_idx * qweight_cols * 8;
|
int scale_offset = 8 * col + group_idx * qweight_cols * 8;
|
||||||
@@ -188,8 +190,8 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
|
|||||||
|
|
||||||
int x_num_threads = 16;
|
int x_num_threads = 16;
|
||||||
int y_num_threads = 16;
|
int y_num_threads = 16;
|
||||||
int x_blocks = qweight_cols / x_num_threads;
|
int x_blocks = (qweight_cols + x_num_threads - 1) / x_num_threads;
|
||||||
int y_blocks = qweight_rows / y_num_threads;
|
int y_blocks = (qweight_rows + y_num_threads - 1) / y_num_threads;
|
||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
|
||||||
|
|
||||||
@@ -206,13 +208,13 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
|
|||||||
if (scales.scalar_type() == at::ScalarType::Half) {
|
if (scales.scalar_type() == at::ScalarType::Half) {
|
||||||
auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());
|
auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());
|
||||||
auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());
|
auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());
|
||||||
dequantize_weights<half>
|
dequantize_weights<half><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
|
_qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows);
|
||||||
} else {
|
} else {
|
||||||
auto _scales = reinterpret_cast<__nv_bfloat16*>(scales.data_ptr<at::BFloat16>());
|
auto _scales = reinterpret_cast<__nv_bfloat16*>(scales.data_ptr<at::BFloat16>());
|
||||||
auto _output = reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>());
|
auto _output = reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>());
|
||||||
dequantize_weights<__nv_bfloat16>
|
dequantize_weights<__nv_bfloat16><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
|
_qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows);
|
||||||
}
|
}
|
||||||
|
|
||||||
return output;
|
return output;
|
||||||
|
|||||||
@@ -67,8 +67,8 @@ def sglang_awq_dequantize(
|
|||||||
"qweight_row,qweight_col,is_bf16_act",
|
"qweight_row,qweight_col,is_bf16_act",
|
||||||
list(
|
list(
|
||||||
itertools.product(
|
itertools.product(
|
||||||
[3584, 18944, 128, 256, 512, 1024],
|
[3584, 18944, 128, 256, 512, 1024, 1536],
|
||||||
[448, 576, 4736, 16, 32, 64, 128],
|
[448, 576, 4736, 16, 32, 64, 128, 72],
|
||||||
[True, False],
|
[True, False],
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@@ -77,7 +77,6 @@ def test_awq_dequant_compare_implementations(
|
|||||||
qweight_row: int, qweight_col: int, is_bf16_act: bool
|
qweight_row: int, qweight_col: int, is_bf16_act: bool
|
||||||
):
|
):
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
|
||||||
qweight = torch.randint(
|
qweight = torch.randint(
|
||||||
0,
|
0,
|
||||||
torch.iinfo(torch.int32).max,
|
torch.iinfo(torch.int32).max,
|
||||||
|
|||||||
Reference in New Issue
Block a user