Files
xc-llm-ascend/vllm_ascend/quantization/modelslim_config.py
Feng-xiaosuo abe72d7cb9 Refactor quantization layer name mapping to leverage vLLM built-in mappers (#7050)
…the quantization layer name

### What this PR does / why we need it?
This PR modifies the loading logic for layer name prefixes in quantized
models. The goal is to reduce or eliminate the need for point-to-point
(hardcoded) modifications by leveraging the built-in mapper mechanism
already provided in vLLM's model code. For models that do not yet have a
corresponding mapper, the original point-to-point modification approach
has been retained to ensure backward compatibility.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
The changes were validated using an offline deployment script to launch
and verify multiple multimodal models. Testing confirmed that the
updated loading logic correctly handles layer name prefixes across
different model architectures, with no regression in model
initialization or inference behavior.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Matrix_K <zhangke144@huawei.com>
Signed-off-by: Feng-xiaosuo <tengchang1@huawei.com>
Co-authored-by: Matrix_K <zhangke144@huawei.com>
2026-03-12 15:48:14 +08:00

609 lines
23 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
"""ModelSlim quantization configuration and model mappings for Ascend.
This module provides the AscendModelSlimConfig class for parsing quantization
configs generated by the ModelSlim tool, along with model-specific mappings.
"""
from collections.abc import Mapping
from types import MappingProxyType
from typing import Any, Optional
import torch
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig, QuantizeMethodBase
from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod, VocabParallelEmbedding
from vllm.model_executor.models.utils import WeightsMapper
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
from .methods import get_scheme_class
logger = init_logger(__name__)
# key: model_type
# value: vLLM prefix -> HF prefix mapping (used to convert vLLM layer names to HF format
# for looking up keys in quant_model_description.json)
QUANT_MODEL_PREFIX_MAPPINGS: dict[str, dict[str, str]] = {
"qwen3_omni_moe": {
"language_model.lm_head.": "thinker.lm_head.",
"language_model.model.": "thinker.model.",
"visual.": "thinker.visual.",
},
"qwen2_5_omni": {
"language_model.lm_head.": "thinker.lm_head.",
"language_model.model.": "thinker.model.",
"visual.": "thinker.visual.",
},
"qwen2_5_omni_text": {
"language_model.": "thinker.",
"language_model.lm_head.": "thinker.lm_head.",
"language_model.model.": "thinker.model.",
},
"glm4v_moe": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
"language_model.model.": "model.language_model.",
},
"glm4v_moe_text": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
"language_model.model.": "model.language_model.",
},
}
# key: model_type
# value: dict of fused module name -> list of original module names
packed_modules_model_mapping: dict[str, dict[str, list[str]]] = {
"qwen3_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"qwen3_5": {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
"in_proj_ba": ["in_proj_b", "in_proj_a"],
},
"qwen3_5_moe": {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
"in_proj_ba": ["in_proj_b", "in_proj_a"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"deepseek_v2": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"deepseek_v3": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"pangu_ultra_moe": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"kimi_k2": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"deepseek_v32": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"glm_moe_dsa": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.
"deepseek_mtp": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"pangu_ultra_moe_mtp": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"qwen3_next": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["gate_proj", "up_proj"],
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"qwen2_5_vl": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
},
"qwen3_vl_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"glm4_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"glm4_moe_lite": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"glm4v_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"glm4v_moe_text": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"longcat_flash": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"minimax_m2": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"experts": ["experts.0.w1", "experts.0.w2", "experts.0.w3"],
},
"qwen3_omni_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"attn_qkv_proj": [
"attn_q_proj",
"attn_k_proj",
"attn_v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"qwen2_5_omni": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"attn_qkv_proj": [
"attn_q_proj",
"attn_k_proj",
"attn_v_proj",
],
"qkv": [
"q",
"k",
"v",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
},
}
def get_packed_modules_mapping(model_type: str) -> dict[str, list[str]]:
"""Get packed modules mapping for a model type.
Args:
model_type: The model type string (e.g., "deepseek_v3").
Returns:
Dictionary mapping fused module names to their component module names.
Returns empty dict if model_type is not found.
"""
return packed_modules_model_mapping.get(model_type, {})
def get_prefix_mapping(model_type: str) -> dict[str, str]:
"""Get prefix mapping for a model type.
Args:
model_type: The model type string (e.g., "qwen3_vl_moe").
Returns:
Dictionary mapping original prefixes to new prefixes.
Returns empty dict if model_type is not found.
"""
return QUANT_MODEL_PREFIX_MAPPINGS.get(model_type, {})
def get_linear_quant_type(
quant_description: dict[str, Any], prefix: str, packed_modules_mapping: dict[str, Any]
) -> str | None:
"""Determine the quantization type for a linear layer.
Args:
quant_description: The quantization description dictionary.
prefix: The layer prefix.
packed_modules_mapping: Mapping for packed/fused modules.
Returns:
The quantization type string (e.g., "W8A8_DYNAMIC").
"""
proj_name = prefix.split(".")[-1]
if proj_name in packed_modules_mapping:
quant_type = None
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name) for shard_proj_name in packed_modules_mapping[proj_name]
]
for shard_prefix in shard_prefixes:
shard_quant_type = quant_description[shard_prefix + ".weight"]
if quant_type is None:
quant_type = shard_quant_type
elif shard_quant_type != quant_type:
raise ValueError(
f"Not all shards of {prefix} are quantized with same quant type."
f"Shard {proj_name} uses {shard_quant_type}, but another shard"
f"use {quant_type}. Please check quantization config."
)
else:
quant_type = quant_description[prefix + ".weight"]
return quant_type
def get_quant_type_for_layer(
quant_description: dict[str, Any],
prefix: str,
layer_type: str,
packed_modules_mapping: dict[str, Any] | None = None,
) -> str | None:
"""Determine the quantization type for a layer.
Args:
quant_description: The quantization description dictionary.
prefix: The layer prefix.
layer_type: The type of layer ("linear", "moe", "attention").
packed_modules_mapping: Mapping for packed/fused modules.
Returns:
The quantization type string (e.g., "W8A8_DYNAMIC").
"""
if packed_modules_mapping is None:
packed_modules_mapping = dict()
# Attention
if layer_type == "attention" and "fa_quant_type" in quant_description:
return quant_description["fa_quant_type"]
# Linear / MoE
return get_linear_quant_type(quant_description, prefix, packed_modules_mapping)
def create_scheme_for_layer(
quant_description: dict[str, Any],
prefix: str,
layer_type: str,
packed_modules_mapping: dict[str, Any] | None = None,
):
"""Create a quantization scheme instance for a layer.
Args:
quant_description: The quantization description dictionary.
prefix: The layer prefix.
layer_type: The type of layer ("linear", "moe", "attention").
packed_modules_mapping: Mapping for packed/fused modules.
Returns:
An instance of the appropriate quantization scheme class.
"""
logger.info_once("Using the vLLM Ascend modelslim Quantization now!")
quant_type = get_quant_type_for_layer(quant_description, prefix, layer_type, packed_modules_mapping)
if quant_type is None:
raise ValueError(f"Could not determine quantization type for layer {prefix}.")
# Use registry to get scheme class
scheme_cls = get_scheme_class(quant_type, layer_type)
if scheme_cls is not None:
return scheme_cls()
raise NotImplementedError(f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}.")
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
class AscendModelSlimConfig(QuantizationConfig):
"""Config class for Ascend ModelSlim quantization.
This class is a general class that parses quantization configs
that are supported on Ascend hardware, specifically for models
quantized using the ModelSlim tool.
"""
def __init__(self, quant_config: dict[str, Any]):
super().__init__()
self.quant_description = quant_config
# TODO(whx): remove this adaptation after adding "shared_head"
# to prefix of DeepSeekShareHead in vLLM.
extra_quant_dict = {}
for k in self.quant_description:
if "shared_head" in k:
new_k = k.replace(".shared_head.", ".")
extra_quant_dict[new_k] = self.quant_description[k]
if "weight_packed" in k:
new_k = k.replace("weight_packed", "weight")
extra_quant_dict[new_k] = self.quant_description[k]
self.quant_description.update(extra_quant_dict)
# Initialize attributes for type checking
self.model_type: str | None = None
self.hf_to_vllm_mapper: WeightsMapper | None = None
self.vllm_to_hf_mapper: WeightsMapper | None = None
def __repr__(self) -> str:
return "AscendModelSlimConfig:\n" + super().__repr__()
@classmethod
def get_name(cls) -> str:
return ASCEND_QUANTIZATION_METHOD
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.int8, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError('Ascend hardware dose not support "get_min_capability" feature.')
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quant_model_description.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "AscendModelSlimConfig":
return cls(config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> str | None:
if hf_quant_cfg is not None:
quant_method = hf_quant_cfg.get("quant_method", None)
if not quant_method and torch.npu.is_available():
return ASCEND_QUANTIZATION_METHOD
return None
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
"""Apply the vLLM model-specific mapper to this quantization config.
This method is called by vLLM to apply the model-specific weight mapper
to the quantization configuration. It creates a reverse mapper to convert
vLLM prefixes back to HF format for looking up keys in quant_config.json.
Args:
hf_to_vllm_mapper: The WeightsMapper instance provided by vLLM
that contains model-specific prefix mappings (HF to vLLM).
"""
# Check if we already have a valid vllm_to_hf_mapper for this hf_to_vllm_mapper
if hasattr(self, "hf_to_vllm_mapper") and self.hf_to_vllm_mapper is hf_to_vllm_mapper:
# Same mapper instance, no need to recreate
return
# Store the original mapper
self.hf_to_vllm_mapper = hf_to_vllm_mapper
# Check if manual mapping exists for this model type
# Manual mapping takes priority and is used exclusively to avoid conflicts
if hasattr(self, "model_type") and self.model_type in QUANT_MODEL_PREFIX_MAPPINGS:
manual_mapping = QUANT_MODEL_PREFIX_MAPPINGS[self.model_type]
# Manual mapping is already in vLLM -> HF direction, use directly
self.vllm_to_hf_mapper = WeightsMapper(orig_to_new_prefix=manual_mapping)
logger.debug(f"Using manual mapping for {self.model_type}: {manual_mapping}")
return
# No manual mapping, use hf_to_vllm_mapper and reverse it
# Try different ways to get the mapping based on WeightsMapper implementation
mapping_attrs = ["orig_to_new_prefix"]
orig_to_new_prefix = {}
for attr_name in mapping_attrs:
if hasattr(hf_to_vllm_mapper, attr_name):
orig_to_new_prefix = getattr(hf_to_vllm_mapper, attr_name)
break
# Create reverse mapping (vLLM -> HF), skipping empty values
vllm_to_hf_mapping = {}
for orig_prefix, new_prefix in orig_to_new_prefix.items():
# Skip empty values to avoid invalid keys in reverse mapping
if new_prefix:
vllm_to_hf_mapping[new_prefix] = orig_prefix
# Create and store the reverse WeightsMapper instance
if vllm_to_hf_mapping:
self.vllm_to_hf_mapper = WeightsMapper(orig_to_new_prefix=vllm_to_hf_mapping)
logger.debug(f"Created reverse mapping from hf_to_vllm_mapper: {vllm_to_hf_mapping}")
else:
logger.info("No valid reverse mapping found for WeightsMapper.")
def quant_prefix_mapper(self, model_type: str, prefix: str) -> str:
# Store model_type for backward compatibility mappings
self.model_type = model_type
# Use the reverse mapper (vLLM to HF) if available
if hasattr(self, "vllm_to_hf_mapper") and self.vllm_to_hf_mapper:
return self.vllm_to_hf_mapper._map_name(prefix)
# Fall back to manual mapping for backward compatibility (simplified)
# This is only used if apply_vllm_mapper wasn't called or failed
prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type)
if prefix_mapping:
# Manual mapping is already in vLLM -> HF direction, use directly
mapper = WeightsMapper(orig_to_new_prefix=prefix_mapping)
return mapper._map_name(prefix)
return prefix
def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]:
from .method_adapters import (
AscendEmbeddingMethod,
AscendFusedMoEMethod,
AscendKVCacheMethod,
AscendLinearMethod,
)
vllm_config = get_current_vllm_config()
model_type = vllm_config.model_config.hf_config.model_type
if model_type in ["minimax", "minimax_m2"]:
# Adapt to Minimax architecture: update layer names to MoE convention
prefix = prefix.replace("mlp", "block_sparse_moe")
# Normalize the prefix by stripping specific expert indices (e.g., 'experts.0' -> 'experts')
parts = prefix.split(".")
if "experts" in parts and len(parts) > 2:
exp_idx = parts.index("experts")
if exp_idx + 1 < len(parts) and parts[exp_idx + 1].isdigit():
parts = parts[: exp_idx + 1]
prefix = ".".join(parts)
if model_type in packed_modules_model_mapping:
self.packed_modules_mapping = packed_modules_model_mapping[model_type]
prefix = self.quant_prefix_mapper(model_type, prefix)
from vllm.model_executor.layers.attention import Attention
if isinstance(layer, LinearBase):
if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping):
# Delayed import to avoid circular import
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
return AscendUnquantizedLinearMethod()
scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping)
return AscendLinearMethod(scheme)
elif (
isinstance(layer, Attention)
and "fa_quant_type" in self.quant_description
and self.quant_description["fa_quant_type"] is not None
):
scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping)
return AscendKVCacheMethod(scheme)
elif isinstance(layer, FusedMoE):
if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping):
# Delayed import to avoid circular import
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
scheme = create_scheme_for_layer(self.quant_description, prefix, "moe", self.packed_modules_mapping)
return AscendFusedMoEMethod(scheme, layer.moe_config)
elif isinstance(layer, VocabParallelEmbedding):
if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping):
return UnquantizedEmbeddingMethod()
scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping)
return AscendEmbeddingMethod(scheme)
return None
def is_layer_skipped_ascend(self, prefix: str, fused_mapping: Mapping[str, list[str]] = MappingProxyType({})):
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
proj_name = prefix.split(".")[-1]
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name) for shard_proj_name in fused_mapping[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = self.quant_description[shard_prefix + ".weight"] == "FLOAT"
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision."
)
else:
is_skipped = any(
key.startswith(prefix) and key.endswith(".weight") and value == "FLOAT"
for key, value in self.quant_description.items()
)
assert is_skipped is not None
return is_skipped
def get_scaled_act_names(self) -> list[str]:
return []