bugfix for mtp fullgraph (#3845)
### What this PR does / why we need it?
bugfix for mtp fullgraph
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19
Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
@@ -3885,7 +3885,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if aclgraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||
aclgraph_runtime_mode = aclgraph_mode.mixed_mode()
|
||||
|
||||
compilation_cases = list(reversed(self.aclgraph_batch_sizes))
|
||||
compilation_cases = sorted(self.aclgraph_batch_sizes)
|
||||
|
||||
try:
|
||||
self._capture_aclgraphs(
|
||||
@@ -3914,14 +3914,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
||||
aclgraph_mode.separate_routine():
|
||||
max_num_tokens = self.scheduler_config.max_num_seqs * \
|
||||
self.uniform_decode_query_len
|
||||
decode_cudagraph_batch_sizes = [
|
||||
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
|
||||
and x >= self.uniform_decode_query_len
|
||||
]
|
||||
compilation_cases_decode = list(
|
||||
reversed(decode_cudagraph_batch_sizes))
|
||||
compilation_cases_decode = sorted(self.aclgraph_batch_sizes)
|
||||
self._capture_aclgraphs(
|
||||
compilation_cases=compilation_cases_decode,
|
||||
aclgraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
|
||||
Reference in New Issue
Block a user