[Bugfix] Add missing draft_attn_metadatas parameter to fix MTP test (#6232)
### What this PR does / why we need it? Fix the MTP test failure caused by accessing non-existent attribute `forward_context.draft_attn_metadatas`. **Root cause:** In `AscendAttentionBackendImpl.update_graph_params`, the code incorrectly accessed `forward_context.draft_attn_metadatas`, but `ForwardContext` class doesn't have this attribute. The original code passed this value via function parameter. **Fix:** Add `draft_attn_metadatas` parameter to the entire call chain: - `update_full_graph_params` function in `acl_graph.py` - All `update_graph_params` methods in attention backends - Pass the parameter correctly in `eagle_proposer.py` Also applied Gemini's suggestion to make `vllm_config=None` in `AscendAttentionCPImpl.update_graph_params` for API consistency. Related to item 9 in #5463 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This fixes the CI test failure: `test_deepseek_mtp_correctness[True-FULL_DECODE_ONLY-2-wemaster/deepseek_mtp_main_random_bf16]` Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
@@ -379,6 +379,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
vllm_config,
|
vllm_config,
|
||||||
speculative_config=None,
|
speculative_config=None,
|
||||||
num_dcp_pcp_tokens=None,
|
num_dcp_pcp_tokens=None,
|
||||||
|
draft_attn_metadatas=None,
|
||||||
):
|
):
|
||||||
if using_paged_attention(num_tokens, vllm_config):
|
if using_paged_attention(num_tokens, vllm_config):
|
||||||
# Paged Attention update logic
|
# Paged Attention update logic
|
||||||
@@ -436,7 +437,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
# FIA update logic
|
# FIA update logic
|
||||||
if forward_context.is_draft_model:
|
if forward_context.is_draft_model:
|
||||||
graph_params = get_draft_graph_params()
|
graph_params = get_draft_graph_params()
|
||||||
attn_metadata = forward_context.draft_attn_metadatas
|
attn_metadata = draft_attn_metadatas
|
||||||
attn_keys = list(attn_metadata[0].keys())
|
attn_keys = list(attn_metadata[0].keys())
|
||||||
else:
|
else:
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
|
|||||||
@@ -281,9 +281,10 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
update_stream,
|
update_stream,
|
||||||
forward_context,
|
forward_context,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
vllm_config,
|
vllm_config=None,
|
||||||
speculative_config=None,
|
speculative_config=None,
|
||||||
num_dcp_pcp_tokens=None,
|
num_dcp_pcp_tokens=None,
|
||||||
|
draft_attn_metadatas=None,
|
||||||
):
|
):
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||||
|
|||||||
@@ -292,6 +292,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
vllm_config=None,
|
vllm_config=None,
|
||||||
speculative_config=None,
|
speculative_config=None,
|
||||||
num_dcp_pcp_tokens=None,
|
num_dcp_pcp_tokens=None,
|
||||||
|
draft_attn_metadatas=None,
|
||||||
):
|
):
|
||||||
if forward_context.is_draft_model:
|
if forward_context.is_draft_model:
|
||||||
graph_params = get_draft_graph_params()
|
graph_params = get_draft_graph_params()
|
||||||
|
|||||||
@@ -733,6 +733,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
vllm_config=None,
|
vllm_config=None,
|
||||||
speculative_config=None,
|
speculative_config=None,
|
||||||
num_dcp_pcp_tokens=None,
|
num_dcp_pcp_tokens=None,
|
||||||
|
draft_attn_metadatas=None,
|
||||||
):
|
):
|
||||||
if forward_context.is_draft_model:
|
if forward_context.is_draft_model:
|
||||||
graph_params = get_draft_graph_params()
|
graph_params = get_draft_graph_params()
|
||||||
|
|||||||
@@ -218,6 +218,7 @@ def update_full_graph_params(
|
|||||||
vllm_config,
|
vllm_config,
|
||||||
speculative_config=None,
|
speculative_config=None,
|
||||||
num_dcp_pcp_tokens=None,
|
num_dcp_pcp_tokens=None,
|
||||||
|
draft_attn_metadatas=None,
|
||||||
):
|
):
|
||||||
impl_cls = attn_backend.get_impl_cls()
|
impl_cls = attn_backend.get_impl_cls()
|
||||||
impl_cls.update_graph_params(
|
impl_cls.update_graph_params(
|
||||||
@@ -227,6 +228,7 @@ def update_full_graph_params(
|
|||||||
vllm_config,
|
vllm_config,
|
||||||
speculative_config,
|
speculative_config,
|
||||||
num_dcp_pcp_tokens,
|
num_dcp_pcp_tokens,
|
||||||
|
draft_attn_metadatas,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1184,7 +1184,8 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None):
|
def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None):
|
||||||
update_full_graph_params(
|
update_full_graph_params(
|
||||||
self.runner.attn_backend, self.update_stream, forward_context, num_tokens,
|
self.runner.attn_backend, self.update_stream, forward_context, num_tokens,
|
||||||
self.vllm_config, self.vllm_config.speculative_config)
|
self.vllm_config, self.vllm_config.speculative_config,
|
||||||
|
draft_attn_metadatas=draft_attn_metadatas)
|
||||||
|
|
||||||
# padding tensor into desired size
|
# padding tensor into desired size
|
||||||
def _pad_tensor(self, tensor, pad_size):
|
def _pad_tensor(self, tensor, pad_size):
|
||||||
|
|||||||
Reference in New Issue
Block a user