[v0.18.0][Refactor] Use forward mapping instead of reverse mapping in AscendMo… (#7716)

…delSlimConfig (#7596)

### What this PR does / why we need it?

This PR refactors the `AscendModelSlimConfig` class to use **forward
mapping** instead of reverse mapping for quantization config key
transformation.

**Changes:**
1. Modified `apply_vllm_mapper()` to directly apply
`hf_to_vllm_mapper.apply_dict()` to transform `quant_description` keys
from HF format to vLLM format
2. Simplified `quant_prefix_mapper()` to return the prefix directly (no
longer needs mapping since keys are already in vLLM format)
3. Removed `QUANT_MODEL_PREFIX_MAPPINGS` dictionary (~50 lines) - no
longer needed
4. Removed `get_prefix_mapping()` function - no longer needed
5. Removed `vllm_to_hf_mapper` attribute - no longer needed

**Why this change is needed:**

The previous implementation used reverse mapping (vLLM → HF) which had
several issues:
- Some keys might not be used in the forward direction but would be
incorrectly used in reverse
- Empty values in the mapping would cause issues when reversed
- Required maintaining a separate `QUANT_MODEL_PREFIX_MAPPINGS` dict
that duplicated information already available in vLLM's model-specific
`WeightsMapper`

The new approach:
- Uses the forward mapping (HF → vLLM) directly from vLLM's
`WeightsMapper`
- Eliminates the need for duplicate mapping definitions
- Avoids issues with reverse mapping (unused keys, empty values)
- Aligns with how `compressed_tensors_config.py` handles the same
scenario

- vLLM version: v0.18.0
- vLLM main:
ed359c497a
---------

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

Signed-off-by: Matrix_K <zhangke144@huawei.com>
Signed-off-by: Feng-xiaosuo <tengchang1@huawei.com>
Co-authored-by: Matrix_K <zhangke144@huawei.com>
Co-authored-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Feng-xiaosuo
2026-03-27 18:25:42 +08:00
committed by GitHub
parent 7cca7e6990
commit 60e88d9541

View File

@@ -47,57 +47,6 @@ from .methods import get_scheme_class
# The config filename that ModelSlim generates after quantizing a model. # The config filename that ModelSlim generates after quantizing a model.
MODELSLIM_CONFIG_FILENAME = "quant_model_description.json" MODELSLIM_CONFIG_FILENAME = "quant_model_description.json"
# key: model_type
# value: vLLM prefix -> HF prefix mapping (used to convert vLLM layer names to HF format
# for looking up keys in quant_model_description.json)
QUANT_MODEL_PREFIX_MAPPINGS: dict[str, dict[str, str]] = {
"qwen3_vl_moe": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
"language_model.model.": "model.language_model.",
},
"qwen3_vl": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
"language_model.model.": "model.language_model.",
},
"kimi_k25": {
"mm_projector.linear_1": "mm_projector.proj.0",
"mm_projector.linear_2": "mm_projector.proj.2",
},
"qwen3_omni_moe": {
"language_model.lm_head.": "thinker.lm_head.",
"language_model.model.": "thinker.model.",
"visual.": "thinker.visual.",
},
"qwen2_5_omni": {
"language_model.lm_head.": "thinker.lm_head.",
"language_model.model.": "thinker.model.",
"visual.": "thinker.visual.",
},
"qwen2_5_omni_text": {
"language_model.": "thinker.",
"language_model.lm_head.": "thinker.lm_head.",
"language_model.model.": "thinker.model.",
},
"glm4v_moe": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
"language_model.model.": "model.language_model.",
},
"glm4v_moe_text": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
"language_model.model.": "model.language_model.",
},
"kimi_k2": {
"language_model.layers.": "language_model.model.layers.",
# mm projector
"mm_projector.proj.0": "mm_projector.linear_1",
"mm_projector.proj.2": "mm_projector.linear_2",
},
}
# key: model_type # key: model_type
# value: dict of fused module name -> list of original module names # value: dict of fused module name -> list of original module names
packed_modules_model_mapping: dict[str, dict[str, list[str]]] = { packed_modules_model_mapping: dict[str, dict[str, list[str]]] = {
@@ -309,19 +258,6 @@ def get_packed_modules_mapping(model_type: str) -> dict[str, list[str]]:
return packed_modules_model_mapping.get(model_type, {}) return packed_modules_model_mapping.get(model_type, {})
def get_prefix_mapping(model_type: str) -> dict[str, str]:
"""Get prefix mapping for a model type.
Args:
model_type: The model type string (e.g., "qwen3_vl_moe").
Returns:
Dictionary mapping original prefixes to new prefixes.
Returns empty dict if model_type is not found.
"""
return QUANT_MODEL_PREFIX_MAPPINGS.get(model_type, {})
def get_linear_quant_type( def get_linear_quant_type(
quant_description: dict[str, Any], prefix: str, packed_modules_mapping: dict[str, Any] quant_description: dict[str, Any], prefix: str, packed_modules_mapping: dict[str, Any]
) -> str | None: ) -> str | None:
@@ -428,21 +364,10 @@ class AscendModelSlimConfig(QuantizationConfig):
def __init__(self, quant_config: dict[str, Any] | None = None): def __init__(self, quant_config: dict[str, Any] | None = None):
super().__init__() super().__init__()
self.quant_description = quant_config if quant_config is not None else {} self.quant_description = quant_config if quant_config is not None else {}
# TODO(whx): remove this adaptation after adding "shared_head" self._apply_extra_quant_adaptations()
# to prefix of DeepSeekShareHead in vLLM.
extra_quant_dict = {}
for k in self.quant_description:
if "shared_head" in k:
new_k = k.replace(".shared_head.", ".")
extra_quant_dict[new_k] = self.quant_description[k]
if "weight_packed" in k:
new_k = k.replace("weight_packed", "weight")
extra_quant_dict[new_k] = self.quant_description[k]
self.quant_description.update(extra_quant_dict)
# Initialize attributes for type checking
self.model_type: str | None = None self.model_type: str | None = None
self.hf_to_vllm_mapper: WeightsMapper | None = None self.hf_to_vllm_mapper: WeightsMapper | None = None
self.vllm_to_hf_mapper: WeightsMapper | None = None self._mapper_applied = False
self._add_kvcache_quant_metadata() self._add_kvcache_quant_metadata()
def __repr__(self) -> str: def __repr__(self) -> str:
@@ -481,73 +406,31 @@ class AscendModelSlimConfig(QuantizationConfig):
return ASCEND_QUANTIZATION_METHOD return ASCEND_QUANTIZATION_METHOD
return None return None
# TODO: Modify the key values in self.quant_description instead of flipping the hf_to_vllm_mapper
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
"""Apply the vLLM model-specific mapper to this quantization config. """Apply the vLLM model-specific mapper to this quantization config.
This method is called by vLLM to apply the model-specific weight mapper This method is called by vLLM to apply the model-specific weight mapper
to the quantization configuration. It creates a reverse mapper to convert to the quantization configuration. It directly uses the forward mapping
vLLM prefixes back to HF format for looking up keys in quant_config.json. (HF -> vLLM) to transform keys in quant_description from HF format to
vLLM format.
Args: Args:
hf_to_vllm_mapper: The WeightsMapper instance provided by vLLM hf_to_vllm_mapper: The WeightsMapper instance provided by vLLM
that contains model-specific prefix mappings (HF to vLLM). that contains model-specific prefix mappings (HF to vLLM).
""" """
# Check if we already have a valid vllm_to_hf_mapper for this hf_to_vllm_mapper if self._mapper_applied and self.hf_to_vllm_mapper is hf_to_vllm_mapper:
if hasattr(self, "hf_to_vllm_mapper") and self.hf_to_vllm_mapper is hf_to_vllm_mapper:
# Same mapper instance, no need to recreate
return return
# Store the original mapper
self.hf_to_vllm_mapper = hf_to_vllm_mapper self.hf_to_vllm_mapper = hf_to_vllm_mapper
self._mapper_applied = True
# Try different ways to get the mapping based on WeightsMapper implementation if self.quant_description:
mapping_attrs = ["orig_to_new_prefix"] self.quant_description = hf_to_vllm_mapper.apply_dict(self.quant_description)
orig_to_new_prefix = {} self._add_kvcache_quant_metadata()
logger.info("Applied hf_to_vllm_mapper to quant_description keys")
for attr_name in mapping_attrs:
if hasattr(hf_to_vllm_mapper, attr_name):
orig_to_new_prefix = getattr(hf_to_vllm_mapper, attr_name)
break
# Create reverse mapping (vLLM -> HF), skipping empty values
vllm_to_hf_mapping = {}
for orig_prefix, new_prefix in orig_to_new_prefix.items():
# Skip empty values to avoid invalid keys in reverse mapping
if new_prefix:
vllm_to_hf_mapping[new_prefix] = orig_prefix
# Create and store the reverse WeightsMapper instance
if vllm_to_hf_mapping:
self.vllm_to_hf_mapper = WeightsMapper(orig_to_new_prefix=vllm_to_hf_mapping)
logger.debug(f"Created reverse mapping from hf_to_vllm_mapper: {vllm_to_hf_mapping}")
else:
logger.info("No valid reverse mapping found for WeightsMapper.")
def quant_prefix_mapper(self, model_type: str, prefix: str) -> str: def quant_prefix_mapper(self, model_type: str, prefix: str) -> str:
# Store model_type for reference
self.model_type = model_type self.model_type = model_type
# Check if manual mapping exists for this model type
# Manual mapping takes priority and is used exclusively to avoid conflicts
if model_type in QUANT_MODEL_PREFIX_MAPPINGS:
manual_mapping = QUANT_MODEL_PREFIX_MAPPINGS[model_type]
# Manual mapping is already in vLLM -> HF direction, use directly
mapper = WeightsMapper(orig_to_new_prefix=manual_mapping)
return mapper._map_name(prefix)
# Use the reverse mapper (vLLM to HF) if available
if hasattr(self, "vllm_to_hf_mapper") and self.vllm_to_hf_mapper:
return self.vllm_to_hf_mapper._map_name(prefix)
# Fall back to manual mapping for backward compatibility (simplified)
# This is only used if apply_vllm_mapper wasn't called or failed
prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type)
if prefix_mapping:
# Manual mapping is already in vLLM -> HF direction, use directly
mapper = WeightsMapper(orig_to_new_prefix=prefix_mapping)
return mapper._map_name(prefix)
return prefix return prefix
def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: