fix bug when max_seqs=14 in mtp=2 scenario and raise error when cudagraph_capture_sizes can't be an integer multiple of uniform_decode_query_len (#3909)

### What this PR does / why we need it?
1. Revert [bugfix for mtp in
fullgraph](0948483642)
and support it when vllm supports
2. raise error when cudagraph_capture_sizes can't be an integer multiple
of uniform_decode_query_len
3. bugfix when max_num_seqs=14 in mtp=2 scenario

---------

Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
zouyida2052
2025-10-31 09:25:06 +08:00
committed by GitHub
parent 387ce1cc5b
commit 90aca84e60
4 changed files with 59 additions and 54 deletions

View File

@@ -263,7 +263,6 @@ class NPUPlatform(Platform):
**********************************************************************************\033[0m
"""
logger.warning(warning_message)
update_aclgraph_sizes(vllm_config)
else:
logger.info(
"%s cudagraph_mode is not support on NPU. falling back to NONE",

View File

@@ -103,10 +103,20 @@ class NPUTorchairModelRunner(NPUModelRunner):
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
tp_size = self.parallel_config.tensor_parallel_size
# Use integer arithmetic for ceiling division.
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
# maintain the same calculation logic as the function _align_graph_size_divisible_by_tp_size()
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
max_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
max_num_tokens, tp_size)
self.mc2_tokens_capacity = max_graph_batch_size
if get_ascend_soc_version(
) == AscendSocVersion.A3 and self.mc2_tokens_capacity > 512:
logger.error(
f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}"
)
if get_ascend_soc_version(
) == AscendSocVersion.A2 and self.mc2_tokens_capacity > 256:
logger.error(
f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}"
)
def _sync_metadata_across_dp(
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
@@ -460,6 +470,17 @@ class NPUTorchairModelRunner(NPUModelRunner):
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
)
def calculate_new_torchair_graph_batch_size(self, old_graph_batch_size,
tp_size):
cur_graph_batch_size = (old_graph_batch_size + tp_size -
1) // tp_size * tp_size
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
# Both adapter multi-dp and FIA operator
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
cur_graph_batch_size = (tp_size * old_graph_batch_size) \
// math.gcd(tp_size, old_graph_batch_size)
return cur_graph_batch_size
def update_torchair_graph_batch_sizes(self):
# return graph_batch_sizes according to the max number of tokens
# first pad according to the number of requests
@@ -501,13 +522,8 @@ class NPUTorchairModelRunner(NPUModelRunner):
tp_size = self.parallel_config.tensor_parallel_size
new_graph_batch_sizes = []
for graph_batch_size in self.torchair_graph_batch_sizes:
cur_graph_batch_size = (graph_batch_size + tp_size -
1) // tp_size * tp_size
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
# Both adapter multi-dp and FIA operator
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
cur_graph_batch_size = (tp_size * graph_batch_size) \
// math.gcd(tp_size, graph_batch_size)
cur_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
graph_batch_size, tp_size)
if cur_graph_batch_size not in new_graph_batch_sizes and \
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
new_graph_batch_sizes.append(cur_graph_batch_size)

View File

@@ -314,13 +314,6 @@ def get_max_hidden_layers(hf_config) -> int:
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
"""Update ACL graph capture sizes based on hardware limitations"""
from vllm.config.compilation import CUDAGraphMode
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
if vllm_config.speculative_config is not None and \
vllm_config.speculative_config.num_speculative_tokens > 1:
_update_spec_aclgraph_sizes(vllm_config)
return
# NOTE: Currently, we can only capture 1800 graphs at most,
# due to the limitation of ACL graph. This number is bounded by
# the number of streams, which is 2048, we save 248 streams
@@ -428,43 +421,25 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
vllm_config.model_config.architectures[0], num_hidden_layers,
len(original_sizes))
if vllm_config.speculative_config is not None and \
vllm_config.speculative_config.num_speculative_tokens > 1:
_update_spec_aclgraph_sizes(vllm_config)
def _update_spec_aclgraph_sizes(vllm_config: VllmConfig) -> None:
# default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario
# the maximum size cudagraph_capture_sizes[0] should be greater or equal than
# (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode
from vllm.config.compilation import CUDAGraphMode
compilation_config = vllm_config.compilation_config
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
uniform_decode_query_len = num_speculative_tokens + 1
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
max_num_tokens = max_num_seqs * uniform_decode_query_len
original_sizes, compilation_config.cudagraph_capture_sizes = \
compilation_config.cudagraph_capture_sizes, None
assert len(original_sizes) > 0
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and \
not all(size % uniform_decode_query_len == 0 for size in original_sizes):
enlarged_sizes = [
size * uniform_decode_query_len for size in original_sizes
if max_num_tokens >= size >= uniform_decode_query_len
]
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
logger.info("Adjusted ACL graphs: %s%s for speculative decoding",
original_sizes, enlarged_sizes)
elif original_sizes[0] < max_num_tokens:
enlarged_sizes = [
size * uniform_decode_query_len for size in original_sizes
]
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
logger.info("Adjusted ACL graphs: %s%s for speculative decoding",
original_sizes, enlarged_sizes)
else:
compilation_config.cudagraph_capture_sizes = original_sizes
if vllm_config.speculative_config is not None and \
vllm_config.speculative_config.num_speculative_tokens > 1:
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
original_sizes, compilation_config.cudagraph_capture_sizes = \
compilation_config.cudagraph_capture_sizes, None
assert len(original_sizes) > 0
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
enlarged_sizes = [(num_speculative_tokens + 1) * size
for size in original_sizes]
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
logger.info(
"Adjusted ACL graphs: %s%s for speculative decoding",
original_sizes, enlarged_sizes)
else:
compilation_config.cudagraph_capture_sizes = original_sizes
# TODO(wxy): Move to ops module

View File

@@ -3529,8 +3529,23 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
aclgraph_mode.separate_routine():
max_num_tokens = self.scheduler_config.max_num_seqs * \
self.uniform_decode_query_len
decode_cudagraph_batch_sizes = [
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
and x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(
reversed(self.aclgraph_batch_sizes))
reversed(decode_cudagraph_batch_sizes))
if not all(x % self.uniform_decode_query_len == 0
for x in decode_cudagraph_batch_sizes):
raise ValueError(
"In the MTP fullgraph scenario, each graph size must be an integer multiple of "
f"(num_speculative_tokens + 1): {self.uniform_decode_query_len}. "
f"Please modify the cudagraph_capture_sizes variable to be integer multiple of {self.uniform_decode_query_len}, "
f"while ensuring the maximum cudagraph_capture_sizes does not exceed max_num_seqs * (num_speculative_tokens + 1): {max_num_tokens}. "
"For example, with MTP=2 and max_num_seqs=16, we recommend setting cudagraph_capture_sizes to [48]."
)
self._capture_aclgraphs(
compilation_cases=compilation_cases_decode,
aclgraph_runtime_mode=CUDAGraphMode.FULL,