### What this PR does / why we need it?
This patch adds support for the Qwen3-MoE data parallel in Xlite. For
more details about Xlite, please refer to the following
link:[https://atomgit.com/openeuler/GVirt/blob/master/xlite/README.md](https://atomgit.com/openeuler/GVirt/blob/master/xlite/README.md).
online server config:
```shell
port=$1
log=$2
export VLLM_USE_V1=1
export TASK_QUEUE_ENABLE=1
export HCCL_BUFFSIZE=512
export HCCL_OP_EXPANSION_MODE="AIV"
export OMP_PROC_BIND=false
export VLLM_ASCEND_ENABLE_NZ=0
sysctl -w vm.swappiness=0
sysctl -w kernel.numa_balancing=0
sysctl kernel.sched_migration_cost_ns=50000
ip=127.0.0.1
python -m vllm.entrypoints.openai.api_server \
--model /mnt/nvme1n1/wy/models/Qwen3-30B-A3B \
--tensor-parallel-size 2 \
--enable-expert-parallel \
--data-parallel-size 4 \
--gpu-memory-utilization 0.9 \
--max-num-batched-tokens 32768 \
--data-parallel-size-local 4 \
--max-num-seqs=200 \
--block-size 128 \
--max-model-len 6656 \
--trust-remote-code \
--disable-log-requests \
--served-model-name qwen \
--no-enable-prefix-caching \
--additional-config '{"xlite_graph_config": {"enabled": true, "full_mode": true}, "enable_cpu_binding": true}' \
--compilation-config '{"cudagraph_capture_sizes":[1, 16, 32, 48, 64, 100, 150, 200], "cudagraph_mode": "FULL_DECODE_ONLY"}' \
--async-scheduling \
--host ${ip} \
--port ${port} > ${log} 2>&1 &
```
test_config:
```shell
vllm bench serve \
--max-concurrency ${maxconcurrency} \
--num-prompts ${num_prompts} \
--host ${HOST} \
--port ${PORT} \
--model ${MODEL_NAME} \
--dataset-name random \
--backend openai-chat \
--random-input-len 512 \
--random-output-len 512 \
--random-range-ratio 0.2 \
--temperature 0.6 \
--metric-percentiles "50,90,99" \
--tokenizer ${TOKENIZER_PATH} \
--endpoint /v1/chat/completions \
--ignore-eos
```
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
c86cdcbcd2
Signed-off-by: uuzWY <Ethan.wangyuan@huawei.com>
Co-authored-by: uuzWY <Ethan.wangyuan@huawei.com>
53 lines
2.1 KiB
Python
53 lines
2.1 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
|
# isort: skip_file
|
|
import torch.nn as nn
|
|
from vllm.config import CUDAGraphMode
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
|
|
|
|
|
class XliteModelRunner(NPUModelRunner):
|
|
def get_model(self) -> nn.Module:
|
|
return self.model.unwrap()
|
|
|
|
def load_model(self) -> None:
|
|
super().load_model()
|
|
from vllm_ascend.xlite.xlite import XliteWrapper
|
|
|
|
self.model = XliteWrapper(self.model, self.vllm_config)
|
|
|
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
super().initialize_kv_cache(kv_cache_config)
|
|
self.model.register_kv_caches(self.kv_caches)
|
|
|
|
def _should_build_dummy_attn_metadata(
|
|
self,
|
|
force_attention: bool = False,
|
|
is_profile: bool = False,
|
|
cudagraph_runtime_mode: CUDAGraphMode | None = None,
|
|
) -> bool:
|
|
"""
|
|
Override to build attention metadata during dummy_run when xlite is enable.
|
|
For xlite, we need to build metadata during DP dummy_run to ensure all ranks
|
|
have consistent metadata, even when some ranks have no requests.
|
|
"""
|
|
base_condition = super()._should_build_dummy_attn_metadata(force_attention, is_profile, cudagraph_runtime_mode)
|
|
xlite_condition = self.ascend_config.xlite_graph_config.enabled and not is_profile
|
|
return base_condition or xlite_condition
|