[3/N][Refactor][Quantization]remove packed_modules_mapping from models (#3021)

### What this PR does / why we need it?

Some custom models in vllm-ascend define packed_modules_mapping, which
prevent keeping same model class with vllm community. So move these
custom packed_modules_mapping to quant utils.py. After this pr, some
custom models can be removed.

### Does this PR introduce _any_ user-facing change?

tested by CI

### How was this patch tested?

tested by CI

- vLLM version: v0.10.2
- vLLM main:
5089fd749c

Signed-off-by: 22dimensions <waitingwind@foxmail.com>
This commit is contained in:
22dimensions
2025-09-19 20:50:14 +08:00
committed by GitHub
parent 4ba56716f9
commit 0942d9aaab
8 changed files with 76 additions and 80 deletions

View File

@@ -19,6 +19,7 @@ from types import MappingProxyType
from typing import Any, Callable, Dict, List, Mapping, Optional
import torch
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
@@ -89,6 +90,11 @@ class AscendQuantConfig(QuantizationConfig):
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
vllm_config = get_current_vllm_config()
model_type = vllm_config.model_config.hf_config.model_type
if model_type in packed_modules_model_mapping:
self.packed_modules_mapping = packed_modules_model_mapping[
model_type]
from vllm.attention.layer import Attention
if prefix.startswith("language_model"):
prefix = prefix.split('.', 1)[-1]
@@ -153,6 +159,61 @@ class AscendQuantConfig(QuantizationConfig):
return []
packed_modules_model_mapping = {
"qwen3_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"deepseek_v2": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
"deepseek_v3": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.
"deepseek_mtp": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
"qwen3_next": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["gate_proj", "up_proj"],
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
},
"qwen2_5_vl": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
}
class AscendLinearMethod(LinearMethodBase):
"""Linear method for Ascend quantization.