Files

213 lines
8.7 KiB
Python
Raw Permalink Normal View History

2026-04-24 09:50:34 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import os
from vllm.config.vllm import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__config__vllm__VllmConfig___set_cudagraph_sizes(self):
"""
vLLM defines the default candidate list of batch sizes for CUDA graph
capture as:
```python
max_graph_size = min(max_num_seqs * 2, 512)
# 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
# up to max_graph_size
cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
range(256, max_graph_size + 1, 16))
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
will be the final sizes to capture cudagraph (in ascending order).
These sizes are used to capture and reuse CUDA graphs for
performance-critical paths (e.g., decoding). Capturing enables
significantly faster kernel dispatch by avoiding Python overhead. The
list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on
most GPUs), which controls the total allowed number of tokens in a
batch. Since each sequence may have a variable number of tokens, the
maximum usable batch size will depend on actual sequence lengths.
Example:
With `max_num_batched_tokens = 8192`, and typical sequences
averaging ~32 tokens, most practical batch sizes fall below 256.
However, the system will still allow capture sizes up to 512 if
shape and memory permit.
Note:
If users explicitly specify cudagraph capture sizes in the
compilation config, those will override this default logic.
At runtime:
- If batch size <= one of the `cudagraph_capture_sizes`, the closest
padded CUDA graph will be used.
- If batch size > largest `cudagraph_capture_sizes`, cudagraph will
not be used.
"""
if hasattr(self.compilation_config, "_has_set_capture_list"):
# avoid set capture list twice while init
return
if (
self.model_config is not None
and not self.model_config.enforce_eager
and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
):
# determine the initial max_cudagraph_capture_size
max_cudagraph_capture_size = (
self.compilation_config.max_cudagraph_capture_size
)
if max_cudagraph_capture_size is None:
max_cudagraph_capture_size = min(
self.scheduler_config.max_num_seqs * 2, 512
)
max_num_tokens = self.scheduler_config.max_num_batched_tokens
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)
assert max_cudagraph_capture_size >= 1, (
"Maximum cudagraph size should be greater than or equal to 1 "
"when using cuda graph."
)
# determine the cudagraph_capture_sizes
if self.compilation_config.cudagraph_capture_sizes is not None:
assert len(self.compilation_config.cudagraph_capture_sizes) > 0, (
"cudagraph_capture_sizes should contain at least one element "
"when using cuda graph."
)
# de-duplicate the sizes provided by the config
dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
cudagraph_capture_sizes = [
i for i in dedup_sizes if i <= max_num_tokens
]
# sort to make sure the sizes are in ascending order
cudagraph_capture_sizes.sort()
else:
cudagraph_capture_sizes = [
i for i in [1, 2, 4] if i <= max_cudagraph_capture_size
]
if max_cudagraph_capture_size >= 8:
# Step size 8 for small batch sizes, up to 256(not included)
cudagraph_capture_sizes += list(
range(8, min(max_cudagraph_capture_size + 1, 256), 8)
)
if max_cudagraph_capture_size >= 256:
# Step size 16 for larger batch sizes
cudagraph_capture_sizes += list(
range(256, max_cudagraph_capture_size + 1, 16)
)
'''
=============================
Modify by vllm_mlu
=============================
@brief:
1) check batch_size_capture_list when enable mtp because bs * (K + 1)
may greater than max_num_batched_tokens
2) capture MLUGraph by given batch list
'''
mlu_graph_capture_list = os.getenv("MLU_GRAPH_CAPTURE_LIST", None)
if mlu_graph_capture_list:
if "-" in mlu_graph_capture_list:
batch_info = mlu_graph_capture_list.split("-")
assert len(batch_info) == 3, \
f"Got invalid graph_capture_list={mlu_graph_capture_list}, " + \
f"but expected format 'min_bs-max_bs(may not include)-step'."
start, end, step = mlu_graph_capture_list.split("-")
cudagraph_capture_sizes = [1, 2, 4] + [
i for i in range(int(start), int(end), int(step))
]
cudagraph_capture_sizes = sorted(list(set(cudagraph_capture_sizes)))
else:
cudagraph_capture_sizes = [int(x) for x in mlu_graph_capture_list.split(",")]
if (self.speculative_config is not None
and self.speculative_config.num_speculative_tokens > 0
):
K = self.speculative_config.num_speculative_tokens
cudagraph_capture_sizes = [x * (1 + K) for x in cudagraph_capture_sizes]
cudagraph_capture_sizes = [
size for size in cudagraph_capture_sizes
if size <= self.scheduler_config.max_num_batched_tokens
]
'''
==================
End of MLU Hijack
==================
'''
if (
self.parallel_config.tensor_parallel_size > 1
and self.compilation_config.pass_config.enable_sequence_parallelism
):
cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
cudagraph_capture_sizes
)
# user-specific compilation_config.max_cudagraph_capture_size get
# truncated to valid_max_size when they are inconsistent.
valid_max_size = (
cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0
)
if (
self.compilation_config.max_cudagraph_capture_size is not None
and self.compilation_config.max_cudagraph_capture_size != valid_max_size
):
# raise error only when both two flags are user-specified
# and they are inconsistent with each other
if self.compilation_config.cudagraph_capture_sizes is not None:
raise ValueError(
"customized max_cudagraph_capture_size"
f"(={self.compilation_config.max_cudagraph_capture_size}) "
"should be consistent with the max value of "
f"cudagraph_capture_sizes(={valid_max_size})"
)
logger.warning(
"Truncating max_cudagraph_capture_size to %d",
valid_max_size,
)
# always set the final max_cudagraph_capture_size
self.compilation_config.max_cudagraph_capture_size = valid_max_size
if self.compilation_config.cudagraph_capture_sizes is not None and len(
cudagraph_capture_sizes
) < len(self.compilation_config.cudagraph_capture_sizes):
# If users have specified capture sizes, we only need to
# compare the lens before and after modification since the modified
# list is only the subset of the original list.
logger.warning(
(
"cudagraph_capture_sizes specified in compilation_config"
" %s is overridden by config %s"
),
self.compilation_config.cudagraph_capture_sizes,
cudagraph_capture_sizes,
)
# always write back the final sizes
self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes
else:
# no cudagraph in use
self.compilation_config.max_cudagraph_capture_size = 0
self.compilation_config.cudagraph_capture_sizes = []
# complete the remaining process.
self.compilation_config.post_init_cudagraph_sizes()
setattr(self.compilation_config, "_has_set_capture_list", True)
MluHijackObject.apply_hijack(
VllmConfig,
VllmConfig._set_cudagraph_sizes,
vllm__config__vllm__VllmConfig___set_cudagraph_sizes,
)