[dev] support AWQ/GPTQ quantization for dense models

This commit is contained in:
Li Wei
2025-12-24 13:45:55 +08:00
parent 75d0bdae2f
commit 6546323c71
5 changed files with 412 additions and 2 deletions

View File

@@ -9,7 +9,7 @@ blake3==1.0.5
cachetools==6.1.0
cbor2==5.7.0
cloudpickle==3.1.1
compressed-tensors==0.10.2
compressed-tensors==0.11.0
diskcache==5.6.3
gguf==0.17.1
mistral_common==1.8.3

View File

@@ -16,4 +16,6 @@
#
import vllm_kunlun.ops.rotary_embedding
import vllm_kunlun.ops.layernorm
import vllm_kunlun.ops.layernorm
import vllm_kunlun.ops.quantization.awq
import vllm_kunlun.ops.quantization.gptq

View File

@@ -0,0 +1,128 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Li Wei, Pan Xiakai, You Zeyu
# Email: liwei157@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Optional
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4):
"""Convert AWQ-packed int4 weights to Kunlun XPU format.
Input: packed[N, K], dtype=int32, saved as AWQ order
Output: packed_reordered[N, K], dtype=int32, saved as Kunlun order
"""
N, K = packed.shape
self.align_type = 1 if K % 8 == 0 else 0
assert num_bits == 4, "Only int4 supported now"
shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32)
if self.align_type == 0: # NORMAL MODE
# Unpack AWQ order:[0, 2, 4, 6, 1, 3, 5, 7]
unpacked_awq = (packed.unsqueeze(-1) >> shifts) & 0xF # [N, K, 8]
# Reverse AWQ order and convert to KUNLUN order
AWQ_TO_KUNLUN_ORDER_NORMAL = [4, 0, 5, 1, 6, 2, 7, 3]
# [0,2,4,6,1,3,5,7] --> [1, 0, 3, 2, 5, 4, 7, 6]
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_NORMAL] # [N, K, 8]
# Pack to int32, order[6, 7, 4, 5, 2, 3, 0, 1]
packed_kunlun = (unpacked_kunlun << shifts).sum(
dim=-1, dtype=torch.int32
) # [N, K]
elif self.align_type == 1: # FAST MODEL
# Unpack AWQ order
unpacked_awq = (
packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts
) & 0xF # [N, K//8, 8, 8]
# Reverse AWQ order and convert to KUNLUN order
AWQ_TO_KUNLUN_ORDER_FAST = [
32, 0, 36, 4, 33, 1, 37, 5,
34, 2, 38, 6, 35, 3, 39, 7,
40, 8, 44, 12, 41, 9, 45, 13,
42, 10, 46, 14, 43, 11, 47, 15,
48, 16, 52, 20, 49, 17, 53, 21,
50, 18, 54, 22, 51, 19, 55, 23,
56, 24, 60, 28, 57, 25, 61, 29,
58, 26, 62, 30, 59, 27, 63, 31
]
unpacked_awq = unpacked_awq.reshape(N, K // 8, 64)
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
# Pack to int32
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8)
packed_kunlun = (
(unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K)
) # [N, K]
else:
raise NotImplementedError
return packed_kunlun
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.qweight = torch.nn.Parameter(
(
self.repack_int4_for_kunlun(layer.qweight.data)
if layer.qweight.data.dtype == torch.int32
else layer.qweight.data
),
requires_grad=False,
)
layer.qzeros = torch.nn.Parameter(
(
self.repack_int4_for_kunlun(layer.qzeros.data)
if layer.qzeros.data.dtype == torch.int32
else layer.qzeros.data
),
requires_grad=False,
)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
def apply(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.qzeros
pack_factor = self.quant_config.pack_factor
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
reshaped_x = x.reshape(-1, x.shape[-1])
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = torch.ops._C.awq_dequantize(
qweight, scales, qzeros, quant_type=0, align_type=self.align_type
)
out = torch.matmul(reshaped_x, out)
else:
out = torch.ops._C.awq_gemm(
reshaped_x, qweight, scales, qzeros, align_type=self.align_type
)
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)
AWQLinearMethod.repack_int4_for_kunlun = repack_int4_for_kunlun
AWQLinearMethod.process_weights_after_loading = process_weights_after_loading
AWQLinearMethod.apply = apply

View File

@@ -0,0 +1,108 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Li Wei, You Zeyu
# Email: liwei157@baidu.com, youzeyu@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.nn.parameter import Parameter
from typing import Optional
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod, ExllamaState
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# for torch.compile
layer.qzeros = Parameter(
self.repack_int4_for_kunlun(layer.qzeros.data, self.quant_config.weight_bits)
if self.quant_config.weight_bits == 4 else layer.qzeros.data,
requires_grad=False
)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
layer.g_idx.data = torch.empty((0, ),
dtype=torch.int,
device=layer.g_idx.device)
layer.exllama_state = ExllamaState.READY
# No need shuffle on xpu
# ops.gptq_shuffle(layer.qweight, layer.g_idx,
# self.quant_config.weight_bits)
def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4):
N, K = packed.shape
assert num_bits == 4, "Only int4 supported now"
shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32)
# Unpack int32 to int4 values
unpacked_gptq = (
packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts
) & 0xF # [N, K//8, 8, 8]
# Convert to KUNLUN order
GPTQ_TO_KUNLUN_ORDER_FAST = [
32, 0, 33, 1, 34, 2, 35, 3,
36, 4, 37, 5, 38, 6, 39, 7,
40, 8, 41, 9, 42, 10, 43, 11,
44, 12, 45, 13, 46, 14, 47, 15,
48, 16, 49, 17, 50, 18, 51, 19,
52, 20, 53, 21, 54, 22, 55, 23,
56, 24, 57, 25, 58, 26, 59, 27,
60, 28, 61, 29, 62, 30, 63, 31,
]
unpacked_gptq = unpacked_gptq.reshape(N, K // 8, 64)
unpacked_kunlun = unpacked_gptq[..., GPTQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
# Pack to int32
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8)
packed_kunlun = (
(unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K)
) # [N, K]
return packed_kunlun
def apply(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
output = torch.ops.xspeedgate_ops.gptq_gemm(
reshaped_x,
layer.qweight,
layer.qzeros,
layer.scales,
layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.quant_config.weight_bits,
)
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)
GPTQLinearMethod.repack_int4_for_kunlun = repack_int4_for_kunlun
GPTQLinearMethod.process_weights_after_loading = process_weights_after_loading
GPTQLinearMethod.apply = apply

View File

@@ -1149,3 +1149,175 @@ def fake_moe_post(
return None
moe_post.register_fake(fake_moe_post)
##################################################
# --------------- awq_dequantize -----------------
##################################################
@custom_op("_C::awq_dequantize", mutates_args=())
def awq_dequantize(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
quant_type: int = 0,
align_type: int = 1,
) -> torch.Tensor:
weight = torch.empty(
(qweight.shape[0], qweight.shape[1] * 8),
dtype=torch.float16,
device=qweight.device,
)
group_m = int(qweight.shape[0] / scales.shape[0])
xtorch_ops.awq_dequantize(
qweight=qweight,
scales=scales,
zeros=zeros,
weight=weight,
group_m=group_m,
quant_type=quant_type,
align_type=align_type,
)
return weight
@impl("_C::awq_dequantize", "CUDA")
def awq_dequantize_cuda(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
quant_type: int = 0,
align_type: int = 1,
) -> torch.Tensor:
weight = torch.empty(
(qweight.shape[0], qweight.shape[1] * 8),
dtype=torch.float16,
device=qweight.device,
)
group_m = int(qweight.shape[0] / scales.shape[0])
out = xtorch_ops.awq_dequantize(
qweight=qweight,
scales=scales,
zeros=zeros,
weight=weight,
group_m=group_m,
quant_type=quant_type,
align_type=align_type,
)
return weight
def _fake_awq_dequantize(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
quant_type: int = 0,
align_type: int = 1,
) -> torch.Tensor:
weight = torch.empty(
(qweight.shape[0], qweight.shape[1] * 8),
dtype=torch.float16,
device=qweight.device,
)
return weight
awq_dequantize.register_fake(_fake_awq_dequantize)
##################################################
# ------------------ awq_gemm -------------------
##################################################
@custom_op("_C::awq_gemm", mutates_args=())
def awq_gemm(
x: torch.Tensor,
qweight: torch.Tensor,
scale: torch.Tensor,
zeros: torch.Tensor,
align_type: int = 1,
) -> torch.Tensor:
out = torch.empty(
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
)
group_size = int(qweight.shape[0] / scale.shape[0])
xtorch_ops.awq_gemm(
x=x,
w=qweight,
scale=scale,
zeros=zeros,
out=out,
align_type=align_type,
group_size=group_size,
)
return out
@impl("_C::awq_gemm", "CUDA")
def awq_gemm_cuda(
x: torch.Tensor,
qweight: torch.Tensor,
scale: torch.Tensor,
zeros: torch.Tensor,
align_type: int = 1,
) -> torch.Tensor:
out = torch.empty(
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
)
group_size = int(qweight.shape[0] / scale.shape[0])
xtorch_ops.awq_gemm(
x=x,
w=qweight,
scale=scale,
zeros=zeros,
out=out,
align_type=align_type,
group_size=group_size,
)
return out
def _fake_awq_gemm(
x: torch.Tensor,
qweight: torch.Tensor,
scale: torch.Tensor,
zeros: torch.Tensor,
align_type: int = 1,
) -> torch.Tensor:
out = torch.empty(
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
)
return out
awq_gemm.register_fake(_fake_awq_gemm)
##################################################
# ---------------- gptq_shuffle ------------------
##################################################
@custom_op("_C::gptq_shuffle", mutates_args=())
def gptq_shuffle(
q_weight: torch.Tensor,
q_perm: torch.Tensor,
bit: int,
) -> None:
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
@impl("_C::gptq_shuffle", "CUDA")
def gptq_shuffle_cuda(
q_weight: torch.Tensor,
q_perm: torch.Tensor,
bit: int,
) -> None:
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
def _fake_gptq_shuffle(
q_weight: torch.Tensor,
q_perm: torch.Tensor,
bit: int,
) -> None:
return None
gptq_shuffle.register_fake(_fake_gptq_shuffle)