[Perf] vectorize PCP/DCP loops in mla_v1.py (#5003)

### What this PR does / why we need it?
- Replace nested PCP/DCP Python loops with fully vectorized tensor
operations

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: F.Liu <liufeng248@huawei.com>
Co-authored-by: F.Liu <liufeng248@huawei.com>
This commit is contained in:
Feng Liu
2025-12-22 11:06:30 +08:00
committed by GitHub
parent 49838d4bec
commit e117b3d693
2 changed files with 70 additions and 84 deletions

View File

@@ -451,10 +451,10 @@ class TestAscendMLAImpl(TestBase):
self.assertIsNone(decode_res)
self.assertIsNotNone(prefill_res)
@patch("torch.distributed.all_gather")
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("torch.distributed.all_to_all_single")
def test_process_attn_out_lse(self, mock_all_to_all_single,
mock_all_gather):
def test_process_attn_out_lse(self, mock_all_to_all_single, mock_pcp):
self.impl.dcp_size = 2
self.impl.pcp_size = 2
@@ -468,11 +468,10 @@ class TestAscendMLAImpl(TestBase):
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)
def mock_all_gather_func(tensor_list, tensor, group=None):
tensor_list[0] = tensor
tensor_list[1] = tensor.clone()
def make_all_gather(ws):
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
mock_all_gather.side_effect = mock_all_gather_func
mock_pcp.all_gather = MagicMock(side_effect=make_all_gather(2))
decode_metadata = MagicMock()
decode_metadata.actual_seq_lengths_q = MagicMock()
@@ -483,11 +482,12 @@ class TestAscendMLAImpl(TestBase):
result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
decode_metadata)
self.assertEqual(result[0].shape[0], B)
self.assertEqual(result[0].shape[1], N / self.impl.dcp_size)
self.assertEqual(result[0].shape[2], self.impl.kv_lora_rank + 1)
self.assertEqual(result.shape[0], B * self.impl.pcp_size)
self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
@patch("torch.distributed.all_gather")
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("torch.distributed.all_to_all_single")
@patch('vllm_ascend.attention.mla_cp.get_forward_context')
@patch("torch_npu.atb.npu_multi_head_latent_attention")
@@ -495,7 +495,7 @@ class TestAscendMLAImpl(TestBase):
def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
mock_npu_multi_head_latent_attention,
mock_get_forward_context,
mock_all_to_all_single, mock_all_gather):
mock_all_to_all_single, mock_pcp):
self.impl.dcp_size = 2
self.impl.pcp_size = 2
self.impl.num_kv_heads = 1
@@ -534,11 +534,10 @@ class TestAscendMLAImpl(TestBase):
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)
def mock_all_gather_func(tensor_list, tensor, group=None):
tensor_list[0] = tensor
tensor_list[1] = tensor.clone()
def make_all_gather(ws):
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
mock_all_gather.side_effect = mock_all_gather_func
mock_pcp.all_gather = MagicMock(side_effect=make_all_gather(2))
self.impl._v_up_proj = MagicMock()
self.impl._v_up_proj.return_value = torch.randn(
@@ -562,9 +561,6 @@ class TestAscendMLAImpl(TestBase):
def mock_all_gather(ws):
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
mock_dcp.all_gather = MagicMock(side_effect=mock_all_gather(2))
mock_pcp.all_gather = MagicMock(side_effect=mock_all_gather(2))
def mock_ring_attn(q_nope, q_rope, k_nope, k_rope, value, mask, seqlen,
head_num, kv_head_num, pre_out, prev_lse, qk_scale,
kernel_type, mask_type, input_layout, calc_type,
@@ -624,6 +620,10 @@ class TestAscendMLAImpl(TestBase):
torch.ones(10, 10, dtype=torch.float16), 1)
for test_case in test_cases:
pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens_of_pcp_dcp = test_case
mock_dcp.all_gather = MagicMock(
side_effect=mock_all_gather(dcp_size))
mock_pcp.all_gather = MagicMock(
side_effect=mock_all_gather(pcp_size))
assert len(nums_tokens_per_rank) == len(nums_all_rank_context)
nums_context_per_rank = []
for num_all_rank_context in nums_all_rank_context:
@@ -804,11 +804,10 @@ class TestAscendMLAImpl(TestBase):
attn_lse_split_cp[0])
mock_npu_attention_update.side_effect = mock_npu_attention_update_effect
attn_out_lse_list = [
torch.randn(NUM_TOKENS, num_heads, head_dim)
for _ in range(self.impl.pcp_size * self.impl.dcp_size)
]
out = self.impl._npu_attention_update(attn_out_lse_list)
attn_out_lse = torch.randn(self.impl.pcp_size * NUM_TOKENS,
self.impl.dcp_size * num_heads,
head_dim)
out = self.impl._npu_attention_update(attn_out_lse)
self.impl.dcp_size = 1
self.impl.pcp_size = 1
assert out.shape == (NUM_TOKENS, num_heads, self.impl.kv_lora_rank)
@@ -908,12 +907,14 @@ class TestAscendMLAImpl(TestBase):
mock_npu_ring_mla.reset_mock()
@patch("torch.distributed.all_to_all_single")
@patch("torch.distributed.all_gather")
def test_process_attn_out_lse_with_dcp_pcp(self, mock_all_gather,
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_process_attn_out_lse_with_dcp_pcp(self, mock_pcp,
mock_all_to_all):
B, H, D = 4, self.impl.num_heads, self.impl.v_head_dim # total: [4, 4, 8]
test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (4, 4)]
for test_case in test_cases:
print(test_case)
self.impl.dcp_size = test_case[0]
self.impl.pcp_size = test_case[1]
# Inputs
@@ -928,26 +929,17 @@ class TestAscendMLAImpl(TestBase):
mock_all_to_all.side_effect = mock_all_to_all_side_effect
def mock_all_gather_side_effect(tensor_list, tensor, group=None):
for i in range(len(tensor_list)):
tensor_list[i].copy_(tensor)
def mock_all_gather(ws):
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
mock_all_gather.side_effect = mock_all_gather_side_effect
mock_pcp.all_gather = MagicMock(
side_effect=mock_all_gather(self.impl.pcp_size))
result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
decode_meta)
self.assertIsInstance(result, list)
if self.impl.dcp_size == 1 and self.impl.pcp_size == 1:
self.assertEqual(len(result), 0)
else:
self.assertEqual(len(result),
self.impl.dcp_size * self.impl.pcp_size) # 4
for tensor in result:
self.assertEqual(tensor.dtype, torch.float32)
self.assertEqual(tensor.shape,
(B, H // self.impl.dcp_size, D + 1))
# [PCP * S, DCP * H, D + 1]
self.assertIsInstance(result, torch.Tensor)
assert result.shape == (B * self.impl.pcp_size, H, D + 1)
self.impl.dcp_size = 1
self.impl.pcp_size = 1