[Bugfix]fix bug when graph_size is not divisible by tp_size (#2719)

### What this PR does / why we need it?
fix https://github.com/vllm-project/vllm-ascend/issues/2702
- A2: skip graph_size update that makes it to tp_size because
dispatch/combine op support different batch size across EP ranks
- A3: add `max_num_reqs = max(new_graph_batch_sizes)` to fix graph_size
and max_num_reqs mismatch

### Does this PR introduce _any_ user-facing change?
Nope
### How was this patch tested?


- vLLM version: v0.10.1.1
- vLLM main:
e599e2c65e

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2025-09-08 14:52:33 +08:00
committed by GitHub
parent dd087effcc
commit d3c3538ddc
2 changed files with 40 additions and 24 deletions

View File

@@ -41,7 +41,8 @@ from vllm_ascend.torchair.utils import (
register_torchair_model, torchair_ops_patch,
torchair_quant_method_register, write_kv_cache_bytes_to_file)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_310p)
is_310p, get_ascend_soc_version,
AscendSocVersion)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -59,7 +60,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
self.init_torchair_graph_batch_sizes()
self.check_torchair_graph_batch_sizes()
self.update_torchair_graph_batch_sizes()
torch._dynamo.cache_size.config.cache_size_limit += len(
self.torchair_graph_batch_sizes)
@@ -397,7 +398,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
)
def check_torchair_graph_batch_sizes(self):
def update_torchair_graph_batch_sizes(self):
# return graph_batch_sizes according to the max number of tokens
# first pad according to the number of requests
if len(self.torchair_graph_batch_sizes) == 0:
@@ -421,27 +422,43 @@ class NPUTorchairModelRunner(NPUModelRunner):
for graph_batch_size in self.torchair_graph_batch_sizes
]
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
# NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size`
# Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same
# on all EP ranks
if get_ascend_soc_version(
) == AscendSocVersion.A3 and self.parallel_config.enable_expert_parallel:
self._align_graph_size_divisible_by_tp_size()
def _align_graph_size_divisible_by_tp_size(self):
tp_size = self.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel:
new_graph_batch_sizes = []
for graph_batch_size in self.torchair_graph_batch_sizes:
cur_graph_batch_size = (graph_batch_size + tp_size -
1) // tp_size * tp_size
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
# Both adapter multi-dp and FIA operator
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
cur_graph_batch_size = (tp_size * graph_batch_size) \
// math.gcd(tp_size, graph_batch_size)
if cur_graph_batch_size not in new_graph_batch_sizes and \
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
new_graph_batch_sizes.append(cur_graph_batch_size)
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
and self.decode_token_per_req > 1:
logger.warning(
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
)
new_graph_batch_sizes = []
for graph_batch_size in self.torchair_graph_batch_sizes:
cur_graph_batch_size = (graph_batch_size + tp_size -
1) // tp_size * tp_size
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
# Both adapter multi-dp and FIA operator
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
cur_graph_batch_size = (tp_size * graph_batch_size) \
// math.gcd(tp_size, graph_batch_size)
if cur_graph_batch_size not in new_graph_batch_sizes and \
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
new_graph_batch_sizes.append(cur_graph_batch_size)
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
and self.decode_token_per_req > 1:
logger.warning(
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
)
new_max_num_reqs = max(new_graph_batch_sizes)
if self.max_num_reqs != new_max_num_reqs:
logger.warning(f"max_num_reqs is updated to {new_max_num_reqs}")
self.max_num_reqs = new_max_num_reqs
self.scheduler_config.max_num_seqs = new_max_num_reqs
if new_graph_batch_sizes != self.torchair_graph_batch_sizes:
logger.warning(
f"torchair_graph_batch_sizes are updated to {new_graph_batch_sizes}."
)
self.torchair_graph_batch_sizes = new_graph_batch_sizes
def _build_drafter_prepare_inputs_torchair_param(self):

View File

@@ -1905,7 +1905,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
max_query_len = self.uniform_decode_query_len if uniform_decode else \
num_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.