From d3c3538ddc67ba8f4873637e2bc1052f9eb09e93 Mon Sep 17 00:00:00 2001 From: realliujiaxu <72550220+realliujiaxu@users.noreply.github.com> Date: Mon, 8 Sep 2025 14:52:33 +0800 Subject: [PATCH] [Bugfix]fix bug when graph_size is not divisible by tp_size (#2719) ### What this PR does / why we need it? fix https://github.com/vllm-project/vllm-ascend/issues/2702 - A2: skip graph_size update that makes it to tp_size because dispatch/combine op support different batch size across EP ranks - A3: add `max_num_reqs = max(new_graph_batch_sizes)` to fix graph_size and max_num_reqs mismatch ### Does this PR introduce _any_ user-facing change? Nope ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/e599e2c65ee32abcc986733ab0a55becea158bb4 --------- Signed-off-by: realliujiaxu --- vllm_ascend/torchair/torchair_model_runner.py | 63 ++++++++++++------- vllm_ascend/worker/model_runner_v1.py | 1 - 2 files changed, 40 insertions(+), 24 deletions(-) diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 71315b1..d730c44 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -41,7 +41,8 @@ from vllm_ascend.torchair.utils import ( register_torchair_model, torchair_ops_patch, torchair_quant_method_register, write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - is_310p) + is_310p, get_ascend_soc_version, + AscendSocVersion) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -59,7 +60,7 @@ class NPUTorchairModelRunner(NPUModelRunner): if ascend_config.torchair_graph_config.graph_batch_sizes_init: self.init_torchair_graph_batch_sizes() - self.check_torchair_graph_batch_sizes() + self.update_torchair_graph_batch_sizes() torch._dynamo.cache_size.config.cache_size_limit += len( self.torchair_graph_batch_sizes) @@ -397,7 +398,7 @@ class NPUTorchairModelRunner(NPUModelRunner): f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}." ) - def check_torchair_graph_batch_sizes(self): + def update_torchair_graph_batch_sizes(self): # return graph_batch_sizes according to the max number of tokens # first pad according to the number of requests if len(self.torchair_graph_batch_sizes) == 0: @@ -421,27 +422,43 @@ class NPUTorchairModelRunner(NPUModelRunner): for graph_batch_size in self.torchair_graph_batch_sizes ] - # NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size` + # NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size` + # Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same + # on all EP ranks + if get_ascend_soc_version( + ) == AscendSocVersion.A3 and self.parallel_config.enable_expert_parallel: + self._align_graph_size_divisible_by_tp_size() + + def _align_graph_size_divisible_by_tp_size(self): tp_size = self.parallel_config.tensor_parallel_size - if self.parallel_config.enable_expert_parallel: - new_graph_batch_sizes = [] - for graph_batch_size in self.torchair_graph_batch_sizes: - cur_graph_batch_size = (graph_batch_size + tp_size - - 1) // tp_size * tp_size - # MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size, - # Both adapter multi-dp and FIA operator - if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1: - cur_graph_batch_size = (tp_size * graph_batch_size) \ - // math.gcd(tp_size, graph_batch_size) - if cur_graph_batch_size not in new_graph_batch_sizes and \ - cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: - new_graph_batch_sizes.append(cur_graph_batch_size) - elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \ - and self.decode_token_per_req > 1: - logger.warning( - f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens", - f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size." - ) + new_graph_batch_sizes = [] + for graph_batch_size in self.torchair_graph_batch_sizes: + cur_graph_batch_size = (graph_batch_size + tp_size - + 1) // tp_size * tp_size + # MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size, + # Both adapter multi-dp and FIA operator + if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1: + cur_graph_batch_size = (tp_size * graph_batch_size) \ + // math.gcd(tp_size, graph_batch_size) + if cur_graph_batch_size not in new_graph_batch_sizes and \ + cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: + new_graph_batch_sizes.append(cur_graph_batch_size) + elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \ + and self.decode_token_per_req > 1: + logger.warning( + f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens", + f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size." + ) + new_max_num_reqs = max(new_graph_batch_sizes) + if self.max_num_reqs != new_max_num_reqs: + logger.warning(f"max_num_reqs is updated to {new_max_num_reqs}") + self.max_num_reqs = new_max_num_reqs + self.scheduler_config.max_num_seqs = new_max_num_reqs + + if new_graph_batch_sizes != self.torchair_graph_batch_sizes: + logger.warning( + f"torchair_graph_batch_sizes are updated to {new_graph_batch_sizes}." + ) self.torchair_graph_batch_sizes = new_graph_batch_sizes def _build_drafter_prepare_inputs_torchair_param(self): diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index dee3731..b80fa94 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1905,7 +1905,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): max_query_len = self.uniform_decode_query_len if uniform_decode else \ num_tokens - max_num_reqs = self.scheduler_config.max_num_seqs # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total.