diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py index c52ef569..3047b3d9 100644 --- a/tests/ut/spec_decode/test_mtp_proposer.py +++ b/tests/ut/spec_decode/test_mtp_proposer.py @@ -97,7 +97,6 @@ class TestMtpProposer: proposer = MtpProposer(vllm_config, torch.device("cpu"), runner) assert proposer.use_aclgraph is True - assert proposer.cudagraph_batch_sizes == [1, 2, 4, 8] @patch("vllm.config.get_layers_from_vllm_config") @patch("vllm_ascend.spec_decode.mtp_proposer.get_model_loader") diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 4deef9a2..14a61a79 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -107,11 +107,6 @@ class MtpProposer(Proposer): self.use_aclgraph = self.runner._use_aclgraph() - self.cudagraph_batch_sizes = (list( - sorted( - self.vllm_config.compilation_config.cudagraph_capture_sizes)) - if self.use_aclgraph else []) - # persistent buffers for aclgraph graph self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, @@ -697,11 +692,13 @@ class MtpProposer(Proposer): assert self.runner is not None - if self.runner.use_aclgraph and num_scheduled_tokens <= self.cudagraph_batch_sizes[ + # Note(qcs): We may need to refactor these check logics. + if self.runner.use_aclgraph and num_scheduled_tokens <= self.runner.cudagraph_batch_sizes[ -1]: num_input_tokens = self.vllm_config.pad_for_cudagraph( num_scheduled_tokens) - elif self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]: + elif self.use_aclgraph and num_tokens <= self.runner.cudagraph_batch_sizes[ + -1]: # Acl graph mode, add padding to the batch size num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: