From f286265791cc6a770ce07a48daf6dab096c0cc3c Mon Sep 17 00:00:00 2001 From: whx <56632993+whx-sjtu@users.noreply.github.com> Date: Mon, 30 Jun 2025 16:51:20 +0800 Subject: [PATCH] [BugFix] Address PrefillCacheHit state to fix prefix cache accuracy bug (#1498) When use AscendScheduler with prefix-cache enabled and chunk-prefill disabled, there will be accuray problem because there is no branch in mla_v1 to process this scenario. This PR fixes it. Signed-off-by: whx-sjtu <2952154980@qq.com> --- vllm_ascend/attention/mla_v1.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index d9471b2..ad1fd7a 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -758,7 +758,8 @@ class AscendMLAImpl(MLAAttentionImpl): if attn_metadata.attn_state in [ AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit ] and not ascend_config.chunked_prefill_for_mla: attn_output_torch = torch.empty(num_tokens, self.num_heads * self.v_head_dim, @@ -783,7 +784,8 @@ class AscendMLAImpl(MLAAttentionImpl): causal=True) elif attn_metadata.attn_state in [ AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit ]: attn_lse = torch.empty(self.num_heads, num_tokens, @@ -835,13 +837,14 @@ class AscendMLAImpl(MLAAttentionImpl): attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim) else: raise RuntimeError( - "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !" + "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !" ) attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) if attn_metadata.attn_state in [ AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit ] and not ascend_config.chunked_prefill_for_mla: attn_output = attn_output_torch