[CP&SP] Integrate FIA operator in mla_cp._forward_decode (#5641)
### What this PR does / why we need it?
Replace the npu_multi_head_latent_attention with FIA operator in
mla_cp.py _forward_decode.
Adjust mla_attn_dpc_pcp in acl_graph.py
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Signed-off-by: tongyuzhou <t00886357@china.huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: tongyuzhou <t00886357@china.huawei.com>
This commit is contained in:
@@ -450,11 +450,11 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
|
||||
|
||||
@patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context')
|
||||
@patch("torch_npu.atb.npu_multi_head_latent_attention")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
@patch('torch_npu.npu_attention_update')
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
|
||||
mock_npu_multi_head_latent_attention,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_get_forward_context):
|
||||
self.impl.dcp_size = 2
|
||||
self.impl.pcp_size = 2
|
||||
@@ -470,8 +470,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
|
||||
q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim)
|
||||
q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim)
|
||||
k_nope = torch.randn(NB, BS, 1, self.impl.kv_lora_rank)
|
||||
k_pe = torch.randn(NB, BS, 1, self.impl.qk_rope_head_dim)
|
||||
k_nope = torch.randn(NB, 1, BS, self.impl.kv_lora_rank)
|
||||
k_pe = torch.randn(NB, 1, BS, self.impl.qk_rope_head_dim)
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.attn_state = AscendAttentionState.SpecDecoding
|
||||
@@ -485,7 +485,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
|
||||
mock_npu_attention_update.return_value = (torch.randn(
|
||||
B, self.impl.num_heads, self.impl.kv_lora_rank), None)
|
||||
mock_npu_multi_head_latent_attention.return_value = [
|
||||
mock_npu_fused_infer_attention_score.return_value = [
|
||||
torch.randn(B, N, self.impl.kv_lora_rank),
|
||||
torch.randn(B, N, 1)
|
||||
]
|
||||
|
||||
@@ -754,7 +754,7 @@ class TestPCPDCPGraphParams(TestBase):
|
||||
|
||||
@patch('torch.npu.graph_task_update_end', )
|
||||
@patch('torch.npu.graph_task_update_begin', MagicMock())
|
||||
@patch('torch_npu.atb.npu_multi_head_latent_attention', MagicMock())
|
||||
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
|
||||
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
|
||||
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
block_table = torch.zeros(2, 5, dtype=torch.long)
|
||||
@@ -793,16 +793,20 @@ class TestPCPDCPGraphParams(TestBase):
|
||||
qk_rope_head_dim = 32
|
||||
qk_nope_head_dim = 64
|
||||
query = torch.randn(4, num_heads, qk_head_dim)
|
||||
q_pe = query[..., qk_nope_head_dim:]
|
||||
|
||||
q_nope = query[..., :qk_nope_head_dim]
|
||||
q_pe = query[..., qk_rope_head_dim:]
|
||||
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
|
||||
k_pe = torch.randn(4, num_heads, qk_rope_head_dim)
|
||||
input_layout = "BNSD"
|
||||
actual_seq_lengths_kv = [1, 1]
|
||||
out = torch.randn(2, 16, 128)
|
||||
lse = torch.randn(2, 16, 8)
|
||||
self.graph_params.attn_params[4] = []
|
||||
self.graph_params.attn_params[4].append(
|
||||
(q_nope, q_pe, k_nope, k_pe, block_table, seq_lens, num_heads,
|
||||
scale, num_kv_heads, out, lse))
|
||||
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
|
||||
None, 0, scale, block_table, 128, None, actual_seq_lengths_kv,
|
||||
out, lse))
|
||||
|
||||
with patch("torch_npu._C._npu_setStream", return_value=None):
|
||||
update_mla_attn_dcp_pcp_params(self.update_stream, forward_context,
|
||||
|
||||
Reference in New Issue
Block a user