fix bug when max_seqs=14 in mtp=2 scenario and raise error when cudagraph_capture_sizes can't be an integer multiple of uniform_decode_query_len (#3909)
### What this PR does / why we need it?
1. Revert [bugfix for mtp in
fullgraph](0948483642)
and support it when vllm supports
2. raise error when cudagraph_capture_sizes can't be an integer multiple
of uniform_decode_query_len
3. bugfix when max_num_seqs=14 in mtp=2 scenario
---------
Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
@@ -314,13 +314,6 @@ def get_max_hidden_layers(hf_config) -> int:
|
||||
|
||||
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
"""Update ACL graph capture sizes based on hardware limitations"""
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
if vllm_config.speculative_config is not None and \
|
||||
vllm_config.speculative_config.num_speculative_tokens > 1:
|
||||
_update_spec_aclgraph_sizes(vllm_config)
|
||||
return
|
||||
|
||||
# NOTE: Currently, we can only capture 1800 graphs at most,
|
||||
# due to the limitation of ACL graph. This number is bounded by
|
||||
# the number of streams, which is 2048, we save 248 streams
|
||||
@@ -428,43 +421,25 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
vllm_config.model_config.architectures[0], num_hidden_layers,
|
||||
len(original_sizes))
|
||||
|
||||
if vllm_config.speculative_config is not None and \
|
||||
vllm_config.speculative_config.num_speculative_tokens > 1:
|
||||
_update_spec_aclgraph_sizes(vllm_config)
|
||||
|
||||
|
||||
def _update_spec_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
# default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario
|
||||
# the maximum size cudagraph_capture_sizes[0] should be greater or equal than
|
||||
# (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
compilation_config = vllm_config.compilation_config
|
||||
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
|
||||
uniform_decode_query_len = num_speculative_tokens + 1
|
||||
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_tokens = max_num_seqs * uniform_decode_query_len
|
||||
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
||||
compilation_config.cudagraph_capture_sizes, None
|
||||
assert len(original_sizes) > 0
|
||||
|
||||
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and \
|
||||
not all(size % uniform_decode_query_len == 0 for size in original_sizes):
|
||||
enlarged_sizes = [
|
||||
size * uniform_decode_query_len for size in original_sizes
|
||||
if max_num_tokens >= size >= uniform_decode_query_len
|
||||
]
|
||||
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
|
||||
logger.info("Adjusted ACL graphs: %s → %s for speculative decoding",
|
||||
original_sizes, enlarged_sizes)
|
||||
elif original_sizes[0] < max_num_tokens:
|
||||
enlarged_sizes = [
|
||||
size * uniform_decode_query_len for size in original_sizes
|
||||
]
|
||||
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
|
||||
logger.info("Adjusted ACL graphs: %s → %s for speculative decoding",
|
||||
original_sizes, enlarged_sizes)
|
||||
else:
|
||||
compilation_config.cudagraph_capture_sizes = original_sizes
|
||||
if vllm_config.speculative_config is not None and \
|
||||
vllm_config.speculative_config.num_speculative_tokens > 1:
|
||||
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
|
||||
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
||||
compilation_config.cudagraph_capture_sizes, None
|
||||
assert len(original_sizes) > 0
|
||||
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
|
||||
enlarged_sizes = [(num_speculative_tokens + 1) * size
|
||||
for size in original_sizes]
|
||||
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
|
||||
logger.info(
|
||||
"Adjusted ACL graphs: %s → %s for speculative decoding",
|
||||
original_sizes, enlarged_sizes)
|
||||
else:
|
||||
compilation_config.cudagraph_capture_sizes = original_sizes
|
||||
|
||||
|
||||
# TODO(wxy): Move to ops module
|
||||
|
||||
Reference in New Issue
Block a user