[Enhancement] Add padding for ACL Graph (#803)
### What this PR does / why we need it? Add padding for ACL Graph and refactor graph batch size adjustments to utils.py --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -16,12 +16,28 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
||||
#
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
# NOTE: Currently, we can only capture 1920 graphs at most,
|
||||
# due to the limitation of ACL graph. This number is bounded by
|
||||
# the number of streams, which is 2048, we save 128 streams
|
||||
# as a buffer.
|
||||
# Maximum number of graphs that can be captured by ACL Graph
|
||||
MAX_CAPTURE_SIZE = 1920
|
||||
|
||||
|
||||
def try_register_lib(lib_name: str, lib_info: str = ""):
|
||||
import importlib
|
||||
@@ -99,3 +115,55 @@ def vllm_version_is(target_vllm_version: str):
|
||||
"is installed probably. Set the environment variable VLLM_VERSION "
|
||||
"to control it by hand. And please make sure the vaule follows the "
|
||||
"format of x.y.z.")
|
||||
|
||||
|
||||
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
"""Update ACL graph capture sizes based on hardware limitations"""
|
||||
# Store original configuration and temporarily clear it
|
||||
compilation_config = vllm_config.compilation_config
|
||||
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
||||
compilation_config.cudagraph_capture_sizes, None
|
||||
|
||||
# Calculate parallel configuration factor (increases with DP or TP)
|
||||
# TODO(Yizhou): This is a temporary solution, need to be improved
|
||||
# in the future, taking into account the other parallel configurations.
|
||||
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
||||
parallel_config = vllm_config.parallel_config
|
||||
parallel_factor = 1 + sum(size > 1 for size in [
|
||||
parallel_config.data_parallel_size,
|
||||
parallel_config.tensor_parallel_size
|
||||
])
|
||||
|
||||
# Calculate maximum supported batch sizes considering model architecture
|
||||
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
|
||||
(num_hidden_layers + 1) / parallel_factor)
|
||||
logger.info("Calculated maximum supported batch sizes for ACL graph: %s",
|
||||
max_num_batch_sizes)
|
||||
|
||||
# If original sizes exceed maximum, sample a representative subset
|
||||
if max_num_batch_sizes < len(original_sizes):
|
||||
# Sample uniformly from original sizes
|
||||
step = (len(original_sizes) - 1) / (max_num_batch_sizes - 1)
|
||||
indices = [round(i * step) for i in range(max_num_batch_sizes)]
|
||||
|
||||
# Ensure first and last elements are preserved
|
||||
indices[0], indices[-1] = 0, len(original_sizes) - 1
|
||||
|
||||
sampled_sizes = [original_sizes[i] for i in indices]
|
||||
compilation_config.init_with_cudagraph_sizes(sampled_sizes)
|
||||
|
||||
logger.info(
|
||||
"Adjusted ACL graph batch sizes for %s model (layers: %d): %d → %d sizes",
|
||||
vllm_config.model_config.architectures[0],
|
||||
num_hidden_layers,
|
||||
len(original_sizes),
|
||||
len(compilation_config.
|
||||
cudagraph_capture_sizes # type: ignore[arg-type]
|
||||
))
|
||||
else:
|
||||
# No adjustment needed
|
||||
compilation_config.cudagraph_capture_sizes = original_sizes
|
||||
logger.info(
|
||||
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
|
||||
vllm_config.model_config.architectures[0], num_hidden_layers,
|
||||
len(original_sizes))
|
||||
|
||||
Reference in New Issue
Block a user