From adadd50613747e8bb9a7d6a961b26fe080f41f77 Mon Sep 17 00:00:00 2001 From: zouyida2052 Date: Wed, 29 Oct 2025 23:50:13 +0800 Subject: [PATCH] bugfix for mtp fullgraph (#3845) ### What this PR does / why we need it? bugfix for mtp fullgraph ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac Signed-off-by: zouyida2052 --- vllm_ascend/platform.py | 2 + vllm_ascend/utils.py | 65 +++++++++++++++++++-------- vllm_ascend/worker/model_runner_v1.py | 11 +---- 3 files changed, 51 insertions(+), 27 deletions(-) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index d7550bf1..0c49d956 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -306,6 +306,7 @@ class NPUPlatform(Platform): **********************************************************************************\033[0m """ logger.warning(warning_message) + update_aclgraph_sizes(vllm_config) else: logger.info( "%s cudagraph_mode is not support on NPU. falling back to NONE", @@ -343,6 +344,7 @@ class NPUPlatform(Platform): **********************************************************************************\033[0m """ logger.warning(warning_message) + update_aclgraph_sizes(vllm_config) else: logger.info( "%s cudagraph_mode is not support on NPU. falling back to NONE", diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index fffb325b..d591f2b9 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -349,6 +349,12 @@ def update_cudagraph_capture_sizes(vllm_config: VllmConfig, def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: """Update ACL graph capture sizes based on hardware limitations""" + from vllm.config.compilation import CUDAGraphMode + if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + if vllm_config.speculative_config is not None and \ + vllm_config.speculative_config.num_speculative_tokens > 1: + _update_spec_aclgraph_sizes(vllm_config) + return # NOTE: Currently, we can only capture 1800 graphs at most, # due to the limitation of ACL graph. This number is bounded by # the number of streams, which is 2048, we save 248 streams @@ -459,28 +465,51 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: vllm_config.model_config.architectures[0], num_hidden_layers, len(original_sizes)) + if vllm_config.speculative_config is not None and \ + vllm_config.speculative_config.num_speculative_tokens > 1: + _update_spec_aclgraph_sizes(vllm_config) + + +def _update_spec_aclgraph_sizes(vllm_config: VllmConfig) -> None: # default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario # the maximum size cudagraph_capture_sizes[0] should be greater or equal than # (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode - if vllm_config.speculative_config is not None and \ - vllm_config.speculative_config.num_speculative_tokens > 1: - num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens - max_num_seqs = vllm_config.scheduler_config.max_num_seqs - original_sizes, compilation_config.cudagraph_capture_sizes = \ - compilation_config.cudagraph_capture_sizes, None - assert len(original_sizes) > 0 - if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs: - enlarged_sizes = [(num_speculative_tokens + 1) * size - for size in original_sizes] - if vllm_version_is("0.11.0"): - compilation_config.init_with_cudagraph_sizes(enlarged_sizes) - else: - update_cudagraph_capture_sizes(vllm_config, enlarged_sizes) - logger.info( - "Adjusted ACL graphs: %s → %s for speculative decoding", - original_sizes, enlarged_sizes) + from vllm.config.compilation import CUDAGraphMode + compilation_config = vllm_config.compilation_config + num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens + uniform_decode_query_len = num_speculative_tokens + 1 + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + max_num_tokens = max_num_seqs * uniform_decode_query_len + original_sizes, compilation_config.cudagraph_capture_sizes = \ + compilation_config.cudagraph_capture_sizes, None + + assert len(original_sizes) > 0 + + if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and \ + not all(size % uniform_decode_query_len == 0 for size in original_sizes): + enlarged_sizes = [ + size * uniform_decode_query_len for size in original_sizes + if size >= uniform_decode_query_len and size * + uniform_decode_query_len <= max_num_tokens + ] + if vllm_version_is("0.11.0"): + compilation_config.init_with_cudagraph_sizes(enlarged_sizes) else: - compilation_config.cudagraph_capture_sizes = original_sizes + update_cudagraph_capture_sizes(vllm_config, enlarged_sizes) + logger.info("Adjusted ACL graphs: %s → %s for speculative decoding", + original_sizes, enlarged_sizes) + elif original_sizes[0] < max_num_tokens: + enlarged_sizes = [ + size * uniform_decode_query_len for size in original_sizes + ] + if vllm_version_is("0.11.0"): + compilation_config.init_with_cudagraph_sizes(enlarged_sizes) + else: + update_cudagraph_capture_sizes(vllm_config, enlarged_sizes) + logger.info("Adjusted ACL graphs: %s → %s for speculative decoding", + original_sizes, enlarged_sizes) + else: + compilation_config.cudagraph_capture_sizes = original_sizes # TODO(wxy): Move to ops module diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 5e13da53..efdd7c10 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -3885,7 +3885,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): if aclgraph_mode.mixed_mode() != CUDAGraphMode.NONE: aclgraph_runtime_mode = aclgraph_mode.mixed_mode() - compilation_cases = list(reversed(self.aclgraph_batch_sizes)) + compilation_cases = sorted(self.aclgraph_batch_sizes) try: self._capture_aclgraphs( @@ -3914,14 +3914,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \ aclgraph_mode.separate_routine(): - max_num_tokens = self.scheduler_config.max_num_seqs * \ - self.uniform_decode_query_len - decode_cudagraph_batch_sizes = [ - x for x in self.aclgraph_batch_sizes if x <= max_num_tokens - and x >= self.uniform_decode_query_len - ] - compilation_cases_decode = list( - reversed(decode_cudagraph_batch_sizes)) + compilation_cases_decode = sorted(self.aclgraph_batch_sizes) self._capture_aclgraphs( compilation_cases=compilation_cases_decode, aclgraph_runtime_mode=CUDAGraphMode.FULL,