From 9c6d0b422c6d24e3d7e895d7a5cf2d5cc5c313d1 Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Fri, 21 Nov 2025 11:19:11 +0800 Subject: [PATCH] [v0.11.0-dev][misc]change default capture size for Qwen3-MoE when using full dp (#4205) ### What this PR does / why we need it? This dev version of #4199 . Currently, the default `cudagraph_capture_size` in vLLM is `[1, 2, 4 ,8 ,16 ,24 ,... , max_capture_size]`. However, this is not always the best choice on different situations. This PR aims to change the default setting when running Qwen3-MoE on full dp (`dp_size > 1` && `tp_size == 1`) setting, which is usually applied in Large-Scale EP. old : `[1, 2, 4 ,8 ,16 ,24 ,... , max_capture_size]` new: `[1, 2, 5 ,10 ,15, 16 ,24 ,... , max_capture_size]` This is mainly because the performance of `_npu_paged_attention` op degrades dramatically on old settings. We hope to provide better performance if users do not set specific `cudagraph_capture_size`. ### Does this PR introduce _any_ user-facing change? The default `cudagraph_capture_size` is modified in above cases. However, if `cudagraph_capture_size` has already set by users, this PR won't have any influence on this. ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: Angazenn --- vllm_ascend/platform.py | 7 +++++- vllm_ascend/utils.py | 47 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 9da4aa6..6bcad2c 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -32,7 +32,8 @@ from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config, from vllm_ascend.torchair.utils import (check_torchair_cache_exist, delete_torchair_cache_file) from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, enable_sp, is_310p, - update_aclgraph_sizes) + update_aclgraph_sizes, + update_default_aclgraph_sizes) if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -182,6 +183,10 @@ class NPUPlatform(Platform): # set cudaprah sizes before extending `compilation_config.splitting_ops` vllm_config._set_cudagraph_sizes() + # There are cases where default cudagraph_capture_sizes are not friendly + # to ascend ops && hardwares. We update these sizes here to improve + # default performance. + update_default_aclgraph_sizes(vllm_config) # TODO delete graph size update here when compilation_config.pass_config.enable_sequence_parallelism # is supported by vllm-ascend. if vllm_config.parallel_config.tensor_parallel_size > 1 and not vllm_config.model_config.enforce_eager and \ diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 785cf30..8f21ef7 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -319,6 +319,53 @@ def get_max_hidden_layers(hf_config) -> int: return max(layer_counts) +def _is_default_capture_sizes(vllm_config: VllmConfig) -> bool: + """ + Check whether it is vLLM default capture sizes. + """ + + cuda_graph_sizes = vllm_config.scheduler_config.cuda_graph_sizes + if len(cuda_graph_sizes) == 1: + default_size_capture_list = [1, 2, 4] + [ + i for i in range(8, cuda_graph_sizes[0] + 1, 8) + ] + + if sorted(default_size_capture_list, reverse=True) == \ + vllm_config.compilation_config.cudagraph_capture_sizes: + return True + + return False + + +def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None: + """ + Update ACL graph default capture sizes, so that new sizes + are more friendly to ascend ops && hardware. + """ + + if vllm_config.model_config is None or \ + vllm_config.model_config.enforce_eager or \ + not _is_default_capture_sizes(vllm_config): + return + + # modify the default capture_sizes for Qwen3-MoE models on dp settings. + # this is mainly because performance of _npu_paged_attention might degrades + # on special shapes. + # TODO(Angazenn): we will remove this once _npu_paged_attention is fully + # replaced by npu_fused_infer_attention_score which does not contain such bugs. + if vllm_config.model_config and vllm_config.model_config.hf_config.model_type == "qwen3_moe" \ + and vllm_config.parallel_config.tensor_parallel_size == 1 \ + and vllm_config.parallel_config.data_parallel_size > 1 : + max_capture_size = vllm_config.scheduler_config.cuda_graph_sizes[0] + new_cudagraph_capture_sizes = [1, 2, 5, 10, 15, 20] + [ + i for i in range(24, max_capture_size + 1, 8) + ] + + vllm_config.compilation_config.cudagraph_capture_sizes = new_cudagraph_capture_sizes + vllm_config.compilation_config.init_with_cudagraph_sizes( + new_cudagraph_capture_sizes) + + def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: """Update ACL graph capture sizes based on hardware limitations""" # NOTE: Currently, we can only capture 1800 graphs at most,