diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 5a8f6a10..768b9082 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -379,6 +379,7 @@ class AscendAttentionBackendImpl(AttentionImpl): vllm_config, speculative_config=None, num_dcp_pcp_tokens=None, + draft_attn_metadatas=None, ): if using_paged_attention(num_tokens, vllm_config): # Paged Attention update logic @@ -436,7 +437,7 @@ class AscendAttentionBackendImpl(AttentionImpl): # FIA update logic if forward_context.is_draft_model: graph_params = get_draft_graph_params() - attn_metadata = forward_context.draft_attn_metadatas + attn_metadata = draft_attn_metadatas attn_keys = list(attn_metadata[0].keys()) else: graph_params = get_graph_params() diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index 04f3d956..f97798f3 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -281,9 +281,10 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): update_stream, forward_context, num_tokens, - vllm_config, + vllm_config=None, speculative_config=None, num_dcp_pcp_tokens=None, + draft_attn_metadatas=None, ): graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index e0cd7998..a53dfb58 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -292,6 +292,7 @@ class AscendMlaCPImpl(AscendMLAImpl): vllm_config=None, speculative_config=None, num_dcp_pcp_tokens=None, + draft_attn_metadatas=None, ): if forward_context.is_draft_model: graph_params = get_draft_graph_params() diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 4c8831bf..5414ee58 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -733,6 +733,7 @@ class AscendMLAImpl(MLAAttentionImpl): vllm_config=None, speculative_config=None, num_dcp_pcp_tokens=None, + draft_attn_metadatas=None, ): if forward_context.is_draft_model: graph_params = get_draft_graph_params() diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 18db774c..ed8673d2 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -218,6 +218,7 @@ def update_full_graph_params( vllm_config, speculative_config=None, num_dcp_pcp_tokens=None, + draft_attn_metadatas=None, ): impl_cls = attn_backend.get_impl_cls() impl_cls.update_graph_params( @@ -227,6 +228,7 @@ def update_full_graph_params( vllm_config, speculative_config, num_dcp_pcp_tokens, + draft_attn_metadatas, ) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 0f6e7c3c..505326c8 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -1184,7 +1184,8 @@ class EagleProposer(VllmEagleProposer): def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None): update_full_graph_params( self.runner.attn_backend, self.update_stream, forward_context, num_tokens, - self.vllm_config, self.vllm_config.speculative_config) + self.vllm_config, self.vllm_config.speculative_config, + draft_attn_metadatas=draft_attn_metadatas) # padding tensor into desired size def _pad_tensor(self, tensor, pad_size):