diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index d31a696..49e4e07 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -157,6 +157,15 @@ class AscendQuantConfig(QuantizationConfig): f"Detected some but not all shards of {prefix} " "are quantized. All shards of fused layers " "to have the same precision.") + elif "experts" in prefix: + # For the experts' prefix (e.g., "model.layers.3.mlp.experts") + # Assume all experts within the same MLP use the same quantization method + experts_quant_description = [ + self.quant_description[layer] + for layer in self.quant_description if prefix in layer + ] + is_skipped = any(quantization == "FLOAT" + for quantization in experts_quant_description) else: is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT" diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py index 6d914c0..0fb156a 100644 --- a/vllm_ascend/quantization/utils.py +++ b/vllm_ascend/quantization/utils.py @@ -52,6 +52,17 @@ def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str, f"Not all shards of {prefix} are quantized with same quant type." f"Shard {proj_name} uses {shard_quant_type}, but another shard" f"use {quant_type}. Please check quantization config.") + elif "experts" in prefix: + # For the experts' prefix (e.g., "model.layers.3.mlp.experts") + # Assume all experts within the same MLP use the same quantization method + experts_quant_description = set(quant_description[layer] + for layer in quant_description + if prefix in layer) + if not len(experts_quant_description) == 1: + raise RuntimeError( + f"{prefix} has different quantization type: {experts_quant_description}." + ) + quant_type = experts_quant_description.pop() else: quant_type = quant_description[prefix + '.weight'] return quant_type