[model] Support PanguUltraMoE (#4615)

### What this PR does / why we need it?
To support PanguUltraMoE model

### Test result
#### Start serving using W8A8 quantized model and ACL graph:
Master node:
```
vllm serve $LOCAL_CKPT_DIR \
        --host 0.0.0.0 \
        --port 8000 \
        --data-parallel-size 2 \
        --data-parallel-size-local 1 \
        --data-parallel-address $MASTER_NODE_IP \
        --data-parallel-rpc-port 13389 \
        --tensor-parallel-size 16 \
        --seed 1024 \
        --enable-expert-parallel \
        --served-model-name $NAME \
        --max-model-len 4096 \
        --max-num-batched-tokens 256 \
        --max-num-seqs 18 \
        --trust-remote-code \
        --gpu-memory-utilization 0.90 \
        --quantization ascend \
        --additional-config '{"ascend_scheduler_config":{"enabled":false, "enable_chunked_prefill":true, "chunked_prefill_enabled":true},"torchair_graph_config":{"enabled":false}}' \
        --speculative_config '{"method": "pangu_ultra_moe_mtp", "num_speculative_tokens": 1}' \
```
Other nodes:
```
vllm serve $LOCAL_CKPT_DIR \
        --host 0.0.0.0 \
        --port 8000 \
        --headless \
        --data-parallel-size 2 \
        --data-parallel-size-local 1 \
        --data-parallel-start-rank 1 \
        --data-parallel-address $MASTER_NODE_IP \
        --data-parallel-rpc-port 13389 \
        --tensor-parallel-size 16 \
        --seed 1024 \
        --enable-expert-parallel \
        --served-model-name $NAME \
        --max-model-len 4096 \
        --max-num-batched-tokens 256 \
        --max-num-seqs 18 \
        --trust-remote-code \
        --gpu-memory-utilization 0.90 \
        --quantization ascend \
        --additional-config '{"ascend_scheduler_config":{"enabled":false, "enable_chunked_prefill":true, "chunked_prefill_enabled":true},"torchair_graph_config":{"enabled":false}}' \
        --speculative_config '{"method": "pangu_ultra_moe_mtp", "num_speculative_tokens": 1}' \
```
Request & Response:

- Request
```
curl http://localhost:8000/v1/chat/completions \
    -H "Content-Type: application/json" \
    -d '{
        "messages": [
      {"role": "system", "content": ""},
      {"role": "user", "content": "你是谁?"}
    ],
        "max_tokens": "64",
        "top_p": "0.95",
        "top_k": "50",
        "temperature": "0.6",
        "add_special_tokens" : true
    }'
```
- Response
```
[unused16] 好的,用户问我是谁,我需要按照之前的设定来回答。首先,我的角色是盘古,由华为开发,属于推理模型。要强调我的主要功能是解答问题和提供信息支持,特别是通过逻辑推理和数据分析处理复杂任务。需要保持回答简洁,用中文,并且符合用户的
```


- vLLM version: v0.12.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0

Signed-off-by: lijifu <lijifu4@huawei.com>
Co-authored-by: lijifu <lijifu4@huawei.com>
This commit is contained in:
JeffLee1874
2025-12-17 16:15:29 +08:00
committed by GitHub
parent f0060fc822
commit 724d04391e
2 changed files with 14 additions and 0 deletions

View File

@@ -221,6 +221,12 @@ packed_modules_model_mapping = {
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
}, },
"pangu_ultra_moe": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"kimi_k2": { "kimi_k2": {
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
"experts": "experts":
@@ -241,6 +247,12 @@ packed_modules_model_mapping = {
"experts": "experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
}, },
"pangu_ultra_moe_mtp": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"qwen3_next": { "qwen3_next": {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",

View File

@@ -44,6 +44,8 @@ PADDING_SLOT_ID = -1
_MTP_MODELS = { _MTP_MODELS = {
"DeepseekV3ForCausalLM": "DeepseekV3ForCausalLM":
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"), ("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
"PanguUltraMoEForCausalLM":
("vllm.model_executor.models.openpangu_mtp", "OpenPanguMTP"),
"DeepseekV32ForCausalLM": "DeepseekV32ForCausalLM":
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"), ("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
"Qwen3NextForCausalLM": "Qwen3NextForCausalLM":