[Refactor] Replace npu_ring_mla with FIA in MLA prefill (#5704)

### What this PR does / why we need it?

**Refactor: Replace npu_ring_mla with FIA in MLA prefill**

This PR refactors the MLA (Multi-Layer Attention) prefill implementation
by replacing `npu_ring_mla` with `npu_fused_infer_attention_score` (FIA)
operator, unifying the attention backend with the standard attention
implementation.

**Key changes:**

1. **Core prefill refactoring (`mla_v1.py`)**
- Replace `npu_ring_mla` with `npu_fused_infer_attention_score` in
`_forward_prefill` and `_compute_prefill_context`
   - Use TND layout with `softmax_lse_flag=True` for prefill attention
- Use `npu_attention_update` to merge multiple chunk outputs with LSE
(Log-Sum-Exp)
- Change `attn_mask` from `get_final_mla_mask()` to
`get_splitfuse_attn_mask()` for FIA compatibility

2. **Data type handling**
- Add automatic float16 → bfloat16 conversion (FIA with TND layout only
supports bfloat16)
   - Convert output back to original dtype after FIA computation

3. **Metadata optimization**
   - Pre-calculate `actual_seq_lengths_q` in `AscendMLAPrefillMetadata`
- Pre-calculate `chunk_actual_seq_lengths_kv_list` in
`ChunkedContextMetadata`
- Move `torch.cumsum` operations from forward pass to metadata building
phase

4. **CP compatibility (`mla_cp.py`)**
- Add `_ring_mla_mask_builder` to get `npu_ring_mla`-compatible masks
for Context Parallel scenarios
- Add `chunk_actual_seq_lengths_kv_list` field to
`CPChunkedContextMetadata`

**Why we need it:**
- **Backend unification**: Aligns MLA prefill with standard attention
implementation (`attention_v1.py`)
- **Better chunked context support**: FIA + `npu_attention_update`
provides native LSE-based output merging
- **Future compatibility**: Prepares for eventual `npu_ring_mla` removal
across the codebase

### Does this PR introduce _any_ user-facing change?

**No.** This is a pure refactoring with no functional changes - same
behavior, unified backend.

---
- Related issue: #5463 (item 7)
- vLLM version: v0.14.1

Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
LICO67373
2026-03-16 10:33:09 +08:00
committed by GitHub
parent e20f0b1a0d
commit 71c21f76f5
6 changed files with 183 additions and 79 deletions

View File

@@ -85,8 +85,8 @@ CASE_DS_FULL_DECODE_ONLY = LLMTestCase(
prompts=PROMPTS_LONG, prompts=PROMPTS_LONG,
golden_answers=[ golden_answers=[
"\n\nSelect an assignment template", "\n\nSelect an assignment template",
"\n\nI'm not sure how to approach this problem. I'm not sure if I should use the law of total probability or if I should use", "\n\nI'm not sure how to approach this problem. I'm thinking that the area of the triangle is $1/2$ times the area",
"\n\n## Answer\n\n$a + b + c = 0$\n\nSolution\n\nLet $x$ be the common root of the equations", "\n\n## Answer\n\n$a + b + c = 0$\n\nSolution\n\nLet $x = \\alpha$ be the common root",
], ],
) )
@@ -106,8 +106,8 @@ CASE_DS_EX = LLMTestCase(
prompts=PROMPTS_LONG, prompts=PROMPTS_LONG,
golden_answers=[ golden_answers=[
"\n\nSelect an assignment template", "\n\nSelect an assignment template",
"\n\nI'm not sure how to approach this problem. I'm not sure if I should use the law of total probability or if I should use", "\n\nI'm not sure how to approach this problem. I'm thinking that the area of the triangle is $1/2$ times the area",
"\n\n## Answer\n\n$a + b + c = 0$\n\nSolution\n\nLet $x$ be the common root of the equations", "\n\n## Answer\n\n$a + b + c = 0$\n\nSolution\n\nLet $x = \\alpha$ be the common root",
], ],
) )

View File

@@ -130,6 +130,10 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:], out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
dtype=torch.int32, dtype=torch.int32,
) )
chunk_actual_seq_lengths_kv_list = [
torch.cumsum(chunk_seq_lens[i], dim=0).tolist()
for i in range(num_chunks)
]
chunked_context_metadata = CPChunkedContextMetadata( chunked_context_metadata = CPChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True), cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
starts=local_chunk_starts.to(non_blocking=True), starts=local_chunk_starts.to(non_blocking=True),
@@ -137,6 +141,7 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens, chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens, chunk_seq_lens_npu=chunk_seq_lens,
chunk_actual_seq_lengths_kv_list=chunk_actual_seq_lengths_kv_list,
workspace=None, workspace=None,
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens, padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens,
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
@@ -500,19 +505,23 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(result.shape[1], self.impl.v_head_dim) self.assertEqual(result.shape[1], self.impl.v_head_dim)
@patch("torch_npu.atb.npu_paged_cache_load") @patch("torch_npu.atb.npu_paged_cache_load")
@patch("torch_npu.atb.npu_ring_mla") @patch("torch_npu.npu_attention_update")
@patch("torch_npu.npu_fused_infer_attention_score")
@patch_distributed_groups(dcp_size=2, pcp_size=2) @patch_distributed_groups(dcp_size=2, pcp_size=2)
def test_compute_prefill_context_with_dcp_pcp(self, mock_all2all, mock_dcp, def test_compute_prefill_context_with_dcp_pcp(self, mock_all2all, mock_dcp,
mock_pcp, mock_ring, mock_pcp, mock_fia,
mock_load): mock_update, mock_load):
def mock_ring_attn(q_nope, q_rope, k_nope, k_rope, value, mask, seqlen, def mock_fia_attn(*args, **kwargs):
head_num, kv_head_num, pre_out, prev_lse, qk_scale, q = args[0]
kernel_type, mask_type, input_layout, calc_type, v = args[2]
output, softmax_lse): return (torch.randn(q.shape[0],
return torch.randn(q_rope.shape[0], value.shape[1], value.shape[2]) v.shape[1],
v.shape[2],
dtype=torch.float16),
torch.randn(v.shape[1], q.shape[0], dtype=torch.float16))
mock_ring.side_effect = mock_ring_attn mock_fia.side_effect = mock_fia_attn
def mock_kv_b_proj(kv_c_normed): def mock_kv_b_proj(kv_c_normed):
return (torch.randn(kv_c_normed.shape[0], return (torch.randn(kv_c_normed.shape[0],
@@ -534,6 +543,13 @@ class TestAscendMLAImpl(TestBase):
# mock proj # mock proj
self.impl.kv_b_proj.side_effect = mock_kv_b_proj self.impl.kv_b_proj.side_effect = mock_kv_b_proj
def mock_update_fn(lse_list, out_list, mode):
total_len = out_list[0].shape[0]
D = out_list[0].shape[1]
return (torch.randn(total_len, D, dtype=torch.float32), None)
mock_update.side_effect = mock_update_fn
NUM_BLOCKS, BLOCK_SIZE = 10, 32 # fixed NUM_BLOCKS, BLOCK_SIZE = 10, 32 # fixed
USED_BLOCKS = 3 USED_BLOCKS = 3
# pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens, num_computed_tokens_of_pcp_dcp # pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens, num_computed_tokens_of_pcp_dcp
@@ -586,8 +602,8 @@ class TestAscendMLAImpl(TestBase):
self.impl.num_heads, self.impl.num_heads,
self.impl.v_head_dim, self.impl.v_head_dim,
dtype=torch.float16) dtype=torch.float16)
prefix_lse = torch.randn(sum(nums_tokens_per_rank), prefix_lse = torch.randn(self.impl.num_heads,
self.impl.num_heads, sum(nums_tokens_per_rank),
dtype=torch.float16) dtype=torch.float16)
chunk_ctx = get_chunk_metadata( chunk_ctx = get_chunk_metadata(
pcp_size, pcp_size,
@@ -602,7 +618,7 @@ class TestAscendMLAImpl(TestBase):
cp_local_block_size=cp_local_block_size) cp_local_block_size=cp_local_block_size)
meta = MagicMock() meta = MagicMock()
prefill_meta = MagicMock() prefill_meta = MagicMock()
prefill_meta.query_lens = nums_tokens_per_rank prefill_meta.query_lens = torch.tensor(nums_tokens_per_rank)
prefill_meta.block_table = torch.randint( prefill_meta.block_table = torch.randint(
0, USED_BLOCKS, (1, 64)) # (batch, max_blocks) 0, USED_BLOCKS, (1, 64)) # (batch, max_blocks)
prefill_meta.chunked_context = chunk_ctx prefill_meta.chunked_context = chunk_ctx
@@ -621,14 +637,14 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(mock_reorg.call_count, self.assertEqual(mock_reorg.call_count,
iters * (1 if dcp_size * pcp_size > 1 else 0)) iters * (1 if dcp_size * pcp_size > 1 else 0))
self.assertEqual(mock_load.call_count, iters) self.assertEqual(mock_load.call_count, iters)
self.assertEqual(mock_ring.call_count, iters) self.assertEqual(mock_fia.call_count, iters)
mock_reorg.reset_mock() mock_reorg.reset_mock()
mock_load.reset_mock() mock_load.reset_mock()
mock_ring.reset_mock() mock_fia.reset_mock()
mock_update.reset_mock()
mock_dcp.reset_mock() mock_dcp.reset_mock()
mock_pcp.reset_mock() mock_pcp.reset_mock()
self.assertEqual(out.shape, prefix_out.shape) self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape)
@patch_distributed_groups(dcp_size=2, pcp_size=2) @patch_distributed_groups(dcp_size=2, pcp_size=2)
def test_reorg_kvcache_with_dcp_pcp(self, mock_all2all, mock_dcp, def test_reorg_kvcache_with_dcp_pcp(self, mock_all2all, mock_dcp,

View File

@@ -102,7 +102,8 @@ class TestAscendMLAPrefillMetadata(TestBase):
max_seq_lens=max_seq_lens, max_seq_lens=max_seq_lens,
workspace=workspace, workspace=workspace,
chunk_seq_lens=chunk_seq_lens, chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens) chunk_seq_lens_npu=chunk_seq_lens,
chunk_actual_seq_lengths_kv_list=[[2, 4]])
metadata = AscendMLAPrefillMetadata( metadata = AscendMLAPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
@@ -886,8 +887,9 @@ class TestAscendMLAImpl(TestBase):
self.assertTrue(torch.equal(prefix_lse, lse)) self.assertTrue(torch.equal(prefix_lse, lse))
@patch("torch_npu.atb.npu_paged_cache_load") @patch("torch_npu.atb.npu_paged_cache_load")
@patch("torch_npu.atb.npu_ring_mla") @patch("torch_npu.npu_attention_update")
def test_compute_prefill_context(self, mock_ring, mock_load): @patch("torch_npu.npu_fused_infer_attention_score")
def test_compute_prefill_context(self, mock_fia, mock_update, mock_load):
S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim
_, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim _, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim
latent_kv_dim = self.impl.kv_lora_rank latent_kv_dim = self.impl.kv_lora_rank
@@ -898,11 +900,16 @@ class TestAscendMLAImpl(TestBase):
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim) kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
kv_cache_1 = torch.randn(num_blocks, block_size, N, D) kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
kv_cache = [kv_cache_0, kv_cache_1] kv_cache = [kv_cache_0, kv_cache_1]
prefix_out = torch.randn(S, N, 128) prefix_out = torch.randn(S, N, VD)
prefix_lse = torch.randn(S, N) prefix_lse = torch.randn(N, S)
self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), ) self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), )
# Mock FIA to return output and lse
mock_fia.return_value = (torch.randn(S, N, VD), torch.randn(N, S))
# Mock attention_update to return merged output
mock_update.return_value = (torch.randn(S * N, VD), None)
chunk_ctx = MagicMock() chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8] chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])] chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
@@ -911,7 +918,7 @@ class TestAscendMLAImpl(TestBase):
prefill_meta = MagicMock() prefill_meta = MagicMock()
prefill_meta.chunked_context = chunk_ctx prefill_meta.chunked_context = chunk_ctx
prefill_meta.query_lens = [8] prefill_meta.query_lens = torch.tensor([S])
prefill_meta.block_table = torch.randint(0, 100, (S, 4)) prefill_meta.block_table = torch.randint(0, 100, (S, 4))
meta = MagicMock() meta = MagicMock()
@@ -924,10 +931,10 @@ class TestAscendMLAImpl(TestBase):
prefix_lse) prefix_lse)
mock_load.assert_called_once() mock_load.assert_called_once()
mock_ring.assert_called_once() mock_fia.assert_called_once()
mock_update.assert_called_once()
self.assertEqual(out.shape, prefix_out.shape) self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape)
@patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj") @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")

View File

@@ -53,6 +53,7 @@ class CPChunkedContextMetadata:
workspace: torch.Tensor workspace: torch.Tensor
chunk_seq_lens: torch.Tensor chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor chunk_seq_lens_npu: torch.Tensor
chunk_actual_seq_lengths_kv_list: list[list[int]]
# for mla DCP & PCP # for mla DCP & PCP
padded_chunk_seq_lens_npu: torch.Tensor = None padded_chunk_seq_lens_npu: torch.Tensor = None
padded_local_chunk_seq_lens: list[list[int]] | None = None padded_local_chunk_seq_lens: list[list[int]] | None = None

View File

@@ -30,6 +30,7 @@ from vllm_ascend.attention.mla_v1 import (
# isort: on # isort: on
from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.context_parallel.common_cp import ( from vllm_ascend.attention.context_parallel.common_cp import (
AscendPCPMetadata, AscendPCPMetadata,
CPChunkedContextMetadata, CPChunkedContextMetadata,
@@ -189,6 +190,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
max_seq_lens=chunked_context_metadata.max_seq_lens, max_seq_lens=chunked_context_metadata.max_seq_lens,
chunk_seq_lens=self.chunk_seq_lens, chunk_seq_lens=self.chunk_seq_lens,
chunk_seq_lens_npu=chunked_context_metadata.chunk_seq_lens_npu, chunk_seq_lens_npu=chunked_context_metadata.chunk_seq_lens_npu,
chunk_actual_seq_lengths_kv_list=chunked_context_metadata.chunk_actual_seq_lengths_kv_list,
workspace=chunked_context_metadata.workspace, workspace=chunked_context_metadata.workspace,
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(), padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
@@ -276,6 +278,10 @@ class AscendMlaCPImpl(AscendMLAImpl):
**kwargs, **kwargs,
) )
# npu_ring_mla needs bfloat16 512x512 mask, different from FIA's int8 2048x2048 mask
# TODO: Remove this when mla_cp.py also migrates to FIA
self._ring_mla_mask_builder = AttentionMaskBuilder(torch.device("npu"))
self.pcp_size = get_pcp_group().world_size self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
@@ -484,6 +490,10 @@ class AscendMlaCPImpl(AscendMLAImpl):
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
# Use ring_mla-specific mask (bfloat16, 512x512)
# TODO: Remove this when mla_cp.py migrates to FIA
ring_mla_mask = self._ring_mla_mask_builder.get_mla_mask(self.vllm_config.model_config.dtype)
output_head, lse_head = self._attention_with_mask_and_nomask( output_head, lse_head = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_head_idx), q_nope=torch.index_select(q_nope, 0, q_head_idx),
q_pe=torch.index_select(q_pe, 0, q_head_idx), q_pe=torch.index_select(q_pe, 0, q_head_idx),
@@ -494,7 +504,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
kv_nomask_idx=kv_with_q_head_nomask_idx, kv_nomask_idx=kv_with_q_head_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens, attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=head_attn_nomask_seqlens, attn_nomask_seqlens=head_attn_nomask_seqlens,
mask=attn_metadata.attn_mask, mask=ring_mla_mask,
) )
output_tail, lse_tail = self._attention_with_mask_and_nomask( output_tail, lse_tail = self._attention_with_mask_and_nomask(
@@ -507,7 +517,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
kv_nomask_idx=kv_with_q_tail_nomask_idx, kv_nomask_idx=kv_with_q_tail_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens, attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=tail_attn_nomask_seqlens, attn_nomask_seqlens=tail_attn_nomask_seqlens,
mask=attn_metadata.attn_mask, mask=ring_mla_mask,
) )
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx

View File

@@ -112,6 +112,7 @@ class ChunkedContextMetadata:
workspace: torch.Tensor workspace: torch.Tensor
chunk_seq_lens: torch.Tensor chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor chunk_seq_lens_npu: torch.Tensor
chunk_actual_seq_lengths_kv_list: list[list[int]]
@dataclass @dataclass
@@ -131,6 +132,7 @@ class AscendMLAPrefillMetadata:
sin: torch.Tensor = None sin: torch.Tensor = None
cos: torch.Tensor = None cos: torch.Tensor = None
pcp_metadata: AscendPCPMetadata | None = None pcp_metadata: AscendPCPMetadata | None = None
actual_seq_lengths_q: list[int] | None = None
@dataclass @dataclass
@@ -447,7 +449,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
num_decodes=self.num_decodes, num_decodes=self.num_decodes,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config), attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(),
attn_state=common_attn_metadata.attn_state, attn_state=common_attn_metadata.attn_state,
prefill=prefill_metadata, prefill=prefill_metadata,
decode=decode_metadata, decode=decode_metadata,
@@ -486,6 +488,9 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
self.chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) self.chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
self.cu_seq_lens_cpu = torch.zeros(self.num_chunks, self.num_prefills + 1, dtype=torch.int32, pin_memory=True) self.cu_seq_lens_cpu = torch.zeros(self.num_chunks, self.num_prefills + 1, dtype=torch.int32, pin_memory=True)
torch.cumsum(self.chunk_seq_lens, dim=1, out=self.cu_seq_lens_cpu[:, 1:], dtype=torch.int32) torch.cumsum(self.chunk_seq_lens, dim=1, out=self.cu_seq_lens_cpu[:, 1:], dtype=torch.int32)
chunk_actual_seq_lengths_kv_list = [
torch.cumsum(self.chunk_seq_lens[i], dim=0).tolist() for i in range(self.num_chunks)
]
return ChunkedContextMetadata( return ChunkedContextMetadata(
cu_seq_lens=self.cu_seq_lens_cpu.pin_memory().to(self.device, non_blocking=True), cu_seq_lens=self.cu_seq_lens_cpu.pin_memory().to(self.device, non_blocking=True),
starts=chunk_starts.pin_memory().to(self.device, non_blocking=True), starts=chunk_starts.pin_memory().to(self.device, non_blocking=True),
@@ -494,6 +499,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
chunk_seq_lens=self.chunk_seq_lens, chunk_seq_lens=self.chunk_seq_lens,
chunk_seq_lens_npu=self.chunk_seq_lens.npu(), chunk_seq_lens_npu=self.chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace, workspace=self.chunked_prefill_workspace,
chunk_actual_seq_lengths_kv_list=chunk_actual_seq_lengths_kv_list,
) )
def get_block_table_size(self, common_attn_metadata: AscendCommonAttentionMetadata, build_metadata_step: int): def get_block_table_size(self, common_attn_metadata: AscendCommonAttentionMetadata, build_metadata_step: int):
@@ -527,9 +533,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
prefill_input_positions = input_positions[tokens_start:] prefill_input_positions = input_positions[tokens_start:]
cos, sin = get_cos_and_sin_mla(prefill_input_positions) cos, sin = get_cos_and_sin_mla(prefill_input_positions)
prefill_query_lens = self.query_lens[reqs_start:].to(torch.int32)
actual_seq_lengths_q = torch.cumsum(prefill_query_lens, dim=0).tolist()
return AscendMLAPrefillMetadata( return AscendMLAPrefillMetadata(
attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config), attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(),
query_lens=self.query_lens[reqs_start:].to(torch.int32), query_lens=prefill_query_lens,
seq_lens=self.seq_lens, seq_lens=self.seq_lens,
context_lens=self.seq_lens[reqs_start:], context_lens=self.seq_lens[reqs_start:],
input_positions=prefill_input_positions, input_positions=prefill_input_positions,
@@ -540,6 +548,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
chunked_context=chunked_context_metadata, chunked_context=chunked_context_metadata,
sin=sin, sin=sin,
cos=cos, cos=cos,
actual_seq_lengths_q=actual_seq_lengths_q,
) )
def build_decode_metadata( def build_decode_metadata(
@@ -887,8 +896,11 @@ class AscendMLAImpl(MLAAttentionImpl):
post_process_after_loading_for_shard_weight_series(layer) post_process_after_loading_for_shard_weight_series(layer)
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() # type: ignore[union-attr] assert self.fused_qkv_a_proj is not None
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() # type: ignore[union-attr] assert self.q_a_layernorm is not None
assert self.kv_a_layernorm is not None
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous()
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous()
kv_a_proj_wt = kv_a_proj_wt.t().contiguous() kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
kv_a_proj_wt = kv_a_proj_wt.t().contiguous() kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
@@ -990,17 +1002,18 @@ class AscendMLAImpl(MLAAttentionImpl):
return prefix_output, prefix_lse return prefix_output, prefix_lse
iters = len(prefill_metadata.chunked_context.seq_tot) iters = len(prefill_metadata.chunked_context.seq_tot)
current_seq_len = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
cache_kv_c = kv_c_and_k_pe_cache[0] cache_kv_c = kv_c_and_k_pe_cache[0]
cache_k_pe = kv_c_and_k_pe_cache[1] cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2) num_heads = cache_k_pe.size(2)
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
actual_seq_lengths_q = prefill_metadata.actual_seq_lengths_q
chunk_outputs = []
chunk_lses = []
for i in range(iters): for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i] toks = prefill_metadata.chunked_context.seq_tot[i]
# chunk_seq_lens will be padded when pcp&dcp
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([current_seq_len, context_seq_len])
context_seq_len_npu = self.get_context_seq_len_npu(i, attn_metadata) context_seq_len_npu = self.get_context_seq_len_npu(i, attn_metadata)
kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, dtype=q_nope.dtype, device=q_nope.device) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, dtype=q_nope.dtype, device=q_nope.device)
k_pe = torch.empty(toks, num_heads, rope_dim, dtype=q_nope.dtype, device=q_nope.device) k_pe = torch.empty(toks, num_heads, rope_dim, dtype=q_nope.dtype, device=q_nope.device)
@@ -1026,27 +1039,61 @@ class AscendMLAImpl(MLAAttentionImpl):
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
mask = attn_metadata.attn_mask actual_seq_lengths_kv = prefill_metadata.chunked_context.chunk_actual_seq_lengths_kv_list[i]
torch_npu.atb.npu_ring_mla(
q_nope=q_nope, chunk_out, chunk_lse = torch_npu.npu_fused_infer_attention_score(
q_rope=q_pe, q_nope,
k_nope=k_nope, k_nope,
k_rope=k_pe, v,
value=v, query_rope=q_pe,
mask=mask, key_rope=k_pe,
seqlen=seq_len, num_heads=self.num_heads,
head_num=self.num_heads, num_key_value_heads=self.num_heads,
kv_head_num=self.num_heads, input_layout="TND",
pre_out=prefix_output, atten_mask=None,
prev_lse=prefix_lse, sparse_mode=0,
qk_scale=self.scale, scale=self.scale,
kernel_type="kernel_type_high_precision", antiquant_mode=0,
mask_type="no_mask", antiquant_scale=None,
input_layout="type_bsnd", softmax_lse_flag=True,
calc_type="calc_type_default", actual_seq_lengths=actual_seq_lengths_q,
output=prefix_output, actual_seq_lengths_kv=actual_seq_lengths_kv,
softmax_lse=prefix_lse,
) )
chunk_outputs.append(chunk_out)
chunk_lses.append(chunk_lse)
if len(chunk_outputs) > 0:
num_tokens = q_nope.size(0)
D = self.v_head_dim
H = self.num_heads
# Normalize prefix output/lse to [num_tokens, H, D] and [num_tokens, H, 1]
prefix_output = prefix_output.to(torch.float32)
prefix_lse = prefix_lse.to(torch.float32)
if prefix_lse.dim() == 2:
prefix_lse = prefix_lse.transpose(0, 1).unsqueeze(-1)
# Concat output and lse: [num_tokens, H, D+1]
all_out_lse = [torch.cat([prefix_output, prefix_lse], dim=-1)]
for chunk_out, chunk_lse in zip(chunk_outputs, chunk_lses):
chunk_out = chunk_out.to(torch.float32)
chunk_lse = chunk_lse.to(torch.float32)
if chunk_lse.dim() == 2:
chunk_lse = chunk_lse.transpose(0, 1).unsqueeze(-1)
all_out_lse.append(torch.cat([chunk_out, chunk_lse], dim=-1))
# Stack and split: [N, num_tokens, H, D+1]
all_out_lse = torch.stack(all_out_lse, dim=0)
N = all_out_lse.size(0)
out_flat, lse_flat = torch.split(all_out_lse, [D, 1], dim=-1)
# Flatten and unbind for npu_attention_update
out_list = out_flat.view(N, num_tokens * H, D).unbind(0)
lse_list = lse_flat.view(N, num_tokens * H).unbind(0)
output_final, _ = torch_npu.npu_attention_update(lse_list, out_list, 0)
return output_final.view(num_tokens, H, D), None
return prefix_output, prefix_lse return prefix_output, prefix_lse
def _forward_prefill( def _forward_prefill(
@@ -1062,33 +1109,54 @@ class AscendMLAImpl(MLAAttentionImpl):
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
assert len(kv_c_and_k_pe_cache) > 1 assert len(kv_c_and_k_pe_cache) > 1
num_tokens = q_nope.size(0) num_tokens = q_nope.size(0)
prefill_meta = attn_metadata.prefill
actual_seq_lengths_q = prefill_meta.actual_seq_lengths_q
actual_seq_lengths_kv = actual_seq_lengths_q.copy()
# FIA with TND layout only supports bfloat16, convert if needed
original_dtype = q_nope.dtype
need_dtype_convert = original_dtype != torch.bfloat16
if need_dtype_convert:
q_nope = q_nope.to(torch.bfloat16)
q_pe = q_pe.to(torch.bfloat16)
k_nope = k_nope.to(torch.bfloat16)
k_pe = k_pe.to(torch.bfloat16)
value = value.to(torch.bfloat16)
attn_output = torch.empty(num_tokens, self.num_heads, self.v_head_dim, dtype=q_nope.dtype, device=q_nope.device) attn_output = torch.empty(num_tokens, self.num_heads, self.v_head_dim, dtype=q_nope.dtype, device=q_nope.device)
attn_lse = torch.empty(self.num_heads, num_tokens, dtype=torch.float32, device=q_nope.device) attn_lse = torch.empty(self.num_heads, num_tokens, dtype=torch.float32, device=q_nope.device)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope, common_kwargs = {
q_rope=q_pe, "query_rope": q_pe,
k_nope=k_nope, "key_rope": k_pe,
k_rope=k_pe, "num_heads": self.num_heads,
value=value, "num_key_value_heads": self.num_heads,
mask=attn_metadata.attn_mask, "input_layout": "TND",
seqlen=attn_metadata.prefill.query_lens, "atten_mask": prefill_meta.attn_mask,
head_num=self.num_heads, "sparse_mode": 3,
kv_head_num=self.num_heads, "scale": self.scale,
pre_out=None, "antiquant_mode": 0,
prev_lse=None, "antiquant_scale": None,
qk_scale=self.scale, "block_table": None,
kernel_type="kernel_type_high_precision", "block_size": 0,
mask_type="mask_type_triu", "softmax_lse_flag": True,
input_layout="type_bsnd", "actual_seq_lengths": actual_seq_lengths_q,
calc_type="calc_type_first_ring", "actual_seq_lengths_kv": actual_seq_lengths_kv,
output=attn_output, }
softmax_lse=attn_lse,
) attn_output, attn_lse = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, value, **common_kwargs)
attn_output, attn_lse = self._compute_prefill_context( attn_output, attn_lse = self._compute_prefill_context(
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse
) )
attn_output = attn_output.reshape([num_tokens, self.num_heads * self.v_head_dim]) attn_output = attn_output.reshape([num_tokens, self.num_heads * self.v_head_dim])
# Convert back to original dtype if needed
if need_dtype_convert:
attn_output = attn_output.to(original_dtype)
return attn_output return attn_output
def exec_kv_decode( def exec_kv_decode(
@@ -1099,6 +1167,7 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_cache: tuple, kv_cache: tuple,
slots: torch.Tensor, slots: torch.Tensor,
): ):
assert self.kv_a_layernorm is not None
B = kv_no_split.shape[0] B = kv_no_split.shape[0]
N = self.num_kv_heads N = self.num_kv_heads
S = 1 S = 1
@@ -1126,6 +1195,7 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_cache: tuple, kv_cache: tuple,
slots: torch.Tensor, slots: torch.Tensor,
): ):
assert self.kv_a_layernorm is not None
B = kv_no_split.shape[0] B = kv_no_split.shape[0]
N = self.num_kv_heads N = self.num_kv_heads
S = 1 S = 1