From 153eeaa6212503bc49cba439e04bf1c8ea71a7b1 Mon Sep 17 00:00:00 2001 From: Wang Yixuan <88923622+hust17yixuan@users.noreply.github.com> Date: Wed, 17 Dec 2025 09:20:44 +0800 Subject: [PATCH] [Bugfix] Fix DeepSeek FIA error in async_scheduling with mtp (#5046) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? When enable the async_scheduling, in large scale EP scene, mtp module goes to eagler mode, which results in the mismatch of seq_lens_list、block_table. So adapt the judgement before the draft model forward. fix #4986 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: hust17yixuan <303660421@qq.com> --- vllm_ascend/spec_decode/mtp_proposer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 8cb46fa2..a152aa47 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -725,7 +725,6 @@ class MtpProposer(Proposer): has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0 aclgraph_runtime_mode, batch_descriptor = \ self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora) - original_aclgraph_runtime_mode = aclgraph_runtime_mode if self.use_async_scheduling: # there is synchronization between mtp steps when enabling aclgraph, # disable aclgraph when use async scheduling to avoid the @@ -779,8 +778,8 @@ class MtpProposer(Proposer): hidden_states = torch.ops.vllm.maybe_pad_and_reduce( hidden_states) - if original_aclgraph_runtime_mode == CUDAGraphMode.FULL and \ - self.use_async_scheduling and attn_metadata[layer_name].decode is not None: + if self.use_async_scheduling and attn_metadata[ + layer_name].decode is not None: for layer_name in self.attn_layer_name: actual_size = len(attn_metadata[layer_name].decode. actual_seq_lengths_q)