[Enhancement] Add padding for ACL Graph (#803)

### What this PR does / why we need it?
Add padding for ACL Graph and refactor graph batch size adjustments to
utils.py

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-05-12 20:26:22 +08:00
committed by GitHub
parent efabd722eb
commit 701b0fd95e
4 changed files with 97 additions and 79 deletions

View File

@@ -18,7 +18,6 @@
#
import gc
import math
import os
import time
import weakref
@@ -293,9 +292,9 @@ class NPUModelRunner:
device="cpu")
self.attn_mask = None
self.attn_state = None
self.use_npu_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
self.use_aclgraph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
self.aclgraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
@@ -508,6 +507,13 @@ class NPUModelRunner:
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
if (self.use_aclgraph and
total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]):
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
total_num_scheduled_tokens)
else:
num_input_tokens = total_num_scheduled_tokens
modified_batch = self.attn_metadata_builder.reorder_batch(
self.input_batch, scheduler_output)
@@ -546,7 +552,7 @@ class NPUModelRunner:
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
positions = self.positions[:total_num_scheduled_tokens]
positions = self.positions[:num_input_tokens]
self.query_lens = torch.from_numpy(num_scheduled_tokens)
self.seq_lens_np[:num_reqs] = (
@@ -605,7 +611,7 @@ class NPUModelRunner:
# Copy the tensors to the NPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
input_ids = self.input_ids[:total_num_scheduled_tokens]
input_ids = self.input_ids[:num_input_tokens]
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
padding = torch.zeros(graph_pad_size,
@@ -615,7 +621,9 @@ class NPUModelRunner:
positions = torch.cat([positions, padding])
# Run forward pass
with set_forward_context(attn_metadata, self.vllm_config):
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
model_kwargs = {}
if self.enable_torchair_graph_mode:
model_kwargs["kv_caches"] = self.kv_caches
@@ -1062,7 +1070,7 @@ class NPUModelRunner:
return kv_cache_spec
def capture_model(self) -> None:
if not self.use_npu_graph:
if not self.use_aclgraph:
logger.warning(
"Skipping NPU graph capture. Please add "
"-O %s to use NPU graphs.", CompilationLevel.PIECEWISE)
@@ -1070,9 +1078,6 @@ class NPUModelRunner:
start_time = time.perf_counter()
start_free_npu_memory = torch.npu.mem_get_info()[0]
# Since vllm aclgraph_batch_sizes is too large,
# we need to adjust its length to proper size.
self.verify_adjust_aclgraph_batch_sizes()
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
@@ -1091,63 +1096,3 @@ class NPUModelRunner:
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, npu_graph_size / (1 << 30))
def verify_adjust_aclgraph_batch_sizes(self) -> None:
# Now, vllm-ascend support max capture size is 1920
max_capture_size = 1920
original_aclgraph_batch_sizes = self.aclgraph_batch_sizes
num_hidden_layers = self.vllm_config.model_config.hf_config.num_hidden_layers
max_support_len_aclgraph = self.get_max_support_len(
max_capture_size, num_hidden_layers)
if max_support_len_aclgraph < len(original_aclgraph_batch_sizes):
self.aclgraph_batch_sizes = self.sample_from_list(
max_support_len_aclgraph)
logger.info(
"Model:%s-num_hidden_layers:%d will adjust aclgraph_batch_sizes, pre-adjust-len: %s, post-adjust-len: %s",
self.vllm_config.model_config.architectures[0],
num_hidden_layers, len(original_aclgraph_batch_sizes),
len(self.aclgraph_batch_sizes))
else:
logger.info(
"Model:%s-num_hidden_layers:%d no need adjust aclgraph_batch_sizes, list_len: %s",
self.vllm_config.model_config.architectures[0],
num_hidden_layers, len(original_aclgraph_batch_sizes))
def get_max_support_len(self, max_capture_size, num_hidden_layers) -> int:
parallel_type_cnt = 0
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if dp_size > 1:
parallel_type_cnt += 1
if tp_size > 1:
parallel_type_cnt += 1
max_support_len_aclgraph = math.floor(max_capture_size /
(num_hidden_layers + 1) /
(parallel_type_cnt + 1))
logger.info(
"max_capture_size:%s, dp_size:%s, tp_size:%s, parallel_type_cnt:%s, max_support_len_aclgraph: %s:",
max_capture_size,
dp_size,
tp_size,
parallel_type_cnt,
max_support_len_aclgraph,
)
return max_support_len_aclgraph
def sample_from_list(self, sample_len) -> list[int]:
# we use this function to sample a new list from old list by given length, and maintain uniformity, for example:
# original: [1 8 16 24 32 40 48 56 64]
# --> sample length = 3: [1 32 64]
# --> sample length = 5: [1 16 32 48 64]
original_len = len(self.aclgraph_batch_sizes)
step = (original_len - 1) / (sample_len - 1)
indices = [round(i * step) for i in range(sample_len)]
# Align first and last element of the original list and sub-list
indices[0] = 0
indices[-1] = original_len - 1
# Sample new list
new_list = [self.aclgraph_batch_sizes[i] for i in indices]
return new_list