[Graph][Bugfix] Set default cudagraph max capture size via platform defaults (#7572)
### What this PR does / why we need it?
This PR lets NPU platform provide its own default
`max_cudagraph_capture_size` via
`NPUPlatform.apply_config_platform_defaults()`.
Previously, when cudagraph sizing was left unset, Ascend inherited
vLLM's upstream default heuristic in `_set_cudagraph_sizes()`, which
uses `max_num_seqs * decode_query_len * 2`. This PR changes Ascend's
default to `min(max_num_seqs * decode_query_len, 512)` while keeping the
rest of vLLM's cudagraph sizing logic unchanged.
### Does this PR introduce _any_ user-facing change?
Yes, but only for Ascend when users do not explicitly configure
cudagraph sizing.
If `max_cudagraph_capture_size` and `cudagraph_capture_sizes` are both
unset, we now uses `max_num_seqs * decode_query_len` (capped at `512`)
instead of the upstream `* 2` default. Explicit user settings are
unchanged.
### How was this patch tested?
Add unit tests to cover:
- default max injection via `apply_config_platform_defaults()`
- explicit `max_cudagraph_capture_size` is preserved
- explicit `cudagraph_capture_sizes` are preserved
- Ascend default max no longer uses the upstream `* 2`
- late `_set_cudagraph_sizes()` recomputation reuses the current max
input
- vLLM version: v0.18.0
- vLLM main:
ed359c497a
---------
Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -9,7 +9,11 @@ from vllm.v1.attention.selector import AttentionSelectorConfig # type: ignore
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, AscendDeviceType
|
||||
from vllm_ascend.utils import (
|
||||
ASCEND_QUANTIZATION_METHOD,
|
||||
COMPRESSED_TENSORS_METHOD,
|
||||
AscendDeviceType,
|
||||
)
|
||||
|
||||
|
||||
class TestNPUPlatform(TestBase):
|
||||
@@ -21,7 +25,9 @@ class TestNPUPlatform(TestBase):
|
||||
mock_vllm_config.parallel_config = MagicMock()
|
||||
mock_vllm_config.cache_config = MagicMock()
|
||||
mock_vllm_config.scheduler_config = MagicMock()
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = None
|
||||
mock_vllm_config.speculative_config = None
|
||||
mock_vllm_config.additional_config = {}
|
||||
mock_vllm_config.compilation_config.pass_config.enable_sp = False
|
||||
mock_vllm_config.compilation_config.cudagraph_mode = None
|
||||
return mock_vllm_config
|
||||
@@ -30,7 +36,13 @@ class TestNPUPlatform(TestBase):
|
||||
def mock_vllm_ascend_config():
|
||||
mock_ascend_config = MagicMock()
|
||||
mock_ascend_config.xlite_graph_config.enabled = False
|
||||
mock_ascend_config.xlite_graph_config.full_mode = False
|
||||
mock_ascend_config.ascend_compilation_config.enable_npugraph_ex = False
|
||||
mock_ascend_config.ascend_fusion_config = None
|
||||
mock_ascend_config.recompute_scheduler_enable = False
|
||||
mock_ascend_config.SLO_limits_for_dynamic_batch = -1
|
||||
mock_ascend_config.enable_shared_expert_dp = False
|
||||
mock_ascend_config.update_compile_ranges_split_points = MagicMock()
|
||||
return mock_ascend_config
|
||||
|
||||
def setUp(self):
|
||||
@@ -99,6 +111,105 @@ class TestNPUPlatform(TestBase):
|
||||
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
||||
self.assertEqual(len(mock_action.choices), 2)
|
||||
|
||||
def test_apply_config_platform_defaults_sets_ascend_default_max(self):
|
||||
test_cases = [
|
||||
(40, 3, 160),
|
||||
(200, 3, 512),
|
||||
]
|
||||
|
||||
for max_num_seqs, num_speculative_tokens, expected_max in test_cases:
|
||||
with self.subTest(
|
||||
max_num_seqs=max_num_seqs,
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
expected_max=expected_max,
|
||||
):
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.scheduler_config.max_num_seqs = max_num_seqs
|
||||
vllm_config.speculative_config = MagicMock(
|
||||
num_speculative_tokens=num_speculative_tokens
|
||||
)
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size = None
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes = None
|
||||
|
||||
self.platform.apply_config_platform_defaults(vllm_config)
|
||||
|
||||
self.assertEqual(
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size,
|
||||
expected_max,
|
||||
)
|
||||
|
||||
def test_apply_config_platform_defaults_respects_explicit_max(self):
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size = 456
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes = None
|
||||
|
||||
self.platform.apply_config_platform_defaults(vllm_config)
|
||||
|
||||
self.assertEqual(vllm_config.compilation_config.max_cudagraph_capture_size, 456)
|
||||
|
||||
def test_apply_config_platform_defaults_respects_explicit_sizes(self):
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size = None
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes = [1, 2, 4]
|
||||
|
||||
self.platform.apply_config_platform_defaults(vllm_config)
|
||||
|
||||
self.assertIsNone(vllm_config.compilation_config.max_cudagraph_capture_size)
|
||||
self.assertEqual(vllm_config.compilation_config.cudagraph_capture_sizes, [1, 2, 4])
|
||||
|
||||
def test_apply_config_platform_defaults_skips_when_scheduler_max_num_seqs_is_missing(self):
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size = None
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes = None
|
||||
|
||||
self.platform.apply_config_platform_defaults(vllm_config)
|
||||
|
||||
self.assertIsNone(vllm_config.compilation_config.max_cudagraph_capture_size)
|
||||
|
||||
@patch("vllm_ascend.platform.refresh_block_size")
|
||||
@patch("vllm_ascend.platform.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
||||
@patch("vllm_ascend.platform.enable_sp", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
||||
def test_check_and_update_config_preserves_platform_default_max_input(
|
||||
self,
|
||||
mock_auto_detect,
|
||||
mock_init_ascend,
|
||||
_mock_enable_sp,
|
||||
_mock_device_type,
|
||||
_mock_refresh_block_size,
|
||||
):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.scheduler_config.max_num_seqs = 77
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size = None
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes = None
|
||||
vllm_config.compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
|
||||
vllm_config.compilation_config.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
vllm_config.compilation_config.custom_ops = []
|
||||
vllm_config.model_config.enforce_eager = False
|
||||
vllm_config.model_config.enable_sleep_mode = True
|
||||
vllm_config.model_config.is_encoder_decoder = False
|
||||
vllm_config.parallel_config.decode_context_parallel_size = 1
|
||||
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
vllm_config.parallel_config.worker_cls = "manual"
|
||||
vllm_config.parallel_config.cp_kv_cache_interleave_size = 1
|
||||
vllm_config.cache_config.block_size = 1
|
||||
|
||||
self.platform.apply_config_platform_defaults(vllm_config)
|
||||
|
||||
observed_inputs: list[int | None] = []
|
||||
vllm_config._set_cudagraph_sizes = MagicMock(
|
||||
side_effect=lambda: observed_inputs.append(
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
)
|
||||
)
|
||||
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
|
||||
self.assertEqual(observed_inputs, [77])
|
||||
|
||||
def test_get_device_capability(self):
|
||||
self.assertIsNone(self.platform.get_device_capability(device_id=0))
|
||||
|
||||
|
||||
@@ -152,6 +152,49 @@ class NPUPlatform(Platform):
|
||||
|
||||
config_deprecated_logging()
|
||||
|
||||
@classmethod
|
||||
def _get_default_max_cudagraph_capture_size(cls, vllm_config: VllmConfig) -> int | None:
|
||||
"""Mirror the default-max branch in vLLM's `_set_cudagraph_sizes()`.
|
||||
|
||||
This helper corresponds to the upstream block under
|
||||
"determine the initial max_cudagraph_capture_size" when
|
||||
`compilation_config.max_cudagraph_capture_size is None`.
|
||||
|
||||
Ascend injects this default earlier via `apply_config_platform_defaults()`
|
||||
so the rest of `_set_cudagraph_sizes()` can keep using upstream logic for
|
||||
size-list generation, token-cap clipping, SP filtering, and later
|
||||
post-processing. The only intentional difference from upstream is removing
|
||||
the CUDA-oriented trailing `* 2`: Ascend wants the default capture upper
|
||||
bound to track `max_num_seqs * decode_query_len`, capped at 512.
|
||||
|
||||
Returning `None` means the platform should not inject a default. This
|
||||
covers the cases where the user has already provided either
|
||||
`max_cudagraph_capture_size` or `cudagraph_capture_sizes`.
|
||||
"""
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if compilation_config.max_cudagraph_capture_size is not None:
|
||||
return None
|
||||
if compilation_config.cudagraph_capture_sizes is not None:
|
||||
return None
|
||||
|
||||
scheduler_config = getattr(vllm_config, "scheduler_config", None)
|
||||
max_num_seqs = getattr(scheduler_config, "max_num_seqs", None)
|
||||
if max_num_seqs is None:
|
||||
return None
|
||||
|
||||
decode_query_len = 1
|
||||
speculative_config = getattr(vllm_config, "speculative_config", None)
|
||||
if speculative_config and speculative_config.num_speculative_tokens:
|
||||
decode_query_len += speculative_config.num_speculative_tokens
|
||||
|
||||
return min(max_num_seqs * decode_query_len, 512)
|
||||
|
||||
@classmethod
|
||||
def apply_config_platform_defaults(cls, vllm_config: VllmConfig) -> None:
|
||||
default_max_cg_capture_size = cls._get_default_max_cudagraph_capture_size(vllm_config)
|
||||
if default_max_cg_capture_size is not None:
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size = default_max_cg_capture_size
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(cls, device_id: int = 0):
|
||||
return None
|
||||
@@ -273,7 +316,10 @@ class NPUPlatform(Platform):
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# set cudaprah sizes before extending `compilation_config.splitting_ops`
|
||||
# Recompute cudagraph sizes after Ascend-specific compatibility updates.
|
||||
# The platform default max is injected earlier via
|
||||
# `apply_config_platform_defaults`, so this late pass should only honor
|
||||
# the current max / size inputs after the mode adjustments above.
|
||||
vllm_config._set_cudagraph_sizes()
|
||||
# TODO delete graph size update here when compilation_config.pass_config.enable_sp
|
||||
# is supported by vllm-ascend.
|
||||
|
||||
Reference in New Issue
Block a user