[Feature] Support AWQ MoE W4A16 Quantization (#142)

Signed-off-by: tangshiwen <tangshiwen@baidu.com>
Co-authored-by: Li Wei <liwei.109@outlook.com>
This commit is contained in:
Shiwen Tang
2026-01-26 18:56:05 +08:00
committed by GitHub
parent 2a998286c0
commit 0711c1abfa
7 changed files with 639 additions and 126 deletions

View File

@@ -1,6 +1,6 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Li Wei, Pan Xiakai, You Zeyu
# Author: Li Wei, Pan Xiakai, You Zeyu, Tang Shiwen
# Email: liwei157@baidu.com
# This file is a part of the vllm-kunlun project.
#
@@ -16,13 +16,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import torch
from vllm.logger import init_logger
from typing import Optional
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.awq import (
AWQLinearMethod,
AWQConfig,
is_layer_skipped_awq,
)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
logger = init_logger(__name__)
class KunlunAWQLinearMethod(AWQLinearMethod):
def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4):
"""Convert AWQ-packed int4 weights to Kunlun XPU format.
Input: packed[N, K], dtype=int32, saved as AWQ order
@@ -64,7 +76,9 @@ class KunlunAWQLinearMethod(AWQLinearMethod):
58, 26, 62, 30, 59, 27, 63, 31
]
unpacked_awq = unpacked_awq.reshape(N, K // 8, 64)
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
unpacked_kunlun = unpacked_awq[
..., AWQ_TO_KUNLUN_ORDER_FAST
] # [N, K//8, 64]
# Pack to int32
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8)
@@ -76,7 +90,6 @@ class KunlunAWQLinearMethod(AWQLinearMethod):
return packed_kunlun
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
logger.warning_once(f"Repacking INT4 for XPU ...")
layer.qweight = torch.nn.Parameter(
@@ -97,9 +110,11 @@ class KunlunAWQLinearMethod(AWQLinearMethod):
)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
def apply(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
@@ -125,11 +140,42 @@ class KunlunAWQLinearMethod(AWQLinearMethod):
return out.reshape(out_shape)
class KunlunAWQConfig(AWQConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]: # type: ignore
if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return KunlunAWQLinearMethod(self)
elif isinstance(layer, FusedMoE):
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
"Falling back to Moe WNA16 kernels."
)
config = {
"quant_method": "awq",
"bits": self.weight_bits,
"group_size": self.group_size,
"zero_point": self.zero_point,
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix)
return None
# monkey patch
from vllm.model_executor.layers.quantization import awq
awq.AWQLinearMethod = KunlunAWQLinearMethod
awq.AWQConfig = KunlunAWQConfig
print(
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.awq.AWQLinearMethod \
--> vllm_kunlun.ops.quantization.awq.KunlunAWQLinearMethod"
)
print(
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.awq.AWQConfig \
--> vllm_kunlun.ops.quantization.awq.KunlunAWQConfig"
)