[Enhancement] Add padding for ACL Graph (#803)
### What this PR does / why we need it? Add padding for ACL Graph and refactor graph batch size adjustments to utils.py --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -104,6 +104,7 @@ class AscendAttentionState(Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AscendMetadata:
|
class AscendMetadata:
|
||||||
|
num_actual_tokens: int # Number of tokens excluding padding.
|
||||||
# (batch_size, max_blocks_per_seq).
|
# (batch_size, max_blocks_per_seq).
|
||||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||||
block_tables: torch.Tensor
|
block_tables: torch.Tensor
|
||||||
@@ -125,7 +126,6 @@ class AscendMetadata:
|
|||||||
is_only_prefill: bool = False
|
is_only_prefill: bool = False
|
||||||
# Current state of this attention run.
|
# Current state of this attention run.
|
||||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||||
|
|
||||||
attn_mask: Optional[torch.Tensor] = None
|
attn_mask: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -149,7 +149,8 @@ class AscendAttentionMetadataBuilder:
|
|||||||
attn_mask = self.runner.attn_mask
|
attn_mask = self.runner.attn_mask
|
||||||
attn_state = self.runner.attn_state
|
attn_state = self.runner.attn_state
|
||||||
|
|
||||||
attn_metadata = AscendMetadata(block_tables=block_table,
|
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
|
||||||
|
block_tables=block_table,
|
||||||
query_lens=query_lens,
|
query_lens=query_lens,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
@@ -234,9 +235,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
output=output,
|
output=output,
|
||||||
layer_name=layer.layer_name)
|
layer_name=layer.layer_name)
|
||||||
else:
|
else:
|
||||||
num_tokens = query.shape[0]
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
return output.view(num_tokens, self.hidden_size)
|
return output.view(num_tokens, self.hidden_size)
|
||||||
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||||
attn_type = self.attn_type
|
attn_type = self.attn_type
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type != AttentionType.DECODER:
|
||||||
@@ -255,11 +256,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
if self.key_cache is None:
|
if self.key_cache is None:
|
||||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||||
slots = attn_metadata.slot_mapping
|
slots = attn_metadata.slot_mapping
|
||||||
torch_npu._npu_reshape_and_cache(key=key,
|
torch_npu._npu_reshape_and_cache(
|
||||||
value=value,
|
key=key[:num_actual_tokens],
|
||||||
key_cache=self.key_cache,
|
value=value[:num_actual_tokens],
|
||||||
value_cache=self.value_cache,
|
key_cache=self.key_cache,
|
||||||
slot_indices=slots)
|
value_cache=self.value_cache,
|
||||||
|
slot_indices=slots)
|
||||||
|
|
||||||
if hasattr(layer, 'quant_method'):
|
if hasattr(layer, 'quant_method'):
|
||||||
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
|
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ from vllm.logger import logger
|
|||||||
from vllm.platforms import Platform, PlatformEnum
|
from vllm.platforms import Platform, PlatformEnum
|
||||||
from vllm.utils import supports_dynamo
|
from vllm.utils import supports_dynamo
|
||||||
|
|
||||||
|
from vllm_ascend.utils import update_aclgraph_sizes
|
||||||
|
|
||||||
CUSTOM_OP_ENABLED = False
|
CUSTOM_OP_ENABLED = False
|
||||||
try:
|
try:
|
||||||
# register custom ops into torch_library here
|
# register custom ops into torch_library here
|
||||||
@@ -144,6 +146,7 @@ class NPUPlatform(Platform):
|
|||||||
compilation_config.use_inductor = False
|
compilation_config.use_inductor = False
|
||||||
compilation_config.splitting_ops.extend(
|
compilation_config.splitting_ops.extend(
|
||||||
["vllm.unified_ascend_attention_with_output"])
|
["vllm.unified_ascend_attention_with_output"])
|
||||||
|
update_aclgraph_sizes(vllm_config)
|
||||||
|
|
||||||
if vllm_config.additional_config is not None:
|
if vllm_config.additional_config is not None:
|
||||||
enable_graph_mode = vllm_config.additional_config.get(
|
enable_graph_mode = vllm_config.additional_config.get(
|
||||||
|
|||||||
@@ -16,12 +16,28 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging.version import InvalidVersion, Version
|
from packaging.version import InvalidVersion, Version
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
import vllm_ascend.envs as envs
|
import vllm_ascend.envs as envs
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
else:
|
||||||
|
VllmConfig = None
|
||||||
|
|
||||||
|
# NOTE: Currently, we can only capture 1920 graphs at most,
|
||||||
|
# due to the limitation of ACL graph. This number is bounded by
|
||||||
|
# the number of streams, which is 2048, we save 128 streams
|
||||||
|
# as a buffer.
|
||||||
|
# Maximum number of graphs that can be captured by ACL Graph
|
||||||
|
MAX_CAPTURE_SIZE = 1920
|
||||||
|
|
||||||
|
|
||||||
def try_register_lib(lib_name: str, lib_info: str = ""):
|
def try_register_lib(lib_name: str, lib_info: str = ""):
|
||||||
import importlib
|
import importlib
|
||||||
@@ -99,3 +115,55 @@ def vllm_version_is(target_vllm_version: str):
|
|||||||
"is installed probably. Set the environment variable VLLM_VERSION "
|
"is installed probably. Set the environment variable VLLM_VERSION "
|
||||||
"to control it by hand. And please make sure the vaule follows the "
|
"to control it by hand. And please make sure the vaule follows the "
|
||||||
"format of x.y.z.")
|
"format of x.y.z.")
|
||||||
|
|
||||||
|
|
||||||
|
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||||
|
"""Update ACL graph capture sizes based on hardware limitations"""
|
||||||
|
# Store original configuration and temporarily clear it
|
||||||
|
compilation_config = vllm_config.compilation_config
|
||||||
|
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
||||||
|
compilation_config.cudagraph_capture_sizes, None
|
||||||
|
|
||||||
|
# Calculate parallel configuration factor (increases with DP or TP)
|
||||||
|
# TODO(Yizhou): This is a temporary solution, need to be improved
|
||||||
|
# in the future, taking into account the other parallel configurations.
|
||||||
|
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
parallel_factor = 1 + sum(size > 1 for size in [
|
||||||
|
parallel_config.data_parallel_size,
|
||||||
|
parallel_config.tensor_parallel_size
|
||||||
|
])
|
||||||
|
|
||||||
|
# Calculate maximum supported batch sizes considering model architecture
|
||||||
|
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
|
||||||
|
(num_hidden_layers + 1) / parallel_factor)
|
||||||
|
logger.info("Calculated maximum supported batch sizes for ACL graph: %s",
|
||||||
|
max_num_batch_sizes)
|
||||||
|
|
||||||
|
# If original sizes exceed maximum, sample a representative subset
|
||||||
|
if max_num_batch_sizes < len(original_sizes):
|
||||||
|
# Sample uniformly from original sizes
|
||||||
|
step = (len(original_sizes) - 1) / (max_num_batch_sizes - 1)
|
||||||
|
indices = [round(i * step) for i in range(max_num_batch_sizes)]
|
||||||
|
|
||||||
|
# Ensure first and last elements are preserved
|
||||||
|
indices[0], indices[-1] = 0, len(original_sizes) - 1
|
||||||
|
|
||||||
|
sampled_sizes = [original_sizes[i] for i in indices]
|
||||||
|
compilation_config.init_with_cudagraph_sizes(sampled_sizes)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Adjusted ACL graph batch sizes for %s model (layers: %d): %d → %d sizes",
|
||||||
|
vllm_config.model_config.architectures[0],
|
||||||
|
num_hidden_layers,
|
||||||
|
len(original_sizes),
|
||||||
|
len(compilation_config.
|
||||||
|
cudagraph_capture_sizes # type: ignore[arg-type]
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
# No adjustment needed
|
||||||
|
compilation_config.cudagraph_capture_sizes = original_sizes
|
||||||
|
logger.info(
|
||||||
|
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
|
||||||
|
vllm_config.model_config.architectures[0], num_hidden_layers,
|
||||||
|
len(original_sizes))
|
||||||
|
|||||||
@@ -18,7 +18,6 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
@@ -293,9 +292,9 @@ class NPUModelRunner:
|
|||||||
device="cpu")
|
device="cpu")
|
||||||
self.attn_mask = None
|
self.attn_mask = None
|
||||||
self.attn_state = None
|
self.attn_state = None
|
||||||
self.use_npu_graph = (self.vllm_config.compilation_config.level
|
self.use_aclgraph = (self.vllm_config.compilation_config.level
|
||||||
== CompilationLevel.PIECEWISE
|
== CompilationLevel.PIECEWISE
|
||||||
and not self.model_config.enforce_eager)
|
and not self.model_config.enforce_eager)
|
||||||
self.aclgraph_batch_sizes = list(
|
self.aclgraph_batch_sizes = list(
|
||||||
reversed(
|
reversed(
|
||||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||||
@@ -508,6 +507,13 @@ class NPUModelRunner:
|
|||||||
assert total_num_scheduled_tokens > 0
|
assert total_num_scheduled_tokens > 0
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
assert num_reqs > 0
|
assert num_reqs > 0
|
||||||
|
if (self.use_aclgraph and
|
||||||
|
total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]):
|
||||||
|
# Add padding to the batch size.
|
||||||
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||||
|
total_num_scheduled_tokens)
|
||||||
|
else:
|
||||||
|
num_input_tokens = total_num_scheduled_tokens
|
||||||
|
|
||||||
modified_batch = self.attn_metadata_builder.reorder_batch(
|
modified_batch = self.attn_metadata_builder.reorder_batch(
|
||||||
self.input_batch, scheduler_output)
|
self.input_batch, scheduler_output)
|
||||||
@@ -546,7 +552,7 @@ class NPUModelRunner:
|
|||||||
|
|
||||||
self.positions[:total_num_scheduled_tokens].copy_(
|
self.positions[:total_num_scheduled_tokens].copy_(
|
||||||
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||||
positions = self.positions[:total_num_scheduled_tokens]
|
positions = self.positions[:num_input_tokens]
|
||||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||||
|
|
||||||
self.seq_lens_np[:num_reqs] = (
|
self.seq_lens_np[:num_reqs] = (
|
||||||
@@ -605,7 +611,7 @@ class NPUModelRunner:
|
|||||||
# Copy the tensors to the NPU.
|
# Copy the tensors to the NPU.
|
||||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||||
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
input_ids = self.input_ids[:num_input_tokens]
|
||||||
|
|
||||||
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||||
padding = torch.zeros(graph_pad_size,
|
padding = torch.zeros(graph_pad_size,
|
||||||
@@ -615,7 +621,9 @@ class NPUModelRunner:
|
|||||||
positions = torch.cat([positions, padding])
|
positions = torch.cat([positions, padding])
|
||||||
|
|
||||||
# Run forward pass
|
# Run forward pass
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=num_input_tokens):
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
if self.enable_torchair_graph_mode:
|
if self.enable_torchair_graph_mode:
|
||||||
model_kwargs["kv_caches"] = self.kv_caches
|
model_kwargs["kv_caches"] = self.kv_caches
|
||||||
@@ -1062,7 +1070,7 @@ class NPUModelRunner:
|
|||||||
return kv_cache_spec
|
return kv_cache_spec
|
||||||
|
|
||||||
def capture_model(self) -> None:
|
def capture_model(self) -> None:
|
||||||
if not self.use_npu_graph:
|
if not self.use_aclgraph:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Skipping NPU graph capture. Please add "
|
"Skipping NPU graph capture. Please add "
|
||||||
"-O %s to use NPU graphs.", CompilationLevel.PIECEWISE)
|
"-O %s to use NPU graphs.", CompilationLevel.PIECEWISE)
|
||||||
@@ -1070,9 +1078,6 @@ class NPUModelRunner:
|
|||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
start_free_npu_memory = torch.npu.mem_get_info()[0]
|
start_free_npu_memory = torch.npu.mem_get_info()[0]
|
||||||
# Since vllm aclgraph_batch_sizes is too large,
|
|
||||||
# we need to adjust its length to proper size.
|
|
||||||
self.verify_adjust_aclgraph_batch_sizes()
|
|
||||||
|
|
||||||
# Trigger ACL graph capture for specific shapes.
|
# Trigger ACL graph capture for specific shapes.
|
||||||
# Capture the large shapes first so that the smaller shapes
|
# Capture the large shapes first so that the smaller shapes
|
||||||
@@ -1091,63 +1096,3 @@ class NPUModelRunner:
|
|||||||
# This usually takes 5~20 seconds.
|
# This usually takes 5~20 seconds.
|
||||||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||||
elapsed_time, npu_graph_size / (1 << 30))
|
elapsed_time, npu_graph_size / (1 << 30))
|
||||||
|
|
||||||
def verify_adjust_aclgraph_batch_sizes(self) -> None:
|
|
||||||
# Now, vllm-ascend support max capture size is 1920
|
|
||||||
max_capture_size = 1920
|
|
||||||
original_aclgraph_batch_sizes = self.aclgraph_batch_sizes
|
|
||||||
num_hidden_layers = self.vllm_config.model_config.hf_config.num_hidden_layers
|
|
||||||
max_support_len_aclgraph = self.get_max_support_len(
|
|
||||||
max_capture_size, num_hidden_layers)
|
|
||||||
|
|
||||||
if max_support_len_aclgraph < len(original_aclgraph_batch_sizes):
|
|
||||||
self.aclgraph_batch_sizes = self.sample_from_list(
|
|
||||||
max_support_len_aclgraph)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Model:%s-num_hidden_layers:%d will adjust aclgraph_batch_sizes, pre-adjust-len: %s, post-adjust-len: %s",
|
|
||||||
self.vllm_config.model_config.architectures[0],
|
|
||||||
num_hidden_layers, len(original_aclgraph_batch_sizes),
|
|
||||||
len(self.aclgraph_batch_sizes))
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Model:%s-num_hidden_layers:%d no need adjust aclgraph_batch_sizes, list_len: %s",
|
|
||||||
self.vllm_config.model_config.architectures[0],
|
|
||||||
num_hidden_layers, len(original_aclgraph_batch_sizes))
|
|
||||||
|
|
||||||
def get_max_support_len(self, max_capture_size, num_hidden_layers) -> int:
|
|
||||||
parallel_type_cnt = 0
|
|
||||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
|
||||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
|
||||||
if dp_size > 1:
|
|
||||||
parallel_type_cnt += 1
|
|
||||||
if tp_size > 1:
|
|
||||||
parallel_type_cnt += 1
|
|
||||||
max_support_len_aclgraph = math.floor(max_capture_size /
|
|
||||||
(num_hidden_layers + 1) /
|
|
||||||
(parallel_type_cnt + 1))
|
|
||||||
logger.info(
|
|
||||||
"max_capture_size:%s, dp_size:%s, tp_size:%s, parallel_type_cnt:%s, max_support_len_aclgraph: %s:",
|
|
||||||
max_capture_size,
|
|
||||||
dp_size,
|
|
||||||
tp_size,
|
|
||||||
parallel_type_cnt,
|
|
||||||
max_support_len_aclgraph,
|
|
||||||
)
|
|
||||||
|
|
||||||
return max_support_len_aclgraph
|
|
||||||
|
|
||||||
def sample_from_list(self, sample_len) -> list[int]:
|
|
||||||
# we use this function to sample a new list from old list by given length, and maintain uniformity, for example:
|
|
||||||
# original: [1 8 16 24 32 40 48 56 64]
|
|
||||||
# --> sample length = 3: [1 32 64]
|
|
||||||
# --> sample length = 5: [1 16 32 48 64]
|
|
||||||
original_len = len(self.aclgraph_batch_sizes)
|
|
||||||
step = (original_len - 1) / (sample_len - 1)
|
|
||||||
indices = [round(i * step) for i in range(sample_len)]
|
|
||||||
# Align first and last element of the original list and sub-list
|
|
||||||
indices[0] = 0
|
|
||||||
indices[-1] = original_len - 1
|
|
||||||
# Sample new list
|
|
||||||
new_list = [self.aclgraph_batch_sizes[i] for i in indices]
|
|
||||||
return new_list
|
|
||||||
|
|||||||
Reference in New Issue
Block a user