[v0.18.0][BugFix]Revert the code: Replace npu_ring_mla wit FIA with MLA prefill. (#7961)

This pull request reverts previous changes to switch to FIA and instead
implements npu_ring_mla for MLA prefill operations(#5704 ). The change
streamlines the attention mechanism by removing unnecessary metadata
tracking and updating the underlying NPU operations to use the
ring-based MLA kernel. This adjustment ensures better compatibility and
performance for MLA prefill tasks within the vLLM Ascend backend.

Highlights

- Migration to npu_ring_mla: Replaced the usage of
npu_fused_infer_attention_score (FIA) with npu_ring_mla for MLA prefill
operations across the codebase to improve performance and alignment with
the intended architecture.
- Cleanup of redundant metadata: Removed
chunk_actual_seq_lengths_kv_list and actual_seq_lengths_q from various
metadata structures as they are no longer required for the updated
attention implementation.
- Test suite updates: Updated unit tests in test_mla_cp.py and
test_mla_v1.py to mock npu_ring_mla instead of the deprecated FIA
functions and adjusted test assertions to reflect the new implementation
details.

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2026-04-09 17:00:25 +08:00
committed by GitHub
parent 7c9aa498d6
commit f668ff9ef0
5 changed files with 73 additions and 151 deletions

View File

@@ -30,7 +30,6 @@ from vllm_ascend.attention.mla_v1 import (
# isort: on
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.context_parallel.common_cp import (
AscendPCPMetadata,
CPChunkedContextMetadata,
@@ -190,7 +189,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
max_seq_lens=chunked_context_metadata.max_seq_lens,
chunk_seq_lens=self.chunk_seq_lens,
chunk_seq_lens_npu=chunked_context_metadata.chunk_seq_lens_npu,
chunk_actual_seq_lengths_kv_list=chunked_context_metadata.chunk_actual_seq_lengths_kv_list,
workspace=chunked_context_metadata.workspace,
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
@@ -278,10 +276,6 @@ class AscendMlaCPImpl(AscendMLAImpl):
**kwargs,
)
# npu_ring_mla needs bfloat16 512x512 mask, different from FIA's int8 2048x2048 mask
# TODO: Remove this when mla_cp.py also migrates to FIA
self._ring_mla_mask_builder = AttentionMaskBuilder(torch.device("npu"))
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
@@ -490,10 +484,6 @@ class AscendMlaCPImpl(AscendMLAImpl):
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
# Use ring_mla-specific mask (bfloat16, 512x512)
# TODO: Remove this when mla_cp.py migrates to FIA
ring_mla_mask = self._ring_mla_mask_builder.get_mla_mask(self.vllm_config.model_config.dtype)
output_head, lse_head = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_head_idx),
q_pe=torch.index_select(q_pe, 0, q_head_idx),
@@ -504,7 +494,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
kv_nomask_idx=kv_with_q_head_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=head_attn_nomask_seqlens,
mask=ring_mla_mask,
mask=attn_metadata.attn_mask,
)
output_tail, lse_tail = self._attention_with_mask_and_nomask(
@@ -517,7 +507,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
kv_nomask_idx=kv_with_q_tail_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=tail_attn_nomask_seqlens,
mask=ring_mla_mask,
mask=attn_metadata.attn_mask,
)
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx