mfix bug when max_seqs=14 in mtp=2 scenario and raise error when cudagraph_capture_sizes can't be an integer multiple of uniform_decode_query_lentp (#3910)
### What this PR does / why we need it? 1. Revert [bugfix for mtp in fullgraph](0948483642) and support it when vllm supports 2. raise error when cudagraph_capture_sizes can't be an integer multiple of uniform_decode_query_len 3. bugfix when max_num_seqs=14 in mtp=2 scenario ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main:83f478bb19--------- Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
@@ -4035,7 +4035,23 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
||||
aclgraph_mode.separate_routine():
|
||||
compilation_cases_decode = sorted(self.aclgraph_batch_sizes)
|
||||
max_num_tokens = self.scheduler_config.max_num_seqs * \
|
||||
self.uniform_decode_query_len
|
||||
decode_cudagraph_batch_sizes = [
|
||||
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
|
||||
and x >= self.uniform_decode_query_len
|
||||
]
|
||||
compilation_cases_decode = sorted(decode_cudagraph_batch_sizes)
|
||||
# TODO: refactor this when vLLM supports mtp>1
|
||||
if not all(x % self.uniform_decode_query_len == 0
|
||||
for x in decode_cudagraph_batch_sizes):
|
||||
raise ValueError(
|
||||
"In the MTP fullgraph scenario, each graph size must be an integer multiple of "
|
||||
f"(num_speculative_tokens + 1): {self.uniform_decode_query_len}. "
|
||||
f"Please modify the cudagraph_capture_sizes variable to be integer multiple of {self.uniform_decode_query_len}, "
|
||||
f"while ensuring the maximum cudagraph_capture_sizes does not exceed max_num_seqs * (num_speculative_tokens + 1): {max_num_tokens}. "
|
||||
"For example, with MTP=2 and max_num_seqs=16, we recommend setting cudagraph_capture_sizes to [48]."
|
||||
)
|
||||
self._capture_aclgraphs(
|
||||
compilation_cases=compilation_cases_decode,
|
||||
aclgraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
|
||||
Reference in New Issue
Block a user