[Feature] Support AWQ MoE W4A16 Quantization (#142)
Signed-off-by: tangshiwen <tangshiwen@baidu.com> Co-authored-by: Li Wei <liwei.109@outlook.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Li Wei, Pan Xiakai, You Zeyu
|
||||
# Author: Li Wei, Pan Xiakai, You Zeyu, Tang Shiwen
|
||||
# Email: liwei157@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
@@ -16,13 +16,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from typing import Optional
|
||||
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.awq import (
|
||||
AWQLinearMethod,
|
||||
AWQConfig,
|
||||
is_layer_skipped_awq,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KunlunAWQLinearMethod(AWQLinearMethod):
|
||||
|
||||
def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4):
|
||||
"""Convert AWQ-packed int4 weights to Kunlun XPU format.
|
||||
Input: packed[N, K], dtype=int32, saved as AWQ order
|
||||
@@ -64,7 +76,9 @@ class KunlunAWQLinearMethod(AWQLinearMethod):
|
||||
58, 26, 62, 30, 59, 27, 63, 31
|
||||
]
|
||||
unpacked_awq = unpacked_awq.reshape(N, K // 8, 64)
|
||||
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
|
||||
unpacked_kunlun = unpacked_awq[
|
||||
..., AWQ_TO_KUNLUN_ORDER_FAST
|
||||
] # [N, K//8, 64]
|
||||
|
||||
# Pack to int32
|
||||
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8)
|
||||
@@ -76,7 +90,6 @@ class KunlunAWQLinearMethod(AWQLinearMethod):
|
||||
|
||||
return packed_kunlun
|
||||
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
logger.warning_once(f"Repacking INT4 for XPU ...")
|
||||
layer.qweight = torch.nn.Parameter(
|
||||
@@ -97,9 +110,11 @@ class KunlunAWQLinearMethod(AWQLinearMethod):
|
||||
)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
|
||||
def apply(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
@@ -125,11 +140,42 @@ class KunlunAWQLinearMethod(AWQLinearMethod):
|
||||
return out.reshape(out_shape)
|
||||
|
||||
|
||||
class KunlunAWQConfig(AWQConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]: # type: ignore
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
return KunlunAWQLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
logger.warning_once(
|
||||
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
|
||||
"Falling back to Moe WNA16 kernels."
|
||||
)
|
||||
config = {
|
||||
"quant_method": "awq",
|
||||
"bits": self.weight_bits,
|
||||
"group_size": self.group_size,
|
||||
"zero_point": self.zero_point,
|
||||
"lm_head": False,
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# monkey patch
|
||||
from vllm.model_executor.layers.quantization import awq
|
||||
|
||||
awq.AWQLinearMethod = KunlunAWQLinearMethod
|
||||
awq.AWQConfig = KunlunAWQConfig
|
||||
print(
|
||||
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.awq.AWQLinearMethod \
|
||||
--> vllm_kunlun.ops.quantization.awq.KunlunAWQLinearMethod"
|
||||
)
|
||||
print(
|
||||
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.awq.AWQConfig \
|
||||
--> vllm_kunlun.ops.quantization.awq.KunlunAWQConfig"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user