[BugFix] Address PrefillCacheHit state to fix prefix cache accuracy bug (#1498)
When use AscendScheduler with prefix-cache enabled and chunk-prefill disabled, there will be accuray problem because there is no branch in mla_v1 to process this scenario. This PR fixes it. Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -758,7 +758,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
if attn_metadata.attn_state in [
|
if attn_metadata.attn_state in [
|
||||||
AscendAttentionState.ChunkedPrefill,
|
AscendAttentionState.ChunkedPrefill,
|
||||||
AscendAttentionState.SpecDecoding
|
AscendAttentionState.SpecDecoding,
|
||||||
|
AscendAttentionState.PrefillCacheHit
|
||||||
] and not ascend_config.chunked_prefill_for_mla:
|
] and not ascend_config.chunked_prefill_for_mla:
|
||||||
attn_output_torch = torch.empty(num_tokens,
|
attn_output_torch = torch.empty(num_tokens,
|
||||||
self.num_heads * self.v_head_dim,
|
self.num_heads * self.v_head_dim,
|
||||||
@@ -783,7 +784,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
causal=True)
|
causal=True)
|
||||||
elif attn_metadata.attn_state in [
|
elif attn_metadata.attn_state in [
|
||||||
AscendAttentionState.ChunkedPrefill,
|
AscendAttentionState.ChunkedPrefill,
|
||||||
AscendAttentionState.SpecDecoding
|
AscendAttentionState.SpecDecoding,
|
||||||
|
AscendAttentionState.PrefillCacheHit
|
||||||
]:
|
]:
|
||||||
attn_lse = torch.empty(self.num_heads,
|
attn_lse = torch.empty(self.num_heads,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
@@ -835,13 +837,14 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
|
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
|
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
|
||||||
)
|
)
|
||||||
attn_output = attn_output.reshape(
|
attn_output = attn_output.reshape(
|
||||||
[num_tokens, self.num_heads * self.v_head_dim])
|
[num_tokens, self.num_heads * self.v_head_dim])
|
||||||
if attn_metadata.attn_state in [
|
if attn_metadata.attn_state in [
|
||||||
AscendAttentionState.ChunkedPrefill,
|
AscendAttentionState.ChunkedPrefill,
|
||||||
AscendAttentionState.SpecDecoding
|
AscendAttentionState.SpecDecoding,
|
||||||
|
AscendAttentionState.PrefillCacheHit
|
||||||
] and not ascend_config.chunked_prefill_for_mla:
|
] and not ascend_config.chunked_prefill_for_mla:
|
||||||
attn_output = attn_output_torch
|
attn_output = attn_output_torch
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user