[WIP]Add Func: aclgraph_batch_size auto-adjust to different model (#771)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
This PR add new function of : aclgraph_batch_size can dynamic adjust to
different model; before this PR, the aclgraph_batch_sizes given from
vllm to vllm-ascend always too large, and that may result in ERROR while
running on different, with the information: "The resources are
insufficient".
Now, with this PR, the code can dynamic adjust aclgraph_batch_sizes
depend on the model hidden_layer_nums and parallel config, for example:
a. for Qwen2.5-7B, the aclgraph_batch_size length is 33 total;
b. for Qwen2.5-72B, the aclgraph_batch_size length is 11 total;

Signed-off-by: chris668899 <15105191595@126.com>
This commit is contained in:
chris668899
2025-05-08 16:23:33 +08:00
committed by GitHub
parent 2e3520e285
commit 6c020883a8
2 changed files with 127 additions and 3 deletions

View File

@@ -0,0 +1,60 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import pytest
import torch
from vllm import LLM, SamplingParams
# TODO: revert me when cuda hard code is fixed in 'VllmBackend'
torch.cuda.CUDAGraph = torch.npu.NPUGraph
MODELS = [
"Qwen/Qwen2.5-0.5B-Instruct",
]
TENSOR_PARALLELS = [2]
prompts = [
"Hello, my name is",
"The future of AI is",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("temperature", [0.0])
@pytest.mark.parametrize("ignore_eos", [True])
def test_models(model: str, tp_size: int, max_tokens: int, temperature: int,
ignore_eos: bool) -> None:
# Create an LLM.
llm = LLM(
model=model,
tensor_parallel_size=tp_size,
)
# Prepare sampling_parames
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=temperature,
ignore_eos=ignore_eos,
)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
outputs = llm.generate(prompts, sampling_params)
torch.npu.synchronize()
# The output length should be equal to prompts length.
assert len(outputs) == len(prompts)

View File

@@ -18,6 +18,7 @@
#
import gc
import math
import os
import time
import weakref
@@ -273,7 +274,7 @@ class NPUModelRunner:
self.use_npu_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
self.npugraph_batch_sizes = list(
self.aclgraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
@@ -950,12 +951,15 @@ class NPUModelRunner:
start_time = time.perf_counter()
start_free_npu_memory = torch.npu.mem_get_info()[0]
# Since vllm aclgraph_batch_sizes is too large,
# we need to adjust its length to proper size.
self.verify_adjust_aclgraph_batch_sizes()
# Trigger NPU graph capture for specific shapes.
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
for num_tokens in reversed(self.npugraph_batch_sizes):
for num_tokens in reversed(self.aclgraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens)
@@ -968,3 +972,63 @@ class NPUModelRunner:
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, npu_graph_size / (1 << 30))
def verify_adjust_aclgraph_batch_sizes(self) -> None:
# Now, vllm-ascend support max capture size is 1920
max_capture_size = 1920
original_aclgraph_batch_sizes = self.aclgraph_batch_sizes
num_hidden_layers = self.vllm_config.model_config.hf_config.num_hidden_layers
max_support_len_aclgraph = self.get_max_support_len(
max_capture_size, num_hidden_layers)
if max_support_len_aclgraph < len(original_aclgraph_batch_sizes):
self.aclgraph_batch_sizes = self.sample_from_list(
max_support_len_aclgraph)
logger.info(
"Model:%s-num_hidden_layers:%d will adjust aclgraph_batch_sizes, pre-adjust-len: %s, post-adjust-len: %s",
self.vllm_config.model_config.architectures[0],
num_hidden_layers, len(original_aclgraph_batch_sizes),
len(self.aclgraph_batch_sizes))
else:
logger.info(
"Model:%s-num_hidden_layers:%d no need adjust aclgraph_batch_sizes, list_len: %s",
self.vllm_config.model_config.architectures[0],
num_hidden_layers, len(original_aclgraph_batch_sizes))
def get_max_support_len(self, max_capture_size, num_hidden_layers) -> int:
parallel_type_cnt = 0
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if dp_size > 1:
parallel_type_cnt += 1
if tp_size > 1:
parallel_type_cnt += 1
max_support_len_aclgraph = math.floor(max_capture_size /
(num_hidden_layers + 1) /
(parallel_type_cnt + 1))
logger.info(
"max_capture_size:%s, dp_size:%s, tp_size:%s, parallel_type_cnt:%s, max_support_len_aclgraph: %s:",
max_capture_size,
dp_size,
tp_size,
parallel_type_cnt,
max_support_len_aclgraph,
)
return max_support_len_aclgraph
def sample_from_list(self, sample_len) -> list[int]:
# we use this function to sample a new list from old list by given length, and maintain uniformity, for example:
# original: [1 8 16 24 32 40 48 56 64]
# --> sample length = 3: [1 32 64]
# --> sample length = 5: [1 16 32 48 64]
original_len = len(self.aclgraph_batch_sizes)
step = (original_len - 1) / (sample_len - 1)
indices = [round(i * step) for i in range(sample_len)]
# Align first and last element of the original list and sub-list
indices[0] = 0
indices[-1] = original_len - 1
# Sample new list
new_list = [self.aclgraph_batch_sizes[i] for i in indices]
return new_list