[CP&SP] Integrate FIA operator in mla_cp._forward_decode (#5641)

### What this PR does / why we need it?
Replace the npu_multi_head_latent_attention with FIA operator in
mla_cp.py _forward_decode.
Adjust mla_attn_dpc_pcp in acl_graph.py

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Signed-off-by: tongyuzhou <t00886357@china.huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: tongyuzhou <t00886357@china.huawei.com>
This commit is contained in:
Bai Yongbin
2026-01-22 20:02:30 +08:00
committed by GitHub
parent 88632cf976
commit 7f91ac2649
4 changed files with 123 additions and 81 deletions

View File

@@ -450,11 +450,11 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1) self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
@patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context') @patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context')
@patch("torch_npu.atb.npu_multi_head_latent_attention") @patch("torch_npu.npu_fused_infer_attention_score")
@patch('torch_npu.npu_attention_update') @patch('torch_npu.npu_attention_update')
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_forward_decode_pcp_dcp(self, mock_npu_attention_update, def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
mock_npu_multi_head_latent_attention, mock_npu_fused_infer_attention_score,
mock_get_forward_context): mock_get_forward_context):
self.impl.dcp_size = 2 self.impl.dcp_size = 2
self.impl.pcp_size = 2 self.impl.pcp_size = 2
@@ -470,8 +470,8 @@ class TestAscendMLAImpl(TestBase):
q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim) q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim)
q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim) q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim)
k_nope = torch.randn(NB, BS, 1, self.impl.kv_lora_rank) k_nope = torch.randn(NB, 1, BS, self.impl.kv_lora_rank)
k_pe = torch.randn(NB, BS, 1, self.impl.qk_rope_head_dim) k_pe = torch.randn(NB, 1, BS, self.impl.qk_rope_head_dim)
attn_metadata = MagicMock() attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.SpecDecoding attn_metadata.attn_state = AscendAttentionState.SpecDecoding
@@ -485,7 +485,7 @@ class TestAscendMLAImpl(TestBase):
mock_npu_attention_update.return_value = (torch.randn( mock_npu_attention_update.return_value = (torch.randn(
B, self.impl.num_heads, self.impl.kv_lora_rank), None) B, self.impl.num_heads, self.impl.kv_lora_rank), None)
mock_npu_multi_head_latent_attention.return_value = [ mock_npu_fused_infer_attention_score.return_value = [
torch.randn(B, N, self.impl.kv_lora_rank), torch.randn(B, N, self.impl.kv_lora_rank),
torch.randn(B, N, 1) torch.randn(B, N, 1)
] ]

View File

@@ -754,7 +754,7 @@ class TestPCPDCPGraphParams(TestBase):
@patch('torch.npu.graph_task_update_end', ) @patch('torch.npu.graph_task_update_end', )
@patch('torch.npu.graph_task_update_begin', MagicMock()) @patch('torch.npu.graph_task_update_begin', MagicMock())
@patch('torch_npu.atb.npu_multi_head_latent_attention', MagicMock()) @patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end): def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
block_table = torch.zeros(2, 5, dtype=torch.long) block_table = torch.zeros(2, 5, dtype=torch.long)
@@ -793,16 +793,20 @@ class TestPCPDCPGraphParams(TestBase):
qk_rope_head_dim = 32 qk_rope_head_dim = 32
qk_nope_head_dim = 64 qk_nope_head_dim = 64
query = torch.randn(4, num_heads, qk_head_dim) query = torch.randn(4, num_heads, qk_head_dim)
q_pe = query[..., qk_nope_head_dim:]
q_nope = query[..., :qk_nope_head_dim] q_nope = query[..., :qk_nope_head_dim]
q_pe = query[..., qk_rope_head_dim:]
k_nope = torch.randn(4, num_heads, qk_nope_head_dim) k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
k_pe = torch.randn(4, num_heads, qk_rope_head_dim) k_pe = torch.randn(4, num_heads, qk_rope_head_dim)
input_layout = "BNSD"
actual_seq_lengths_kv = [1, 1]
out = torch.randn(2, 16, 128) out = torch.randn(2, 16, 128)
lse = torch.randn(2, 16, 8) lse = torch.randn(2, 16, 8)
self.graph_params.attn_params[4] = [] self.graph_params.attn_params[4] = []
self.graph_params.attn_params[4].append( self.graph_params.attn_params[4].append(
(q_nope, q_pe, k_nope, k_pe, block_table, seq_lens, num_heads, (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
scale, num_kv_heads, out, lse)) None, 0, scale, block_table, 128, None, actual_seq_lengths_kv,
out, lse))
with patch("torch_npu._C._npu_setStream", return_value=None): with patch("torch_npu._C._npu_setStream", return_value=None):
update_mla_attn_dcp_pcp_params(self.update_stream, forward_context, update_mla_attn_dcp_pcp_params(self.update_stream, forward_context,

View File

@@ -14,6 +14,8 @@ from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm_ascend.attention.attention_v1 import AscendAttentionState
# isort: off # isort: off
from vllm_ascend.attention.mla_v1 import ( from vllm_ascend.attention.mla_v1 import (
AscendMLADecodeMetadata, AscendMLADecodeMetadata,
@@ -244,8 +246,12 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
self.batch_seq_mask_buf[: batch_seq_mask.shape[0]].copy_(batch_seq_mask, non_blocking=True) self.batch_seq_mask_buf[: batch_seq_mask.shape[0]].copy_(batch_seq_mask, non_blocking=True)
batch_seq_mask = self.batch_seq_mask_buf[: batch_seq_mask.shape[0]] batch_seq_mask = self.batch_seq_mask_buf[: batch_seq_mask.shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len) cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
decode_metadata.cp_seq_len = cp_seq_len decode_metadata.cp_seq_len = cp_seq_len.tolist()
decode_metadata.batch_seq_mask = batch_seq_mask decode_metadata.batch_seq_mask = batch_seq_mask
actual_seq_lengths_q = torch.arange(self.num_decodes_flatten) + 1
decode_metadata.actual_seq_lengths_q = actual_seq_lengths_q
return decode_metadata return decode_metadata
@@ -535,18 +541,53 @@ class AscendMlaCPImpl(AscendMLAImpl):
num_heads = self.num_heads * self.dcp_size num_heads = self.num_heads * self.dcp_size
else: else:
num_heads = self.num_heads num_heads = self.num_heads
k_nope = k_nope.view(-1, block_size, self.num_kv_heads, self.kv_lora_rank)
k_pe = k_pe.view(-1, block_size, self.num_kv_heads, self.qk_rope_head_dim)
q_nope = q_nope.view(num_tokens, num_heads, -1)
q_pe = q_pe.view(num_tokens, num_heads, -1)
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask # use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
seq_len = decode_meta.cp_seq_len k_nope = k_nope.view(-1, self.num_kv_heads, block_size, self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size, self.qk_rope_head_dim)
actual_seq_lengths = None
input_layout = "BNSD"
if (
attn_metadata.attn_state
in [
AscendAttentionState.SpecDecoding,
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.DecodeOnly,
]
and self.speculative_config is not None
):
input_layout = "TND"
# TODO: If the driver is upgraded later, the contiguous function can be deleted.
q_nope = q_nope.view(num_tokens, num_heads, -1).contiguous()
q_pe = q_pe.view(num_tokens, num_heads, -1)
sparse_mode = 3
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q
else:
q_nope = q_nope.view(num_tokens, num_heads, 1, -1).contiguous()
q_pe = q_pe.view(num_tokens, num_heads, 1, -1)
sparse_mode = 0
spec_attn_mask = None
common_kwargs = { common_kwargs = {
"return_lse": True, "query_rope": q_pe,
"calc_type": "calc_type_ring", "key_rope": k_pe,
"num_heads": num_heads,
"num_key_value_heads": self.num_kv_heads,
"input_layout": input_layout,
"atten_mask": spec_attn_mask,
"sparse_mode": sparse_mode,
"scale": self.scale,
"antiquant_mode": 0,
"antiquant_scale": None,
"block_table": decode_meta.block_table,
"block_size": block_size,
"actual_seq_lengths": actual_seq_lengths,
"actual_seq_lengths_kv": decode_meta.cp_seq_len,
"softmax_lse_flag": True,
} }
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
if forward_context.is_draft_model: if forward_context.is_draft_model:
graph_params = get_draft_graph_params() graph_params = get_draft_graph_params()
@@ -560,72 +601,58 @@ class AscendMlaCPImpl(AscendMLAImpl):
graph_params.events[num_tokens].append(event) graph_params.events[num_tokens].append(event)
workspace = graph_params.workspaces.get(num_tokens) workspace = graph_params.workspaces.get(num_tokens)
if workspace is None: if workspace is None:
workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace( workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, q_nope,
q_pe,
k_nope, k_nope,
k_pe, k_nope,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
**common_kwargs, **common_kwargs,
) )
update_graph_params_workspaces(num_tokens, workspace) update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty_like(q_nope) attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=q_nope.dtype, device=q_nope.device) if input_layout == "BNSD":
softmax_lse = torch.empty((num_tokens, num_heads, 1, 1), dtype=torch.float, device=q_nope.device)
else:
softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=torch.float, device=q_nope.device)
graph_params.attn_params[num_tokens].append( graph_params.attn_params[num_tokens].append(
( (
weak_ref_tensors(q_nope), weak_ref_tensors(q_nope),
weak_ref_tensors(q_pe),
weak_ref_tensors(k_nope), weak_ref_tensors(k_nope),
weak_ref_tensors(q_pe),
weak_ref_tensors(k_pe), weak_ref_tensors(k_pe),
decode_meta.block_table,
seq_len,
num_heads, num_heads,
self.scale,
self.num_kv_heads, self.num_kv_heads,
input_layout,
weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None else None,
sparse_mode,
self.scale,
weak_ref_tensors(decode_meta.block_table),
block_size,
actual_seq_lengths,
decode_meta.cp_seq_len,
weak_ref_tensors(attn_output), weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse), weak_ref_tensors(softmax_lse),
) )
) )
torch.npu.graph_task_group_begin(stream) torch.npu.graph_task_group_begin(stream)
torch_npu.atb.npu_multi_head_latent_attention( torch_npu.npu_fused_infer_attention_score.out(
q_nope, q_nope, k_nope, k_nope, **common_kwargs, workspace=workspace, out=[attn_output, softmax_lse]
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
**common_kwargs,
workspace=workspace,
output=attn_output,
lse=softmax_lse,
) )
handle = torch.npu.graph_task_group_end(stream) handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle) graph_params.handles[num_tokens].append(handle)
else: else:
attn_output = torch.empty_like(q_nope) attn_output, softmax_lse = torch_npu.npu_fused_infer_attention_score(
softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=q_nope.dtype, device=q_nope.device)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope, q_nope,
q_pe,
k_nope, k_nope,
k_pe, k_nope,
decode_meta.block_table, **common_kwargs,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
return_lse=True,
calc_type="calc_type_ring",
output=attn_output,
lse=softmax_lse,
) )
if input_layout == "BNSD":
B_attn, N_attn, S, D = attn_output.shape
B_lse, N_lse, Q_S, _ = softmax_lse.shape
attn_output = attn_output.permute(0, 2, 1, 3).reshape(B_attn * S, N_attn, D)
softmax_lse = softmax_lse.permute(0, 2, 1, 3).reshape(B_lse * Q_S, N_lse, 1)
# Update out&lse # Update out&lse
attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse, decode_meta.batch_seq_mask) attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse, decode_meta.batch_seq_mask)

View File

@@ -468,45 +468,56 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape
): ):
( (
q_nope, q_nope,
q_pe,
k_nope, k_nope,
q_pe,
k_pe, k_pe,
block_table,
seq_len,
num_heads, num_heads,
scale,
num_kv_heads, num_kv_heads,
input_layout,
spec_attn_mask,
sparse_mode,
scale,
block_table,
block_size,
actual_seq_lengths,
actual_seq_lengths_kv,
attn_output, attn_output,
softmax_lse, softmax_lse,
) = param ) = param
decode_meta = forward_context.attn_metadata[key].decode decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len seq_len = decode_meta.cp_seq_len
if isinstance(seq_len, torch.Tensor):
seq_len = seq_len.tolist()
actual_seq_lengths_kv = seq_len
# For pcp + spec decode, we flatten seq_lens pad_length = runtime_shape - len(actual_seq_lengths_kv)
# to avoid irregular attn_mask shape, if pad_length > 0:
# so there's no need to divide runtime_shape by spec_multiple actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * (runtime_shape - len(actual_seq_lengths_kv))
pad_length = runtime_shape - len(seq_len)
pad_tensor = torch.zeros(pad_length, dtype=seq_len.dtype, device=seq_len.device)
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.atb.npu_multi_head_latent_attention( torch_npu.npu_fused_infer_attention_score.out(
q_nope, q_nope,
q_pe,
k_nope, k_nope,
k_pe, k_nope,
block_table, query_rope=q_pe,
seq_len, key_rope=k_pe,
num_heads, num_heads=num_heads,
scale, num_key_value_heads=num_kv_heads,
num_kv_heads, input_layout=input_layout,
return_lse=True, atten_mask=spec_attn_mask,
calc_type="calc_type_ring", sparse_mode=sparse_mode,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=actual_seq_lengths,
workspace=graph_params.workspaces.get(runtime_shape), workspace=graph_params.workspaces.get(runtime_shape),
output=attn_output, out=[attn_output, softmax_lse],
lse=softmax_lse,
) )
torch.npu.graph_task_update_end(update_stream) torch.npu.graph_task_update_end(update_stream)