diff --git a/vllm_ascend/attention/attention_cp.py b/vllm_ascend/attention/attention_cp.py index 22c58369..9a426ec2 100644 --- a/vllm_ascend/attention/attention_cp.py +++ b/vllm_ascend/attention/attention_cp.py @@ -549,6 +549,11 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score( query, k_nope, value, **common_kwargs) + out_mask = attn_metadata.decode_meta.batch_seq_mask[:, None, + None].expand_as( + attn_out) + attn_out = torch.where(out_mask, 0, attn_out) + lse_mask = attn_metadata.decode_meta.batch_seq_mask[:, None, None].expand_as( attn_lse)