[1/N][Feat] Xlite Qwen3 MoE Support (#5951)

### What this PR does / why we need it?
This patch adds support for the Qwen3-MoE model in Xlite. For more
details about Xlite, please refer to the following
link:https://atomgit.com/openeuler/GVirt/blob/master/xlite/README.md.

Qwen3-MoE TODO List:
- [ ] Qwen3-235B-A22B support
- [ ] Qwen3-MoE weights NZ support
- [ ] Qwen3-MoE data parallel support

## Qwen3-30B-A3B-Instruct-2507 910B3(A2) Online Inference Performance
Comparison
- aclgraph: main(69b170b8b5)
- xlite-full: main + xlite-full
- xlite-decode-only: main + xlite-decode-only
- diff1: Performance comparison between xlite-full and aclgraph
- diff2: Performance comparison between xlite-decode-only and aclgraph

| maxconcurrency | item | TTFT(ms) | | TPOT(ms) | | QPS (req/s) |
OutputSpeed (token/s) |
| --- | --- | --- | --- | --- | --- | --- | --- |
|  |  | Avg | P99 | Avg | P99 |  |  |
| 1 | baseline-aclgraph | 205.07 | 287.29 | 12.34 | 12.65 | 0.14 | 78.81
|
| 1 | xlite-full | 66.40 | 113.69 | 11.71 | 12.40 | 0.15 | 84.73 |
| 1 | xlite-decode-only | 221.15 | 316.40 | 12.16 | 12.91 | 0.14 | 79.70
|
| 1 | diff1 | -67.62% | -60.43% | -5.11% | -1.98% | 7.14% | 7.51% |
| 1 | diff2 | 7.84% | 10.13% | -1.46% | 2.06% | 0.00% | 1.13% |
|  |  |  |  |  |  |  |  |
| 16 | baseline-aclgraph | 1892.16 | 13916.86 | 22.78 | 39.28 | 1.15 |
589.89 |
| 16 | xlite-full | 1355.40 | 8907.45 | 15.96 | 25.15 | 1.65 | 850.21 |
| 16 | xlite-decode-only | 1519.42 | 8711.64 | 19.23 | 29.73 | 1.38 |
711.60 |
| 16 | diff1 | -28.37% | -36.00% | -29.94% | -35.97% | 43.48% | 44.13% |
| 16 | diff2 | -19.70% | -37.40% | -15.58% | -24.31% | 20.00% | 20.63% |
|  |  |  |  |  |  |  |  |
| 32 | baseline-aclgraph | 673.80 | 3914.90 | 32.20 | 37.95 | 1.80 |
928.54 |
| 32 | xlite-full | 481.65 | 2710.50 | 19.95 | 25.35 | 2.91 | 1506.67 |
| 32 | xlite-decode-only | 372.22 | 1095.25 | 25.19 | 28.47 | 2.33 |
1202.82 |
| 32 | diff1 | -28.52% | -30.76% | -38.04% | -33.20% | 61.67% | 62.26% |
| 32 | diff2 | -44.76% | -72.02% | -21.77% | -24.98% | 29.44% | 29.54% |
|  |  |  |  |  |  |  |  |
| 48 | baseline-aclgraph | 583.18 | 3277.65 | 41.02 | 46.05 | 2.17 |
1115.08 |
| 48 | xlite-full | 973.42 | 8237.33 | 23.29 | 30.50 | 3.71 | 1908.09 |
| 48 | xlite-decode-only | 480.79 | 2026.98 | 31.48 | 35.41 | 2.83 |
1453.75 |
| 48 | diff1 | 66.92% | 151.32% | -43.22% | -33.77% | 70.97% | 71.12% |
| 48 | diff2 | -17.56% | -38.16% | -23.26% | -23.11% | 30.41% | 30.37% |
|  |  |  |  |  |  |  |  |
| 64 | baseline-aclgraph | 742.74 | 5953.39 | 47.79 | 53.15 | 2.48 |
1272.37 |
| 64 | xlite-full | 545.22 | 3941.34 | 25.09 | 30.41 | 4.64 | 2376.44 |
| 64 | xlite-decode-only | 752.40 | 4534.29 | 38.67 | 43.28 | 3.06 |
1567.94 |
| 64 | diff1 | -26.59% | -33.80% | -47.50% | -42.78% | 87.10% | 86.77% |
| 64 | diff2 | 1.30% | -23.84% | -19.08% | -18.57% | 23.39% | 23.23% |
|  |  |  |  |  |  |  |  |
| 100 | baseline-aclgraph | 565.52 | 1716.81 | 60.89 | 68.69 | 3.08 |
1580.64 |
| 100 | xlite-full | 398.14 | 2328.88 | 30.70 | 32.45 | 6.01 | 3086.42 |
| 100 | xlite-decode-only | 712.53 | 4875.94 | 52.71 | 60.78 | 3.53 |
1813.58 |
| 100 | diff1 | -29.60% | 35.65% | -49.58% | -52.76% | 95.13% | 95.26% |
| 100 | diff2 | 26.00% | 184.01% | -13.43% | -11.52% | 14.61% | 14.74% |
|  |  |  |  |  |  |  |  |
| 150 | baseline-aclgraph | 842.42 | 5175.01 | 73.60 | 88.18 | 3.80 |
1952.26 |
| 150 | xlite-full | 568.52 | 4204.33 | 37.90 | 40.01 | 7.27 | 3734.72 |
| 150 | xlite-decode-only | 654.43 | 2504.06 | 67.40 | 77.00 | 4.18 |
2145.11 |
| 150 | diff1 | -32.51% | -18.76% | -48.51% | -54.63% | 91.32% | 91.30%
|
| 150 | diff2 | -22.32% | -51.61% | -8.42% | -12.68% | 10.00% | 9.88% |
|  |  |  |  |  |  |  |  |
| 200 | baseline-aclgraph | 750.63 | 3049.91 | 88.26 | 101.95 | 4.28 |
2189.72 |
| 200 | xlite-full | 558.48 | 3791.98 | 45.54 | 49.04 | 8.17 | 4175.52 |
| 200 | xlite-decode-only | 807.09 | 4254.95 | 85.18 | 101.79 | 4.44 |
2271.52 |
| 200 | diff1 | -25.60% | 24.33% | -48.40% | -51.90% | 90.89% | 90.69% |
| 200 | diff2 | 7.52% | 39.51% | -3.49% | -0.16% | 3.74% | 3.74% |
|  |  |  |  |  |  |  |  |

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: changdawei1 <changdawei3@huawei.com>
Co-authored-by: LVYANGGUO <275926687@qq.com>
Co-authored-by: lulina <lina.lulina@huawei.com>
This commit is contained in:
Magnus
2026-01-21 09:26:03 +08:00
committed by GitHub
parent 1ab6cd4935
commit 5b129cf0a1
5 changed files with 147 additions and 70 deletions

View File

@@ -19,12 +19,14 @@ from typing import Any, Callable, Tuple
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
from vllm.distributed import (get_ep_group,
get_tensor_model_parallel_world_size,
get_world_group)
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.sequence import IntermediateTensors
from xlite._C import AttnMHA, Model, ModelAttnMeta, ModelConfig, Runtime
from xlite._C import (AttnMHA, Model, ModelAttnMeta, ModelConfig, Runtime,
ScoringFuncSoftmax)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
@@ -47,67 +49,8 @@ class LlamaXliteModel(XliteModel):
self, runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
dtype = vllm_config.model_config.dtype
params_dict = dict(runnable.named_parameters())
if hasattr(runnable, "language_model"):
layers = runnable.language_model.model.layers
model_prefix = "language_model."
else:
layers = runnable.model.layers
model_prefix = ""
config = self._build_model_config(vllm_config)
xlite_model = Model()
xlite_model.embed = params_dict.get(model_prefix +
"model.embed_tokens.weight")
xlite_model.norm = params_dict.get(model_prefix + "model.norm.weight")
if vllm_config.model_config.hf_text_config.tie_word_embeddings:
xlite_model.head = xlite_model.embed
else:
xlite_model.head = params_dict.get(model_prefix + "lm_head.weight")
xlite_model.attn_norm = [
layer.input_layernorm.weight for layer in layers
]
xlite_model.attn_out = [
layer.self_attn.o_proj.weight for layer in layers
]
xlite_model.mha_qkv = [
layer.self_attn.qkv_proj.weight for layer in layers
]
xlite_model.mlp_norm = [
layer.post_attention_layernorm.weight for layer in layers
]
xlite_model.mlp_up_gate = [
layer.mlp.gate_up_proj.weight for layer in layers
]
xlite_model.mlp_down = [layer.mlp.down_proj.weight for layer in layers]
mha_qkv_bias = [
layer.self_attn.qkv_proj.bias for layer in layers
if hasattr(layer.self_attn.qkv_proj, "bias")
and layer.self_attn.qkv_proj.bias is not None
]
q_norm = [
layer.self_attn.q_norm.weight for layer in layers
if hasattr(layer.self_attn, "q_norm")
]
k_norm = [
layer.self_attn.k_norm.weight for layer in layers
if hasattr(layer.self_attn, "k_norm")
]
if len(mha_qkv_bias) != config.n_layers:
config.qkv_bias = False
else:
config.qkv_bias = True
xlite_model.mha_qkv_bias = mha_qkv_bias
if (len(q_norm) != config.n_layers or len(k_norm) != config.n_layers):
config.qk_norm = False
else:
config.qk_norm = True
xlite_model.mha_q_norm = q_norm
xlite_model.mha_k_norm = k_norm
xlite_model = self._build_model(runnable, vllm_config, config)
rank = torch.distributed.get_rank()
xlite_model.init(config, rank)
@@ -153,6 +96,76 @@ class LlamaXliteModel(XliteModel):
config.block_size = vllm_config.cache_config.block_size
return config
def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig,
config: ModelConfig) -> Model:
params_dict = dict(runnable.named_parameters())
if hasattr(runnable, "language_model"):
layers = runnable.language_model.model.layers
model_prefix = "language_model."
else:
layers = runnable.model.layers
model_prefix = ""
xlite_model = Model()
xlite_model.embed = params_dict.get(model_prefix +
"model.embed_tokens.weight")
xlite_model.norm = params_dict.get(model_prefix + "model.norm.weight")
if vllm_config.model_config.hf_text_config.tie_word_embeddings:
xlite_model.head = xlite_model.embed
else:
xlite_model.head = params_dict.get(model_prefix + "lm_head.weight")
xlite_model.attn_norm = [
layer.input_layernorm.weight for layer in layers
]
xlite_model.attn_out = [
layer.self_attn.o_proj.weight for layer in layers
]
xlite_model.mha_qkv = [
layer.self_attn.qkv_proj.weight for layer in layers
]
xlite_model.mlp_norm = [
layer.post_attention_layernorm.weight for layer in layers
]
xlite_model.mlp_up_gate = [
layer.mlp.gate_up_proj.weight for layer in layers
if hasattr(layer.mlp, "gate_up_proj")
and layer.mlp.gate_up_proj.weight is not None
]
xlite_model.mlp_down = [
layer.mlp.down_proj.weight for layer in layers
if hasattr(layer.mlp, "down_proj")
and layer.mlp.down_proj.weight is not None
]
mha_qkv_bias = [
layer.self_attn.qkv_proj.bias for layer in layers
if hasattr(layer.self_attn.qkv_proj, "bias")
and layer.self_attn.qkv_proj.bias is not None
]
q_norm = [
layer.self_attn.q_norm.weight for layer in layers
if hasattr(layer.self_attn, "q_norm")
]
k_norm = [
layer.self_attn.k_norm.weight for layer in layers
if hasattr(layer.self_attn, "k_norm")
]
if len(mha_qkv_bias) != config.n_layers:
config.qkv_bias = False
else:
config.qkv_bias = True
xlite_model.mha_qkv_bias = mha_qkv_bias
if (len(q_norm) != config.n_layers or len(k_norm) != config.n_layers):
config.qk_norm = False
else:
config.qk_norm = True
xlite_model.mha_q_norm = q_norm
xlite_model.mha_k_norm = k_norm
return xlite_model
def _precompute_freqs_cis(self,
dim: int,
end: int,
@@ -168,6 +181,62 @@ class LlamaXliteModel(XliteModel):
return freq_cis.to(device='npu')
class QwenMoeXliteModel(LlamaXliteModel):
def initialize(
self, runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
if envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2:
architecture = vllm_config.model_config.architectures[0]
raise ValueError(
f"{architecture} not support VLLM_ASCEND_ENABLE_NZ = 2!")
dtype = vllm_config.model_config.dtype
config = self._build_model_config(vllm_config)
xlite_model = self._build_model(runnable, vllm_config, config)
rank = torch.distributed.get_rank()
xlite_model.init(config, rank)
freq_cis = super()._precompute_freqs_cis(config.head_dim,
config.max_seq_len, dtype,
config.rope_theta)
return (xlite_model, freq_cis, config.hidden_size, dtype)
def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig:
config = super()._build_model_config(vllm_config)
hf_config = vllm_config.model_config.hf_text_config
ep_group = get_ep_group()
config.n_layers = hf_config.max_window_layers
config.n_dense_layers = 0
config.n_routed_experts = hf_config.num_experts
config.n_shared_experts = 0
config.n_act_experts = hf_config.num_experts_per_tok
config.def_dp_size = vllm_config.parallel_config.data_parallel_size
config.moe_ep_size = ep_group.world_size if vllm_config.parallel_config.enable_expert_parallel else 1
config.moe_tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else ep_group.world_size
config.experts_weight_transpose = True
config.moe_intermediate_size = hf_config.moe_intermediate_size
config.norm_topk_prob = hf_config.norm_topk_prob
config.scoring_func = ScoringFuncSoftmax
return config
def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig,
config: ModelConfig) -> Model:
xlite_model = super()._build_model(runnable, vllm_config, config)
layers = runnable.model.layers
xlite_model.gate = [layer.mlp.gate.weight for layer in layers]
xlite_model.re_up_gate = [
layer.mlp.experts.w13_weight[i] for layer in layers
for i in range(layer.mlp.experts.local_num_experts)
]
xlite_model.re_down = [
layer.mlp.experts.w2_weight[i] for layer in layers
for i in range(layer.mlp.experts.local_num_experts)
]
return xlite_model
def xlite_model_init(
runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
@@ -176,6 +245,7 @@ def xlite_model_init(
"Qwen2ForCausalLM": LlamaXliteModel,
"Qwen3ForCausalLM": LlamaXliteModel,
"Qwen3VLForConditionalGeneration": LlamaXliteModel,
"Qwen3MoeForCausalLM": QwenMoeXliteModel,
}
architecture = vllm_config.model_config.architectures[0]
@@ -197,7 +267,8 @@ class XliteWrapper:
rank = torch.distributed.get_rank()
local_rank = get_world_group().local_rank
self.xlite_rt = Runtime(local_rank, 0, rank,
get_tensor_model_parallel_world_size())
get_tensor_model_parallel_world_size(),
vllm_config.parallel_config.data_parallel_size)
(self.xlite_model, self.freq_cis, hidden_size,
dtype) = xlite_model_init(runnable, vllm_config)