[Feat] Support native Kimi-K2-Thinking native W4A16 quantized experts weights (#4516)

### What this PR does / why we need it?

Adds W4A16 quantization method for the Kimi-K2-Thinking model and
updates relevant modules to support the new quantization method.

- Implements complete W4A16 quantization method including weight
packing/unpacking, per-group quantization parameter generation,
post-processing logic and MoE method application.
- Adds parameters `use_int4_w4a16`, `w1_offset` and `w2_offset`, adjusts
`with_quant` conditional logic to support W4A16 matrix multiplication.
- Adds `packed_modules_model_mapping` for Kimi-K2-Thinking model and
processing logic for `weight_packed` field.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: Ruri <zhouxiang100@huawei.com>
This commit is contained in:
Ruri
2025-12-10 15:58:52 +08:00
committed by GitHub
parent c1db298f43
commit ce5872705e
13 changed files with 781 additions and 13 deletions

View File

@@ -4,7 +4,8 @@ import torch
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import (
QUANTIZATION_METHODS, register_quantization_config)
@@ -16,8 +17,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
from vllm_ascend.quantization.quant_config import (AscendLinearMethod,
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.quantization.quant_config import (AscendFusedMoEMethod,
AscendLinearMethod,
AscendQuantConfig)
from vllm_ascend.quantization.w4a16 import AscendW4A16FusedMoEMethod
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD
@@ -142,7 +146,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
# choose quantization method
quant_method: LinearMethodBase = UnquantizedLinearMethod()
quant_method = UnquantizedLinearMethod()
if quant_scheme is not None:
layer.scheme = quant_scheme
ascend_quant_config = AscendQuantConfig(self.quant_description
@@ -150,6 +154,21 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
quant_method = AscendLinearMethod(ascend_quant_config, prefix,
None, layer)
return quant_method
if isinstance(layer, FusedMoE):
layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD
# collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
# choose quantization method
quant_method = AscendUnquantizedFusedMoEMethod(layer.moe_config)
if quant_scheme is not None:
layer.scheme = quant_scheme
ascend_quant_config = AscendQuantConfig(self.quant_description
or {})
quant_method = AscendFusedMoEMethod(
ascend_quant_config, prefix,
ascend_quant_config.packed_modules_mapping, layer)
return quant_method
return None
def get_scheme(self,
@@ -215,6 +234,10 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return AscendW8A8DynamicLinearMethod()
if weight_quant is not None:
if self._is_w4a16(weight_quant):
return AscendW4A16FusedMoEMethod()
raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")
@@ -246,6 +269,10 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return is_8_bits and is_token and is_symmetric and is_dynamic
def _is_w4a16(self, weight_quant: QuantizationArgs) -> bool:
is_4_bits = weight_quant.num_bits == 4
return is_4_bits
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
self.target_scheme_map = hf_to_vllm_mapper.apply_dict(
self.target_scheme_map)

View File

@@ -65,6 +65,9 @@ class AscendQuantConfig(QuantizationConfig):
if "shared_head" in k:
new_k = k.replace(".shared_head.", ".")
extra_quant_dict[new_k] = self.quant_description[k]
if "weight_packed" in k:
new_k = k.replace("weight_packed", "weight")
extra_quant_dict[new_k] = self.quant_description[k]
self.quant_description.update(extra_quant_dict)
def __repr__(self) -> str:
@@ -200,7 +203,8 @@ packed_modules_model_mapping = {
"kimi_k2": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"deepseek_v32": {
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -439,7 +443,9 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
per_group_param = [
"weight_scale_second", "weight_offset_second", "scale_bias"
]
] + ["weight_scale", "weight_offset"] if hasattr(
self.quant_method,
"group_size") and self.quant_method.group_size > 0 else []
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)

View File

@@ -8,6 +8,7 @@ from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD
from .w4a4_flatquant_dynamic import AscendW4A4FlatQuantDynamicLinearMethod
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
AscendW4A8DynamicLinearMethod)
from .w4a16 import AscendW4A16FusedMoEMethod
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
AscendW8A8LinearMethod)
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
@@ -16,6 +17,9 @@ from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod,
AscendW8A8PDMixLinearMethod)
ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
"W4A16": {
"moe": AscendW4A16FusedMoEMethod,
},
"W4A8_DYNAMIC": {
"linear": AscendW4A8DynamicLinearMethod,
"moe": AscendW4A8DynamicFusedMoEMethod,

View File

@@ -0,0 +1,284 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Dict, Optional
import torch
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
def unpack_from_int32(
weight: torch.Tensor,
shape: torch.Size,
num_bits: int,
packed_dim: int = 1,
) -> torch.Tensor:
"""
Unpacks quantized weights from int32 format back to original bits.
:param weight: The packed int32 tensor containing quantized weights
:param shape: Original shape to restore, defaults to None
:param num_bits: The number of bits used for quantization (<= 8)
:param packed_dim: Dimension along which weights are packed (0 or 1), defaults to 1
:return: Unpacked tensor with int8 dtype after applying offset correction
"""
assert weight.dtype == torch.int32, f"Expecting `weight.dtype` is torch.int32 but got {weight.dtype}."
assert num_bits <= 8, f"Expecting `num_bits` should not be larger than 8 but got {num_bits}."
pack_factor = 32 // num_bits
mask = (1 << num_bits) - 1
if packed_dim == 1:
unpacked_weight = torch.zeros(
(weight.shape[0], weight.shape[1] * pack_factor),
device=weight.device,
dtype=torch.int32,
)
for i in range(pack_factor):
unpacked_weight[:, i::pack_factor] = (weight >>
(num_bits * i)) & mask
original_row_size = int(shape[1])
unpacked_weight = unpacked_weight[:, :original_row_size]
else:
unpacked_weight = torch.zeros(
(weight.shape[0] * pack_factor, weight.shape[1]),
device=weight.device,
dtype=torch.int32,
)
for i in range(pack_factor):
unpacked_weight[i::pack_factor, :] = (weight >>
(num_bits * i)) & mask
original_row_size = int(shape[0])
unpacked_weight = unpacked_weight[:original_row_size, :]
offset = pow(2, num_bits) // 2
unpacked_weight = (unpacked_weight - offset).to(torch.int8)
return unpacked_weight
def pack_to_int32(weight: torch.Tensor) -> torch.Tensor:
"""
Packs quantized weights into int32 format for storage.
:param weight: The 3D tensor to pack, must be int8 or int32 dtype
:return: Packed tensor with int32 dtype optimized for storage
"""
assert weight.dim(
) == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}."
assert weight.dtype in [
torch.int8, torch.int32
], f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}."
if weight.dtype == torch.int32:
assert weight.shape[
-1] % 8 == 0, "the last dim of weight needs to be divided by 8."
packed_weight = torch_npu.npu_convert_weight_to_int4pack(
weight.flatten(0, 1))
packed_weight = packed_weight.view(weight.shape[0], weight.shape[1],
-1)
else:
assert weight.shape[
-1] % 4 == 0, "the last dim of weight needs to be divided by 4."
packed_weight = weight.view(torch.int32).contiguous()
return packed_weight
class AscendW4A16FusedMoEMethod:
"""FusedMoe method for Ascend W4A16.
"""
def __init__(self) -> None:
self.transpose_weight = True
self.num_bits = 4 # dtype = torch.int4
self.pack_factor = 8 # pack 8 of torch.int4 tensors to torch.int32
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 32)
ascend_config = get_ascend_config()
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
def get_weight(
self,
num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
assert intermediate_size_per_partition % self.pack_factor == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `pack_factor` {self.pack_factor}"
assert hidden_sizes % self.pack_factor == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}"
param_dict = {}
param_dict["w13_weight_packed"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.pack_factor,
dtype=torch.int32)
param_dict["w2_weight_packed"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.pack_factor,
dtype=torch.int32)
return param_dict
def get_dynamic_quant_param(
self,
num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
assert intermediate_size_per_partition % self.group_size == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `group_size` {self.group_size}"
assert hidden_sizes % self.group_size == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}"
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.bfloat16)
param_dict["w2_weight_scale"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=torch.bfloat16)
param_dict["w13_weight_shape"] = torch.empty(num_experts,
2,
dtype=torch.int32)
param_dict["w2_weight_shape"] = torch.empty(num_experts,
2,
dtype=torch.int32)
param_dict["w13_weight_offset"] = torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.bfloat16)
param_dict["w2_weight_offset"] = torch.zeros(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=torch.bfloat16)
return param_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = True,
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
topk_ids = topk_ids.to(torch.int32)
topk_weights = topk_weights.to(x.dtype)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight_packed,
w2=layer.w2_weight_packed,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_offset=layer.w13_weight_offset,
w2_offset=layer.w2_weight_offset,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int4_w4a16=True,
expert_map=expert_map,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.transpose_weight:
w13_shape = layer.w13_weight_packed.data.shape
w2_shape = layer.w2_weight_packed.data.shape
unpacked_w13_weight = (unpack_from_int32(
layer.w13_weight_packed.data.flatten(0, 1),
torch.Size([
w13_shape[0] * w13_shape[1],
w13_shape[2] * self.pack_factor
]),
self.num_bits,
).view(w13_shape[0], w13_shape[1],
-1).transpose(1, 2).contiguous().int())
unpacked_w2_weight = (unpack_from_int32(
layer.w2_weight_packed.data.flatten(0, 1),
torch.Size([
w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor
]),
self.num_bits,
).view(w2_shape[0], w2_shape[1],
-1).transpose(1, 2).contiguous().int())
layer.w13_weight_packed.data = pack_to_int32(unpacked_w13_weight)
layer.w2_weight_packed.data = pack_to_int32(unpacked_w2_weight)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
1, 2).contiguous()
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
1, 2).contiguous()
layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose(
1, 2).contiguous()
layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose(
1, 2).contiguous()