From 379ce599d0de92bdd4faf964d8c11f314f5fd03c Mon Sep 17 00:00:00 2001 From: LICO67373 <110013619+LICO1314@users.noreply.github.com> Date: Wed, 28 Jan 2026 14:41:18 +0800 Subject: [PATCH] [Bugfix] Add missing draft_attn_metadatas parameter to fix MTP test (#6232) ### What this PR does / why we need it? Fix the MTP test failure caused by accessing non-existent attribute `forward_context.draft_attn_metadatas`. **Root cause:** In `AscendAttentionBackendImpl.update_graph_params`, the code incorrectly accessed `forward_context.draft_attn_metadatas`, but `ForwardContext` class doesn't have this attribute. The original code passed this value via function parameter. **Fix:** Add `draft_attn_metadatas` parameter to the entire call chain: - `update_full_graph_params` function in `acl_graph.py` - All `update_graph_params` methods in attention backends - Pass the parameter correctly in `eagle_proposer.py` Also applied Gemini's suggestion to make `vllm_config=None` in `AscendAttentionCPImpl.update_graph_params` for API consistency. Related to item 9 in #5463 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This fixes the CI test failure: `test_deepseek_mtp_correctness[True-FULL_DECODE_ONLY-2-wemaster/deepseek_mtp_main_random_bf16]` Signed-off-by: lico67373 <918688502@qq.com> --- vllm_ascend/attention/attention_v1.py | 3 ++- vllm_ascend/attention/context_parallel/attention_cp.py | 3 ++- vllm_ascend/attention/context_parallel/mla_cp.py | 1 + vllm_ascend/attention/mla_v1.py | 1 + vllm_ascend/compilation/acl_graph.py | 2 ++ vllm_ascend/spec_decode/eagle_proposer.py | 3 ++- 6 files changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 5a8f6a10..768b9082 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -379,6 +379,7 @@ class AscendAttentionBackendImpl(AttentionImpl): vllm_config, speculative_config=None, num_dcp_pcp_tokens=None, + draft_attn_metadatas=None, ): if using_paged_attention(num_tokens, vllm_config): # Paged Attention update logic @@ -436,7 +437,7 @@ class AscendAttentionBackendImpl(AttentionImpl): # FIA update logic if forward_context.is_draft_model: graph_params = get_draft_graph_params() - attn_metadata = forward_context.draft_attn_metadatas + attn_metadata = draft_attn_metadatas attn_keys = list(attn_metadata[0].keys()) else: graph_params = get_graph_params() diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index 04f3d956..f97798f3 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -281,9 +281,10 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): update_stream, forward_context, num_tokens, - vllm_config, + vllm_config=None, speculative_config=None, num_dcp_pcp_tokens=None, + draft_attn_metadatas=None, ): graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index e0cd7998..a53dfb58 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -292,6 +292,7 @@ class AscendMlaCPImpl(AscendMLAImpl): vllm_config=None, speculative_config=None, num_dcp_pcp_tokens=None, + draft_attn_metadatas=None, ): if forward_context.is_draft_model: graph_params = get_draft_graph_params() diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 4c8831bf..5414ee58 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -733,6 +733,7 @@ class AscendMLAImpl(MLAAttentionImpl): vllm_config=None, speculative_config=None, num_dcp_pcp_tokens=None, + draft_attn_metadatas=None, ): if forward_context.is_draft_model: graph_params = get_draft_graph_params() diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 18db774c..ed8673d2 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -218,6 +218,7 @@ def update_full_graph_params( vllm_config, speculative_config=None, num_dcp_pcp_tokens=None, + draft_attn_metadatas=None, ): impl_cls = attn_backend.get_impl_cls() impl_cls.update_graph_params( @@ -227,6 +228,7 @@ def update_full_graph_params( vllm_config, speculative_config, num_dcp_pcp_tokens, + draft_attn_metadatas, ) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 0f6e7c3c..505326c8 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -1184,7 +1184,8 @@ class EagleProposer(VllmEagleProposer): def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None): update_full_graph_params( self.runner.attn_backend, self.update_stream, forward_context, num_tokens, - self.vllm_config, self.vllm_config.speculative_config) + self.vllm_config, self.vllm_config.speculative_config, + draft_attn_metadatas=draft_attn_metadatas) # padding tensor into desired size def _pad_tensor(self, tensor, pad_size):