[3/N][Refactor][Quantization]remove packed_modules_mapping from models (#3021)

### What this PR does / why we need it?

Some custom models in vllm-ascend define packed_modules_mapping, which
prevent keeping same model class with vllm community. So move these
custom packed_modules_mapping to quant utils.py. After this pr, some
custom models can be removed.

### Does this PR introduce _any_ user-facing change?

tested by CI

### How was this patch tested?

tested by CI

- vLLM version: v0.10.2
- vLLM main:
5089fd749c

Signed-off-by: 22dimensions <waitingwind@foxmail.com>
This commit is contained in:
22dimensions
2025-09-19 20:50:14 +08:00
committed by GitHub
parent 4ba56716f9
commit 0942d9aaab
8 changed files with 76 additions and 80 deletions

View File

@@ -15,41 +15,11 @@
import math
import unittest
import pytest
import torch
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM
from vllm_ascend.torchair.models.qwen3_moe import CustomQwen3MoeAttention
class TestCustomQwen3MoeForCausalLM:
def test_class_inheritance(self):
assert issubclass(CustomQwen3MoeForCausalLM, Qwen3MoeForCausalLM)
@pytest.mark.parametrize("key, expected", [
("qkv_proj", ["q_proj", "k_proj", "v_proj"]),
("gate_up_proj", ["gate_proj", "up_proj"]),
("experts",
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]),
])
def test_packed_modules_mapping(self, key, expected):
assert CustomQwen3MoeForCausalLM.packed_modules_mapping[
key] == expected
def test_packed_modules_mapping_structure(self):
expected_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
"experts": [
"experts.0.gate_proj", "experts.0.up_proj",
"experts.0.down_proj"
]
}
assert CustomQwen3MoeForCausalLM.packed_modules_mapping == expected_mapping
class DummyRMSNorm:
def __init__(self, dim: int, eps: float = 1e-6):

View File

@@ -73,9 +73,12 @@ class TestAscendQuantConfig(TestBase):
self.assertIsNone(result)
def test_get_quant_method_for_linear(self):
mock_config = MagicMock()
mock_config.model_config.hf_config.model_type = None
linear_layer = MagicMock(spec=LinearBase)
# Test skipped layer
with patch.object(self.ascend_config,
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch.object(self.ascend_config, \
'is_layer_skipped_ascend',
return_value=True):
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
@@ -83,6 +86,7 @@ class TestAscendQuantConfig(TestBase):
# Test quantized layer
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
@@ -93,14 +97,18 @@ class TestAscendQuantConfig(TestBase):
def test_get_quant_method_for_attention(self):
attention_layer = MagicMock(spec=Attention)
with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod',
mock_config = MagicMock()
mock_config.model_config.hf_config.model_type = None
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
return_value=MagicMock()) as mock_ascend_kvcache:
# Test with fa_quant_type
method = self.ascend_config.get_quant_method(
attention_layer, ".attn")
self.assertIs(method, mock_ascend_kvcache.return_value)
with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod',
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
return_value=MagicMock()) as mock_ascend_kvcache:
# Test with kv_quant_type
modified_config = {"kv_quant_type": "C8"}
@@ -113,9 +121,12 @@ class TestAscendQuantConfig(TestBase):
fused_moe_layer = MagicMock(spec=FusedMoE)
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
mock_config = MagicMock()
mock_config.model_config.hf_config.model_type = None
# Test skipped layer
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
method = self.ascend_config.get_quant_method(
fused_moe_layer, "moe_layer")
@@ -123,6 +134,7 @@ class TestAscendQuantConfig(TestBase):
# Test quantized layer
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
method = self.ascend_config.get_quant_method(
fused_moe_layer, "moe_layer")

View File

@@ -180,14 +180,6 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
class CustomDeepSeekMTP(DeepSeekMTP):
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)

View File

@@ -320,12 +320,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
# add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)

View File

@@ -491,17 +491,6 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
class AscendQwen2_5_VLForConditionalGeneration(
Qwen2_5_VLForConditionalGeneration):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)

View File

@@ -318,19 +318,6 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)

View File

@@ -1166,15 +1166,6 @@ class Qwen3NextModel(nn.Module):
class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
MixtureOfExperts, IsHybrid):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["gate_proj", "up_proj"],
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config

View File

@@ -19,6 +19,7 @@ from types import MappingProxyType
from typing import Any, Callable, Dict, List, Mapping, Optional
import torch
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
@@ -89,6 +90,11 @@ class AscendQuantConfig(QuantizationConfig):
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
vllm_config = get_current_vllm_config()
model_type = vllm_config.model_config.hf_config.model_type
if model_type in packed_modules_model_mapping:
self.packed_modules_mapping = packed_modules_model_mapping[
model_type]
from vllm.attention.layer import Attention
if prefix.startswith("language_model"):
prefix = prefix.split('.', 1)[-1]
@@ -153,6 +159,61 @@ class AscendQuantConfig(QuantizationConfig):
return []
packed_modules_model_mapping = {
"qwen3_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"deepseek_v2": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
"deepseek_v3": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.
"deepseek_mtp": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
"qwen3_next": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["gate_proj", "up_proj"],
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
},
"qwen2_5_vl": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
}
class AscendLinearMethod(LinearMethodBase):
"""Linear method for Ascend quantization.