Files
xc-llm-ascend/vllm_ascend/xlite/xlite.py
王远 82fdd40d49 [Feat]Xlite Qwen3 MoE Support Data Parallel (#6715)
### What this PR does / why we need it?
This patch adds support for the Qwen3-MoE data parallel in Xlite. For
more details about Xlite, please refer to the following
link:[https://atomgit.com/openeuler/GVirt/blob/master/xlite/README.md](https://atomgit.com/openeuler/GVirt/blob/master/xlite/README.md).

online server config:
```shell
port=$1
log=$2
export VLLM_USE_V1=1
export TASK_QUEUE_ENABLE=1
export HCCL_BUFFSIZE=512
export HCCL_OP_EXPANSION_MODE="AIV"
export OMP_PROC_BIND=false
export VLLM_ASCEND_ENABLE_NZ=0
sysctl -w vm.swappiness=0
sysctl -w kernel.numa_balancing=0
sysctl kernel.sched_migration_cost_ns=50000
ip=127.0.0.1
python -m vllm.entrypoints.openai.api_server \
        --model /mnt/nvme1n1/wy/models/Qwen3-30B-A3B  \
        --tensor-parallel-size 2 \
        --enable-expert-parallel \
        --data-parallel-size 4 \
        --gpu-memory-utilization 0.9 \
        --max-num-batched-tokens 32768 \
        --data-parallel-size-local 4 \
        --max-num-seqs=200 \
        --block-size 128 \
        --max-model-len 6656 \
        --trust-remote-code \
        --disable-log-requests \
        --served-model-name qwen \
        --no-enable-prefix-caching \
	--additional-config '{"xlite_graph_config": {"enabled": true, "full_mode": true}, "enable_cpu_binding": true}' \
	--compilation-config '{"cudagraph_capture_sizes":[1, 16, 32, 48, 64, 100, 150, 200], "cudagraph_mode": "FULL_DECODE_ONLY"}' \
	--async-scheduling \
	--host ${ip} \
	--port ${port} > ${log} 2>&1 &
``` 
test_config:
```shell
vllm bench serve \
    --max-concurrency ${maxconcurrency} \
    --num-prompts ${num_prompts} \
    --host ${HOST} \
    --port ${PORT} \
    --model ${MODEL_NAME} \
    --dataset-name random \
    --backend openai-chat \
    --random-input-len 512 \
    --random-output-len 512  \
    --random-range-ratio 0.2 \
    --temperature 0.6 \
    --metric-percentiles "50,90,99" \
    --tokenizer ${TOKENIZER_PATH} \
    --endpoint /v1/chat/completions \
    --ignore-eos
``` 

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?


- vLLM version: v0.16.0
- vLLM main:
c86cdcbcd2

Signed-off-by: uuzWY <Ethan.wangyuan@huawei.com>
Co-authored-by: uuzWY <Ethan.wangyuan@huawei.com>
2026-03-09 17:53:35 +08:00

322 lines
14 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Callable
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size, get_world_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.sequence import IntermediateTensors
from xlite._C import ( # type: ignore[attr-defined]
AttnMHA,
Model,
ModelAttnMeta,
ModelConfig,
Runtime,
ScoringFuncSoftmax,
)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState, AscendMetadata
class XliteModel:
def initialize(self, runnable: nn.Module, vllm_config: VllmConfig) -> tuple[Model, int, int, torch.dtype]:
raise NotImplementedError("Xlite Model initialize function not implemented.")
class LlamaXliteModel(XliteModel):
def initialize(self, runnable: nn.Module, vllm_config: VllmConfig) -> tuple[Model, int, int, torch.dtype]:
dtype = vllm_config.model_config.dtype
config = self._build_model_config(vllm_config)
xlite_model = self._build_model(runnable, vllm_config, config)
rank = torch.distributed.get_rank()
xlite_model.init(config, rank)
freq_cis = self._precompute_freqs_cis(config.head_dim, config.max_seq_len, dtype, config.rope_theta)
return (xlite_model, freq_cis, config.hidden_size, dtype)
def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig:
hf_config = vllm_config.model_config.hf_text_config
if hasattr(hf_config, "text_config"):
hf_config = hf_config.text_config
config = ModelConfig()
config.vocab_size = hf_config.vocab_size
config.hidden_size = hf_config.hidden_size
config.n_layers = hf_config.num_hidden_layers
config.n_heads = hf_config.num_attention_heads
config.n_kv_heads = hf_config.num_key_value_heads
if hasattr(hf_config, "head_dim"):
config.head_dim = hf_config.head_dim
else:
config.head_dim = hf_config.hidden_size // hf_config.num_attention_heads
config.rope_head_dim = config.head_dim
config.norm_eps = hf_config.rms_norm_eps
config.rope_theta = hf_config.rope_theta
config.softmax_scale = config.head_dim**-0.5
config.n_dense_layers = hf_config.num_hidden_layers
config.intermediate_size = hf_config.intermediate_size
config.def_tp_size = get_tensor_model_parallel_world_size()
config.def_dp_size = 1
config.moe_ep_size = 1
config.moe_tp_size = 1
config.attn_type = AttnMHA
config.weight_nz = envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2
scheduler_config = vllm_config.scheduler_config
max_batch_size = scheduler_config.max_num_seqs
max_seq_len = vllm_config.model_config.max_model_len
config.max_m = scheduler_config.max_num_batched_tokens
config.max_batch_size = max_batch_size
config.max_seq_len = max_seq_len
config.block_size = vllm_config.cache_config.block_size
return config
def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig, config: ModelConfig) -> Model:
params_dict = dict(runnable.named_parameters())
if hasattr(runnable, "language_model"):
layers = runnable.language_model.model.layers
model_prefix = "language_model."
else:
layers = runnable.model.layers
model_prefix = ""
xlite_model = Model()
xlite_model.embed = params_dict.get(model_prefix + "model.embed_tokens.weight")
xlite_model.norm = params_dict.get(model_prefix + "model.norm.weight")
if vllm_config.model_config.hf_text_config.tie_word_embeddings:
xlite_model.head = xlite_model.embed
else:
xlite_model.head = params_dict.get(model_prefix + "lm_head.weight")
xlite_model.attn_norm = [layer.input_layernorm.weight for layer in layers]
xlite_model.attn_out = [layer.self_attn.o_proj.weight for layer in layers]
xlite_model.mha_qkv = [layer.self_attn.qkv_proj.weight for layer in layers]
xlite_model.mlp_norm = [layer.post_attention_layernorm.weight for layer in layers]
xlite_model.mlp_up_gate = [
layer.mlp.gate_up_proj.weight
for layer in layers
if hasattr(layer.mlp, "gate_up_proj") and layer.mlp.gate_up_proj.weight is not None
]
xlite_model.mlp_down = [
layer.mlp.down_proj.weight
for layer in layers
if hasattr(layer.mlp, "down_proj") and layer.mlp.down_proj.weight is not None
]
mha_qkv_bias = [
layer.self_attn.qkv_proj.bias
for layer in layers
if hasattr(layer.self_attn.qkv_proj, "bias") and layer.self_attn.qkv_proj.bias is not None
]
q_norm = [layer.self_attn.q_norm.weight for layer in layers if hasattr(layer.self_attn, "q_norm")]
k_norm = [layer.self_attn.k_norm.weight for layer in layers if hasattr(layer.self_attn, "k_norm")]
if len(mha_qkv_bias) != config.n_layers:
config.qkv_bias = False
else:
config.qkv_bias = True
xlite_model.mha_qkv_bias = mha_qkv_bias
if len(q_norm) != config.n_layers or len(k_norm) != config.n_layers:
config.qk_norm = False
else:
config.qk_norm = True
xlite_model.mha_q_norm = q_norm
xlite_model.mha_k_norm = k_norm
return xlite_model
def _precompute_freqs_cis(self, dim: int, end: int, dtype: torch.dtype, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device="cpu")[: (dim // 2)] / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
cos_cache = freqs.cos().to(dtype)
sin_cache = freqs.sin().to(dtype)
freq_cis = torch.cat((cos_cache, sin_cache), dim=-1)
return freq_cis.to(device="npu")
class QwenMoeXliteModel(LlamaXliteModel):
def initialize(self, runnable: nn.Module, vllm_config: VllmConfig) -> tuple[Model, int, int, torch.dtype]:
dtype = vllm_config.model_config.dtype
config = self._build_model_config(vllm_config)
xlite_model = self._build_model(runnable, vllm_config, config)
rank = torch.distributed.get_rank()
xlite_model.init(config, rank)
freq_cis = super()._precompute_freqs_cis(config.head_dim, config.max_seq_len, dtype, config.rope_theta)
return (xlite_model, freq_cis, config.hidden_size, dtype)
def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig:
config = super()._build_model_config(vllm_config)
hf_config = vllm_config.model_config.hf_text_config
ep_group = get_ep_group()
config.n_dense_layers = 0
config.n_routed_experts = hf_config.num_experts
config.n_shared_experts = 0
config.n_act_experts = hf_config.num_experts_per_tok
config.def_dp_size = vllm_config.parallel_config.data_parallel_size
config.moe_ep_size = ep_group.world_size if vllm_config.parallel_config.enable_expert_parallel else 1
config.moe_tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else ep_group.world_size
config.experts_weight_transpose = True # type: ignore
config.moe_intermediate_size = hf_config.moe_intermediate_size
config.norm_topk_prob = hf_config.norm_topk_prob # type: ignore
config.scoring_func = ScoringFuncSoftmax # type: ignore
return config
def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig, config: ModelConfig) -> Model:
xlite_model = super()._build_model(runnable, vllm_config, config)
layers = runnable.model.layers
xlite_model.gate = [layer.mlp.gate.weight for layer in layers]
xlite_model.re_up_gate = [
layer.mlp.experts.w13_weight[i] for layer in layers for i in range(layer.mlp.experts.local_num_experts)
]
xlite_model.re_down = [
layer.mlp.experts.w2_weight[i] for layer in layers for i in range(layer.mlp.experts.local_num_experts)
]
return xlite_model
def xlite_model_init(runnable: nn.Module, vllm_config: VllmConfig) -> tuple[Model, int, int, torch.dtype]:
strategy_map = {
"LlamaForCausalLM": LlamaXliteModel,
"Qwen2ForCausalLM": LlamaXliteModel,
"Qwen3ForCausalLM": LlamaXliteModel,
"Qwen3VLForConditionalGeneration": LlamaXliteModel,
"Qwen3MoeForCausalLM": QwenMoeXliteModel,
}
architecture = vllm_config.model_config.architectures[0]
strategy_class = strategy_map.get(architecture)
if not strategy_class:
raise ValueError(f"{architecture} not supported!")
return strategy_class().initialize(runnable, vllm_config)
class XliteWrapper:
"""
xlite graph wrapper
"""
def __init__(self, runnable: nn.Module, vllm_config: VllmConfig):
self.runnable = runnable
self.full_mode = get_ascend_config().xlite_graph_config.full_mode
rank = torch.distributed.get_rank()
local_rank = get_world_group().local_rank
self.data_parallel_size = vllm_config.parallel_config.data_parallel_size
self.xlite_rt = Runtime(local_rank, 0, rank, get_tensor_model_parallel_world_size(), self.data_parallel_size)
(self.xlite_model, self.freq_cis, hidden_size, dtype) = xlite_model_init(runnable, vllm_config)
rt_pool_size = self.xlite_model.get_tensor_pool_size()
if rank == 0:
logger.info(f"xlite runtime pool size: {rt_pool_size} MB")
if self.xlite_rt.init_tensor_pool(rt_pool_size) != 0:
raise ValueError(f"xlite wrapper init failed! runtime pool size: {rt_pool_size} MB")
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.hidden_states = torch.empty(max_num_tokens, hidden_size, device=f"npu:{local_rank}", dtype=dtype)
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of xlite wrapper: {self.runnable}")
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable
def register_kv_caches(self, kv_caches: Any):
self.kv_caches = kv_caches
def __call__(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
forward_context = get_forward_context()
attn_metadata: Any = forward_context.attn_metadata
if attn_metadata is None:
return self.runnable(input_ids, positions, intermediate_tensors, inputs_embeds)
attn_metadata = next(iter(attn_metadata.values()), None)
if attn_metadata is None or not isinstance(attn_metadata, AscendMetadata):
return self.runnable(input_ids, positions, intermediate_tensors, inputs_embeds)
with_prefill = attn_metadata.attn_state not in [
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding,
]
# Full: graph for prefill and decode
# Decode-Only: runnable for prefill, graph for decode
if not self.full_mode and self.data_parallel_size > 1:
num_tokens = forward_context.batch_descriptor.num_tokens
num_reqs = forward_context.batch_descriptor.num_reqs
use_xlite_graph = num_reqs is not None and num_tokens <= num_reqs
else:
use_xlite_graph = not with_prefill or self.full_mode
if use_xlite_graph:
# TODO: When vllm_ascend enables graph mode, attn_metadata.num_decodes
# will be padded in decode requests. Therefore, it is first fixed using
# num_decode_tokens. However, in the future, when MTP is enabled, there
# may be cases where a single request involves multiple tokens, which
# will need to be solved.
num_decodes = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills
batch = num_prefills + num_decodes
seq_lens = attn_metadata.seq_lens[:batch]
seq_tensor = torch.cat([torch.tensor([0]), torch.tensor(attn_metadata.actual_seq_lengths_q)], dim=0)
query_lens = seq_tensor[1:] - seq_tensor[:-1]
query_lens = query_lens[:batch]
cached_lens = seq_lens - query_lens
xlite_attn_metadata = ModelAttnMeta()
xlite_attn_metadata.lens = query_lens.tolist()
xlite_attn_metadata.cached_lens = cached_lens.tolist()
xlite_attn_metadata.is_prefills = [False] * num_decodes + [True] * num_prefills
xlite_attn_metadata.block_tables = attn_metadata.block_tables.cpu().tolist()
# Compatibility between DP and Non-DP scenarios
num_tokens = forward_context.batch_descriptor.num_tokens
num_actual_tokens = attn_metadata.num_actual_tokens
h = self.hidden_states[:num_tokens]
stream = torch.npu.current_stream().npu_stream
if inputs_embeds is None:
self.xlite_model.forward(
self.xlite_rt, input_ids, xlite_attn_metadata, self.kv_caches, self.freq_cis, h, stream
)
else:
self.xlite_model.forward_with_inputs_embeds(
self.xlite_rt, inputs_embeds, xlite_attn_metadata, self.kv_caches, self.freq_cis, h, stream
)
return h[:num_actual_tokens]
else:
return self.runnable(input_ids, positions, intermediate_tensors, inputs_embeds)