### What this PR does / why we need it? This patch adds support for the Qwen3-MoE model in Xlite. For more details about Xlite, please refer to the following link:https://atomgit.com/openeuler/GVirt/blob/master/xlite/README.md. Qwen3-MoE TODO List: - [ ] Qwen3-235B-A22B support - [ ] Qwen3-MoE weights NZ support - [ ] Qwen3-MoE data parallel support ## Qwen3-30B-A3B-Instruct-2507 910B3(A2) Online Inference Performance Comparison - aclgraph: main(69b170b8b5) - xlite-full: main + xlite-full - xlite-decode-only: main + xlite-decode-only - diff1: Performance comparison between xlite-full and aclgraph - diff2: Performance comparison between xlite-decode-only and aclgraph | maxconcurrency | item | TTFT(ms) | | TPOT(ms) | | QPS (req/s) | OutputSpeed (token/s) | | --- | --- | --- | --- | --- | --- | --- | --- | | | | Avg | P99 | Avg | P99 | | | | 1 | baseline-aclgraph | 205.07 | 287.29 | 12.34 | 12.65 | 0.14 | 78.81 | | 1 | xlite-full | 66.40 | 113.69 | 11.71 | 12.40 | 0.15 | 84.73 | | 1 | xlite-decode-only | 221.15 | 316.40 | 12.16 | 12.91 | 0.14 | 79.70 | | 1 | diff1 | -67.62% | -60.43% | -5.11% | -1.98% | 7.14% | 7.51% | | 1 | diff2 | 7.84% | 10.13% | -1.46% | 2.06% | 0.00% | 1.13% | | | | | | | | | | | 16 | baseline-aclgraph | 1892.16 | 13916.86 | 22.78 | 39.28 | 1.15 | 589.89 | | 16 | xlite-full | 1355.40 | 8907.45 | 15.96 | 25.15 | 1.65 | 850.21 | | 16 | xlite-decode-only | 1519.42 | 8711.64 | 19.23 | 29.73 | 1.38 | 711.60 | | 16 | diff1 | -28.37% | -36.00% | -29.94% | -35.97% | 43.48% | 44.13% | | 16 | diff2 | -19.70% | -37.40% | -15.58% | -24.31% | 20.00% | 20.63% | | | | | | | | | | | 32 | baseline-aclgraph | 673.80 | 3914.90 | 32.20 | 37.95 | 1.80 | 928.54 | | 32 | xlite-full | 481.65 | 2710.50 | 19.95 | 25.35 | 2.91 | 1506.67 | | 32 | xlite-decode-only | 372.22 | 1095.25 | 25.19 | 28.47 | 2.33 | 1202.82 | | 32 | diff1 | -28.52% | -30.76% | -38.04% | -33.20% | 61.67% | 62.26% | | 32 | diff2 | -44.76% | -72.02% | -21.77% | -24.98% | 29.44% | 29.54% | | | | | | | | | | | 48 | baseline-aclgraph | 583.18 | 3277.65 | 41.02 | 46.05 | 2.17 | 1115.08 | | 48 | xlite-full | 973.42 | 8237.33 | 23.29 | 30.50 | 3.71 | 1908.09 | | 48 | xlite-decode-only | 480.79 | 2026.98 | 31.48 | 35.41 | 2.83 | 1453.75 | | 48 | diff1 | 66.92% | 151.32% | -43.22% | -33.77% | 70.97% | 71.12% | | 48 | diff2 | -17.56% | -38.16% | -23.26% | -23.11% | 30.41% | 30.37% | | | | | | | | | | | 64 | baseline-aclgraph | 742.74 | 5953.39 | 47.79 | 53.15 | 2.48 | 1272.37 | | 64 | xlite-full | 545.22 | 3941.34 | 25.09 | 30.41 | 4.64 | 2376.44 | | 64 | xlite-decode-only | 752.40 | 4534.29 | 38.67 | 43.28 | 3.06 | 1567.94 | | 64 | diff1 | -26.59% | -33.80% | -47.50% | -42.78% | 87.10% | 86.77% | | 64 | diff2 | 1.30% | -23.84% | -19.08% | -18.57% | 23.39% | 23.23% | | | | | | | | | | | 100 | baseline-aclgraph | 565.52 | 1716.81 | 60.89 | 68.69 | 3.08 | 1580.64 | | 100 | xlite-full | 398.14 | 2328.88 | 30.70 | 32.45 | 6.01 | 3086.42 | | 100 | xlite-decode-only | 712.53 | 4875.94 | 52.71 | 60.78 | 3.53 | 1813.58 | | 100 | diff1 | -29.60% | 35.65% | -49.58% | -52.76% | 95.13% | 95.26% | | 100 | diff2 | 26.00% | 184.01% | -13.43% | -11.52% | 14.61% | 14.74% | | | | | | | | | | | 150 | baseline-aclgraph | 842.42 | 5175.01 | 73.60 | 88.18 | 3.80 | 1952.26 | | 150 | xlite-full | 568.52 | 4204.33 | 37.90 | 40.01 | 7.27 | 3734.72 | | 150 | xlite-decode-only | 654.43 | 2504.06 | 67.40 | 77.00 | 4.18 | 2145.11 | | 150 | diff1 | -32.51% | -18.76% | -48.51% | -54.63% | 91.32% | 91.30% | | 150 | diff2 | -22.32% | -51.61% | -8.42% | -12.68% | 10.00% | 9.88% | | | | | | | | | | | 200 | baseline-aclgraph | 750.63 | 3049.91 | 88.26 | 101.95 | 4.28 | 2189.72 | | 200 | xlite-full | 558.48 | 3791.98 | 45.54 | 49.04 | 8.17 | 4175.52 | | 200 | xlite-decode-only | 807.09 | 4254.95 | 85.18 | 101.79 | 4.44 | 2271.52 | | 200 | diff1 | -25.60% | 24.33% | -48.40% | -51.90% | 90.89% | 90.69% | | 200 | diff2 | 7.52% | 39.51% | -3.49% | -0.16% | 3.74% | 3.74% | | | | | | | | | | ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main:2c24bc6996--------- Signed-off-by: changdawei1 <changdawei3@huawei.com> Co-authored-by: LVYANGGUO <275926687@qq.com> Co-authored-by: lulina <lina.lulina@huawei.com>
370 lines
15 KiB
Python
370 lines
15 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
from typing import Any, Callable, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import (get_ep_group,
|
|
get_tensor_model_parallel_world_size,
|
|
get_world_group)
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.logger import logger
|
|
from vllm.sequence import IntermediateTensors
|
|
from xlite._C import (AttnMHA, Model, ModelAttnMeta, ModelConfig, Runtime,
|
|
ScoringFuncSoftmax)
|
|
|
|
import vllm_ascend.envs as envs_ascend
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
|
AscendMetadata)
|
|
|
|
|
|
class XliteModel:
|
|
|
|
def initialize(
|
|
self, runnable: nn.Module,
|
|
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
|
|
raise NotImplementedError(
|
|
"Xlite Model initialize function not implemented.")
|
|
|
|
|
|
class LlamaXliteModel(XliteModel):
|
|
|
|
def initialize(
|
|
self, runnable: nn.Module,
|
|
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
|
|
dtype = vllm_config.model_config.dtype
|
|
config = self._build_model_config(vllm_config)
|
|
xlite_model = self._build_model(runnable, vllm_config, config)
|
|
rank = torch.distributed.get_rank()
|
|
xlite_model.init(config, rank)
|
|
|
|
freq_cis = self._precompute_freqs_cis(config.head_dim,
|
|
config.max_seq_len, dtype,
|
|
config.rope_theta)
|
|
|
|
return (xlite_model, freq_cis, config.hidden_size, dtype)
|
|
|
|
def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig:
|
|
hf_config = vllm_config.model_config.hf_text_config
|
|
if hasattr(hf_config, "text_config"):
|
|
hf_config = hf_config.text_config
|
|
config = ModelConfig()
|
|
config.vocab_size = hf_config.vocab_size
|
|
config.hidden_size = hf_config.hidden_size
|
|
config.n_layers = hf_config.num_hidden_layers
|
|
config.n_heads = hf_config.num_attention_heads
|
|
config.n_kv_heads = hf_config.num_key_value_heads
|
|
if hasattr(hf_config, "head_dim"):
|
|
config.head_dim = hf_config.head_dim
|
|
else:
|
|
config.head_dim = hf_config.hidden_size // hf_config.num_attention_heads
|
|
config.rope_head_dim = config.head_dim
|
|
config.norm_eps = hf_config.rms_norm_eps
|
|
config.rope_theta = hf_config.rope_theta
|
|
config.softmax_scale = config.head_dim**-0.5
|
|
config.n_dense_layers = hf_config.num_hidden_layers
|
|
config.intermediate_size = hf_config.intermediate_size
|
|
config.def_tp_size = get_tensor_model_parallel_world_size()
|
|
config.def_dp_size = 1
|
|
config.moe_ep_size = 1
|
|
config.moe_tp_size = 1
|
|
|
|
config.attn_type = AttnMHA
|
|
config.weight_nz = envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2
|
|
scheduler_config = vllm_config.scheduler_config
|
|
max_batch_size = scheduler_config.max_num_seqs
|
|
max_seq_len = vllm_config.model_config.max_model_len
|
|
config.max_m = scheduler_config.max_num_batched_tokens
|
|
config.max_batch_size = max_batch_size
|
|
config.max_seq_len = max_seq_len
|
|
config.block_size = vllm_config.cache_config.block_size
|
|
return config
|
|
|
|
def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig,
|
|
config: ModelConfig) -> Model:
|
|
params_dict = dict(runnable.named_parameters())
|
|
|
|
if hasattr(runnable, "language_model"):
|
|
layers = runnable.language_model.model.layers
|
|
model_prefix = "language_model."
|
|
else:
|
|
layers = runnable.model.layers
|
|
model_prefix = ""
|
|
|
|
xlite_model = Model()
|
|
xlite_model.embed = params_dict.get(model_prefix +
|
|
"model.embed_tokens.weight")
|
|
xlite_model.norm = params_dict.get(model_prefix + "model.norm.weight")
|
|
if vllm_config.model_config.hf_text_config.tie_word_embeddings:
|
|
xlite_model.head = xlite_model.embed
|
|
else:
|
|
xlite_model.head = params_dict.get(model_prefix + "lm_head.weight")
|
|
xlite_model.attn_norm = [
|
|
layer.input_layernorm.weight for layer in layers
|
|
]
|
|
xlite_model.attn_out = [
|
|
layer.self_attn.o_proj.weight for layer in layers
|
|
]
|
|
xlite_model.mha_qkv = [
|
|
layer.self_attn.qkv_proj.weight for layer in layers
|
|
]
|
|
xlite_model.mlp_norm = [
|
|
layer.post_attention_layernorm.weight for layer in layers
|
|
]
|
|
xlite_model.mlp_up_gate = [
|
|
layer.mlp.gate_up_proj.weight for layer in layers
|
|
if hasattr(layer.mlp, "gate_up_proj")
|
|
and layer.mlp.gate_up_proj.weight is not None
|
|
]
|
|
xlite_model.mlp_down = [
|
|
layer.mlp.down_proj.weight for layer in layers
|
|
if hasattr(layer.mlp, "down_proj")
|
|
and layer.mlp.down_proj.weight is not None
|
|
]
|
|
mha_qkv_bias = [
|
|
layer.self_attn.qkv_proj.bias for layer in layers
|
|
if hasattr(layer.self_attn.qkv_proj, "bias")
|
|
and layer.self_attn.qkv_proj.bias is not None
|
|
]
|
|
q_norm = [
|
|
layer.self_attn.q_norm.weight for layer in layers
|
|
if hasattr(layer.self_attn, "q_norm")
|
|
]
|
|
k_norm = [
|
|
layer.self_attn.k_norm.weight for layer in layers
|
|
if hasattr(layer.self_attn, "k_norm")
|
|
]
|
|
|
|
if len(mha_qkv_bias) != config.n_layers:
|
|
config.qkv_bias = False
|
|
else:
|
|
config.qkv_bias = True
|
|
xlite_model.mha_qkv_bias = mha_qkv_bias
|
|
|
|
if (len(q_norm) != config.n_layers or len(k_norm) != config.n_layers):
|
|
config.qk_norm = False
|
|
else:
|
|
config.qk_norm = True
|
|
xlite_model.mha_q_norm = q_norm
|
|
xlite_model.mha_k_norm = k_norm
|
|
|
|
return xlite_model
|
|
|
|
def _precompute_freqs_cis(self,
|
|
dim: int,
|
|
end: int,
|
|
dtype: torch.dtype,
|
|
theta: float = 10000.0):
|
|
freqs = 1.0 / (theta**(torch.arange(
|
|
0, dim, 2, dtype=torch.float32, device='cpu')[:(dim // 2)] / dim))
|
|
t = torch.arange(end, device=freqs.device) # type: ignore
|
|
freqs = torch.outer(t, freqs).float() # type: ignore
|
|
cos_cache = freqs.cos().to(dtype)
|
|
sin_cache = freqs.sin().to(dtype)
|
|
freq_cis = torch.cat((cos_cache, sin_cache), dim=-1)
|
|
return freq_cis.to(device='npu')
|
|
|
|
|
|
class QwenMoeXliteModel(LlamaXliteModel):
|
|
|
|
def initialize(
|
|
self, runnable: nn.Module,
|
|
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
|
|
if envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2:
|
|
architecture = vllm_config.model_config.architectures[0]
|
|
raise ValueError(
|
|
f"{architecture} not support VLLM_ASCEND_ENABLE_NZ = 2!")
|
|
dtype = vllm_config.model_config.dtype
|
|
config = self._build_model_config(vllm_config)
|
|
xlite_model = self._build_model(runnable, vllm_config, config)
|
|
rank = torch.distributed.get_rank()
|
|
xlite_model.init(config, rank)
|
|
|
|
freq_cis = super()._precompute_freqs_cis(config.head_dim,
|
|
config.max_seq_len, dtype,
|
|
config.rope_theta)
|
|
|
|
return (xlite_model, freq_cis, config.hidden_size, dtype)
|
|
|
|
def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig:
|
|
config = super()._build_model_config(vllm_config)
|
|
hf_config = vllm_config.model_config.hf_text_config
|
|
ep_group = get_ep_group()
|
|
config.n_layers = hf_config.max_window_layers
|
|
config.n_dense_layers = 0
|
|
config.n_routed_experts = hf_config.num_experts
|
|
config.n_shared_experts = 0
|
|
config.n_act_experts = hf_config.num_experts_per_tok
|
|
config.def_dp_size = vllm_config.parallel_config.data_parallel_size
|
|
config.moe_ep_size = ep_group.world_size if vllm_config.parallel_config.enable_expert_parallel else 1
|
|
config.moe_tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else ep_group.world_size
|
|
config.experts_weight_transpose = True
|
|
config.moe_intermediate_size = hf_config.moe_intermediate_size
|
|
config.norm_topk_prob = hf_config.norm_topk_prob
|
|
config.scoring_func = ScoringFuncSoftmax
|
|
return config
|
|
|
|
def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig,
|
|
config: ModelConfig) -> Model:
|
|
xlite_model = super()._build_model(runnable, vllm_config, config)
|
|
layers = runnable.model.layers
|
|
xlite_model.gate = [layer.mlp.gate.weight for layer in layers]
|
|
xlite_model.re_up_gate = [
|
|
layer.mlp.experts.w13_weight[i] for layer in layers
|
|
for i in range(layer.mlp.experts.local_num_experts)
|
|
]
|
|
xlite_model.re_down = [
|
|
layer.mlp.experts.w2_weight[i] for layer in layers
|
|
for i in range(layer.mlp.experts.local_num_experts)
|
|
]
|
|
|
|
return xlite_model
|
|
|
|
|
|
def xlite_model_init(
|
|
runnable: nn.Module,
|
|
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
|
|
strategy_map = {
|
|
"LlamaForCausalLM": LlamaXliteModel,
|
|
"Qwen2ForCausalLM": LlamaXliteModel,
|
|
"Qwen3ForCausalLM": LlamaXliteModel,
|
|
"Qwen3VLForConditionalGeneration": LlamaXliteModel,
|
|
"Qwen3MoeForCausalLM": QwenMoeXliteModel,
|
|
}
|
|
|
|
architecture = vllm_config.model_config.architectures[0]
|
|
strategy_class = strategy_map.get(architecture)
|
|
if not strategy_class:
|
|
raise ValueError(f"{architecture} not supported!")
|
|
return strategy_class().initialize(runnable, vllm_config)
|
|
|
|
|
|
class XliteWrapper:
|
|
"""
|
|
xlite graph wrapper
|
|
"""
|
|
|
|
def __init__(self, runnable: nn.Module, vllm_config: VllmConfig):
|
|
self.runnable = runnable
|
|
self.full_mode = get_ascend_config().xlite_graph_config.full_mode
|
|
|
|
rank = torch.distributed.get_rank()
|
|
local_rank = get_world_group().local_rank
|
|
self.xlite_rt = Runtime(local_rank, 0, rank,
|
|
get_tensor_model_parallel_world_size(),
|
|
vllm_config.parallel_config.data_parallel_size)
|
|
|
|
(self.xlite_model, self.freq_cis, hidden_size,
|
|
dtype) = xlite_model_init(runnable, vllm_config)
|
|
|
|
rt_pool_size = self.xlite_model.get_tensor_pool_size()
|
|
if rank == 0:
|
|
logger.info(f"xlite runtime pool size: {rt_pool_size} MB")
|
|
if self.xlite_rt.init_tensor_pool(rt_pool_size) != 0:
|
|
raise ValueError(
|
|
f"xlite wrapper init failed! runtime pool size: {rt_pool_size} MB"
|
|
)
|
|
|
|
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
|
self.hidden_states = torch.empty(max_num_tokens,
|
|
hidden_size,
|
|
device=f"npu:{local_rank}",
|
|
dtype=dtype)
|
|
|
|
def __getattr__(self, key: str):
|
|
# allow accessing the attributes of the runnable.
|
|
if hasattr(self.runnable, key):
|
|
return getattr(self.runnable, key)
|
|
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
|
f"xlite wrapper: {self.runnable}")
|
|
|
|
def unwrap(self) -> Callable:
|
|
# in case we need to access the original runnable.
|
|
return self.runnable
|
|
|
|
def register_kv_caches(self, kv_caches: Any):
|
|
self.kv_caches = kv_caches
|
|
|
|
def __call__(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor,
|
|
list[torch.Tensor]]:
|
|
forward_context = get_forward_context()
|
|
attn_metadata: Any = forward_context.attn_metadata
|
|
if attn_metadata is None:
|
|
return self.runnable(input_ids, positions, intermediate_tensors,
|
|
inputs_embeds)
|
|
|
|
attn_metadata = next(iter(attn_metadata.values()), None)
|
|
if attn_metadata is None or not isinstance(attn_metadata,
|
|
AscendMetadata):
|
|
return self.runnable(input_ids, positions, intermediate_tensors,
|
|
inputs_embeds)
|
|
|
|
with_prefill = attn_metadata.attn_state not in [
|
|
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
|
]
|
|
|
|
if not with_prefill or self.full_mode:
|
|
# TODO: When vllm_ascend enables graph mode, attn_metadata.num_decodes
|
|
# will be padded in decode requests. Therefore, it is first fixed using
|
|
# num_decode_tokens. However, in the future, when MTP is enabled, there
|
|
# may be cases where a single request involves multiple tokens, which
|
|
# will need to be solved.
|
|
num_decodes = attn_metadata.num_decode_tokens
|
|
num_prefills = attn_metadata.num_prefills
|
|
batch = num_prefills + num_decodes
|
|
seq_lens = attn_metadata.seq_lens[:batch]
|
|
seq_tensor = torch.cat([
|
|
torch.tensor([0]),
|
|
torch.tensor(attn_metadata.actual_seq_lengths_q)
|
|
],
|
|
dim=0)
|
|
query_lens = seq_tensor[1:] - seq_tensor[:-1]
|
|
query_lens = query_lens[:batch]
|
|
cached_lens = seq_lens - query_lens
|
|
|
|
xlite_attn_metadata = ModelAttnMeta()
|
|
xlite_attn_metadata.lens = query_lens.tolist()
|
|
xlite_attn_metadata.cached_lens = cached_lens.tolist()
|
|
xlite_attn_metadata.is_prefills = [False] * num_decodes + [
|
|
True
|
|
] * num_prefills
|
|
xlite_attn_metadata.block_tables = attn_metadata.block_tables.cpu(
|
|
).tolist()
|
|
|
|
h = self.hidden_states[:attn_metadata.num_actual_tokens]
|
|
stream = torch.npu.current_stream().npu_stream
|
|
if inputs_embeds is None:
|
|
self.xlite_model.forward(self.xlite_rt, input_ids,
|
|
xlite_attn_metadata, self.kv_caches,
|
|
self.freq_cis, h, stream)
|
|
else:
|
|
self.xlite_model.forward_with_inputs_embeds(
|
|
self.xlite_rt, inputs_embeds, xlite_attn_metadata,
|
|
self.kv_caches, self.freq_cis, h, stream)
|
|
return h
|
|
else:
|
|
return self.runnable(input_ids, positions, intermediate_tensors,
|
|
inputs_embeds)
|