[Refactor] Replace npu_ring_mla with FIA in MLA prefill (#5704)
### What this PR does / why we need it? **Refactor: Replace npu_ring_mla with FIA in MLA prefill** This PR refactors the MLA (Multi-Layer Attention) prefill implementation by replacing `npu_ring_mla` with `npu_fused_infer_attention_score` (FIA) operator, unifying the attention backend with the standard attention implementation. **Key changes:** 1. **Core prefill refactoring (`mla_v1.py`)** - Replace `npu_ring_mla` with `npu_fused_infer_attention_score` in `_forward_prefill` and `_compute_prefill_context` - Use TND layout with `softmax_lse_flag=True` for prefill attention - Use `npu_attention_update` to merge multiple chunk outputs with LSE (Log-Sum-Exp) - Change `attn_mask` from `get_final_mla_mask()` to `get_splitfuse_attn_mask()` for FIA compatibility 2. **Data type handling** - Add automatic float16 → bfloat16 conversion (FIA with TND layout only supports bfloat16) - Convert output back to original dtype after FIA computation 3. **Metadata optimization** - Pre-calculate `actual_seq_lengths_q` in `AscendMLAPrefillMetadata` - Pre-calculate `chunk_actual_seq_lengths_kv_list` in `ChunkedContextMetadata` - Move `torch.cumsum` operations from forward pass to metadata building phase 4. **CP compatibility (`mla_cp.py`)** - Add `_ring_mla_mask_builder` to get `npu_ring_mla`-compatible masks for Context Parallel scenarios - Add `chunk_actual_seq_lengths_kv_list` field to `CPChunkedContextMetadata` **Why we need it:** - **Backend unification**: Aligns MLA prefill with standard attention implementation (`attention_v1.py`) - **Better chunked context support**: FIA + `npu_attention_update` provides native LSE-based output merging - **Future compatibility**: Prepares for eventual `npu_ring_mla` removal across the codebase ### Does this PR introduce _any_ user-facing change? **No.** This is a pure refactoring with no functional changes - same behavior, unified backend. --- - Related issue: #5463 (item 7) - vLLM version: v0.14.1 Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
@@ -53,6 +53,7 @@ class CPChunkedContextMetadata:
|
||||
workspace: torch.Tensor
|
||||
chunk_seq_lens: torch.Tensor
|
||||
chunk_seq_lens_npu: torch.Tensor
|
||||
chunk_actual_seq_lengths_kv_list: list[list[int]]
|
||||
# for mla DCP & PCP
|
||||
padded_chunk_seq_lens_npu: torch.Tensor = None
|
||||
padded_local_chunk_seq_lens: list[list[int]] | None = None
|
||||
|
||||
@@ -30,6 +30,7 @@ from vllm_ascend.attention.mla_v1 import (
|
||||
# isort: on
|
||||
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||
AscendPCPMetadata,
|
||||
CPChunkedContextMetadata,
|
||||
@@ -189,6 +190,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
max_seq_lens=chunked_context_metadata.max_seq_lens,
|
||||
chunk_seq_lens=self.chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunked_context_metadata.chunk_seq_lens_npu,
|
||||
chunk_actual_seq_lengths_kv_list=chunked_context_metadata.chunk_actual_seq_lengths_kv_list,
|
||||
workspace=chunked_context_metadata.workspace,
|
||||
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
|
||||
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
|
||||
@@ -276,6 +278,10 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# npu_ring_mla needs bfloat16 512x512 mask, different from FIA's int8 2048x2048 mask
|
||||
# TODO: Remove this when mla_cp.py also migrates to FIA
|
||||
self._ring_mla_mask_builder = AttentionMaskBuilder(torch.device("npu"))
|
||||
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
|
||||
@@ -484,6 +490,10 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
|
||||
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
|
||||
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
|
||||
# Use ring_mla-specific mask (bfloat16, 512x512)
|
||||
# TODO: Remove this when mla_cp.py migrates to FIA
|
||||
ring_mla_mask = self._ring_mla_mask_builder.get_mla_mask(self.vllm_config.model_config.dtype)
|
||||
|
||||
output_head, lse_head = self._attention_with_mask_and_nomask(
|
||||
q_nope=torch.index_select(q_nope, 0, q_head_idx),
|
||||
q_pe=torch.index_select(q_pe, 0, q_head_idx),
|
||||
@@ -494,7 +504,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
kv_nomask_idx=kv_with_q_head_nomask_idx,
|
||||
attn_mask_seqlens=attn_mask_seqlens,
|
||||
attn_nomask_seqlens=head_attn_nomask_seqlens,
|
||||
mask=attn_metadata.attn_mask,
|
||||
mask=ring_mla_mask,
|
||||
)
|
||||
|
||||
output_tail, lse_tail = self._attention_with_mask_and_nomask(
|
||||
@@ -507,7 +517,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
kv_nomask_idx=kv_with_q_tail_nomask_idx,
|
||||
attn_mask_seqlens=attn_mask_seqlens,
|
||||
attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
||||
mask=attn_metadata.attn_mask,
|
||||
mask=ring_mla_mask,
|
||||
)
|
||||
|
||||
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
|
||||
|
||||
Reference in New Issue
Block a user