diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index e5361dd1..555b05cd 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -50,7 +50,8 @@ MODELSLIM_CONFIG_FILENAME = "quant_model_description.json" logger = init_logger(__name__) # key: model_type -# value: orig_to_new_prefix +# value: vLLM prefix -> HF prefix mapping (used to convert vLLM layer names to HF format +# for looking up keys in quant_model_description.json) QUANT_MODEL_PREFIX_MAPPINGS: dict[str, dict[str, str]] = { "qwen3_vl_moe": { "visual.": "model.visual.", @@ -78,6 +79,8 @@ QUANT_MODEL_PREFIX_MAPPINGS: dict[str, dict[str, str]] = { }, "qwen2_5_omni_text": { "language_model.": "thinker.", + "language_model.lm_head.": "thinker.lm_head.", + "language_model.model.": "thinker.model.", }, "glm4v_moe": { "visual.": "model.visual.", @@ -89,15 +92,11 @@ QUANT_MODEL_PREFIX_MAPPINGS: dict[str, dict[str, str]] = { "language_model.lm_head.": "lm_head.", "language_model.model.": "model.language_model.", }, - "qwen3_5": { - "visual.": "model.visual.", - "language_model.lm_head.": "lm_head.", - "language_model.model.": "model.language_model.", - }, - "qwen3_5_moe": { - "visual.": "model.visual.", - "language_model.lm_head.": "lm_head.", - "language_model.model.": "model.language_model.", + "kimi_k2": { + "language_model.layers.": "language_model.model.layers.", + # mm projector + "mm_projector.proj.0": "mm_projector.linear_1", + "mm_projector.proj.2": "mm_projector.linear_2", }, } @@ -440,6 +439,10 @@ class AscendModelSlimConfig(QuantizationConfig): new_k = k.replace("weight_packed", "weight") extra_quant_dict[new_k] = self.quant_description[k] self.quant_description.update(extra_quant_dict) + # Initialize attributes for type checking + self.model_type: str | None = None + self.hf_to_vllm_mapper: WeightsMapper | None = None + self.vllm_to_hf_mapper: WeightsMapper | None = None self._add_kvcache_quant_metadata() def __repr__(self) -> str: @@ -478,12 +481,73 @@ class AscendModelSlimConfig(QuantizationConfig): return ASCEND_QUANTIZATION_METHOD return None + # TODO: Modify the key values in self.quant_description instead of flipping the hf_to_vllm_mapper + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + """Apply the vLLM model-specific mapper to this quantization config. + + This method is called by vLLM to apply the model-specific weight mapper + to the quantization configuration. It creates a reverse mapper to convert + vLLM prefixes back to HF format for looking up keys in quant_config.json. + + Args: + hf_to_vllm_mapper: The WeightsMapper instance provided by vLLM + that contains model-specific prefix mappings (HF to vLLM). + """ + # Check if we already have a valid vllm_to_hf_mapper for this hf_to_vllm_mapper + if hasattr(self, "hf_to_vllm_mapper") and self.hf_to_vllm_mapper is hf_to_vllm_mapper: + # Same mapper instance, no need to recreate + return + + # Store the original mapper + self.hf_to_vllm_mapper = hf_to_vllm_mapper + + # Try different ways to get the mapping based on WeightsMapper implementation + mapping_attrs = ["orig_to_new_prefix"] + orig_to_new_prefix = {} + + for attr_name in mapping_attrs: + if hasattr(hf_to_vllm_mapper, attr_name): + orig_to_new_prefix = getattr(hf_to_vllm_mapper, attr_name) + break + + # Create reverse mapping (vLLM -> HF), skipping empty values + vllm_to_hf_mapping = {} + for orig_prefix, new_prefix in orig_to_new_prefix.items(): + # Skip empty values to avoid invalid keys in reverse mapping + if new_prefix: + vllm_to_hf_mapping[new_prefix] = orig_prefix + + # Create and store the reverse WeightsMapper instance + if vllm_to_hf_mapping: + self.vllm_to_hf_mapper = WeightsMapper(orig_to_new_prefix=vllm_to_hf_mapping) + logger.debug(f"Created reverse mapping from hf_to_vllm_mapper: {vllm_to_hf_mapping}") + else: + logger.info("No valid reverse mapping found for WeightsMapper.") + def quant_prefix_mapper(self, model_type: str, prefix: str) -> str: - # TODO (Levi-JQ): will be removed when QuantizationConfig.apply_vllm_mapper is implemented + # Store model_type for reference + self.model_type = model_type + + # Check if manual mapping exists for this model type + # Manual mapping takes priority and is used exclusively to avoid conflicts + if model_type in QUANT_MODEL_PREFIX_MAPPINGS: + manual_mapping = QUANT_MODEL_PREFIX_MAPPINGS[model_type] + # Manual mapping is already in vLLM -> HF direction, use directly + mapper = WeightsMapper(orig_to_new_prefix=manual_mapping) + return mapper._map_name(prefix) + + # Use the reverse mapper (vLLM to HF) if available + if hasattr(self, "vllm_to_hf_mapper") and self.vllm_to_hf_mapper: + return self.vllm_to_hf_mapper._map_name(prefix) + + # Fall back to manual mapping for backward compatibility (simplified) + # This is only used if apply_vllm_mapper wasn't called or failed prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type) if prefix_mapping: - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix=prefix_mapping) - return hf_to_vllm_mapper._map_name(prefix) + # Manual mapping is already in vLLM -> HF direction, use directly + mapper = WeightsMapper(orig_to_new_prefix=prefix_mapping) + return mapper._map_name(prefix) + return prefix def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: @@ -512,9 +576,6 @@ class AscendModelSlimConfig(QuantizationConfig): self.packed_modules_mapping = packed_modules_model_mapping[model_type] prefix = self.quant_prefix_mapper(model_type, prefix) - if model_type != "kimi_k2": - if prefix.startswith("language_model"): - prefix = prefix.split(".", 1)[-1] if isinstance(layer, LinearBase): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): # Delayed import to avoid circular import